From 33a0e53709b413ecd64b4701fea922ac9cdf81a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=A7=E9=A9=8A?= Date: Thu, 2 Apr 2026 22:45:41 +0800 Subject: [PATCH] feat: Add GeminiWeb provider foundation - Add Provider interface and factory pattern - Create apitypes package for shared types - Implement GeminiWeb provider with: - Browser automation using Rod - Session pool management - Cookie persistence - DOM interaction for Gemini web interface - Add gemini-login command for session setup - Add CURSOR_BRIDGE_PROVIDER env variable Remaining: Integration with chat.go handlers --- cmd/gemini-login/main.go | 28 ++++ go.mod | 6 + go.sum | 13 ++ internal/apitypes/types.go | 40 +++++ internal/config/config.go | 88 +++++----- internal/env/env.go | 103 +++++++----- internal/providers/cursor/provider.go | 27 ++++ internal/providers/factory.go | 32 ++++ internal/providers/geminiweb/browser.go | 125 ++++++++++++++ internal/providers/geminiweb/page.go | 194 ++++++++++++++++++++++ internal/providers/geminiweb/pool.go | 169 +++++++++++++++++++ internal/providers/geminiweb/provider.go | 197 +++++++++++++++++++++++ 12 files changed, 938 insertions(+), 84 deletions(-) create mode 100644 cmd/gemini-login/main.go create mode 100644 internal/apitypes/types.go create mode 100644 internal/providers/cursor/provider.go create mode 100644 internal/providers/factory.go create mode 100644 internal/providers/geminiweb/browser.go create mode 100644 internal/providers/geminiweb/page.go create mode 100644 internal/providers/geminiweb/pool.go create mode 100644 internal/providers/geminiweb/provider.go diff --git a/cmd/gemini-login/main.go b/cmd/gemini-login/main.go new file mode 100644 index 0000000..a04a354 --- /dev/null +++ b/cmd/gemini-login/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/env" + "cursor-api-proxy/internal/providers/geminiweb" + "fmt" + "os" +) + +func main() { + accountName := "" + if len(os.Args) > 1 { + accountName = os.Args[1] + } + + e := env.OsEnvToMap() + loaded := env.LoadEnvConfig(e, "") + cfg := config.LoadBridgeConfig(e, "") + + cfg.GeminiAccountDir = loaded.GeminiAccountDir + cfg.GeminiBrowserVisible = true + + if err := geminiweb.RunLogin(cfg, accountName); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} diff --git a/go.mod b/go.mod index c231a41..c00e366 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,15 @@ require ( require ( github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-rod/rod v0.116.2 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/ysmood/fetchup v0.2.3 // indirect + github.com/ysmood/goob v0.4.0 // indirect + github.com/ysmood/got v0.40.0 // indirect + github.com/ysmood/gson v0.7.3 // indirect + github.com/ysmood/leakless v0.9.0 // indirect golang.org/x/sys v0.42.0 // indirect modernc.org/libc v1.70.0 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index 8411195..8f3d2b9 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA= +github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -12,6 +14,17 @@ github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOF github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/ysmood/fetchup v0.2.3 h1:ulX+SonA0Vma5zUFXtv52Kzip/xe7aj4vqT5AJwQ+ZQ= +github.com/ysmood/fetchup v0.2.3/go.mod h1:xhibcRKziSvol0H1/pj33dnKrYyI2ebIvz5cOOkYGns= +github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ= +github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18= +github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q= +github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg= +github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM= +github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE= +github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg= +github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU= +github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= diff --git a/internal/apitypes/types.go b/internal/apitypes/types.go new file mode 100644 index 0000000..abdcfcd --- /dev/null +++ b/internal/apitypes/types.go @@ -0,0 +1,40 @@ +package apitypes + +type Message struct { + Role string + Content string +} + +type Tool struct { + Type string + Function ToolFunction +} + +type ToolFunction struct { + Name string + Description string + Parameters interface{} +} + +type ToolCall struct { + ID string + Name string + Arguments string +} + +type StreamChunk struct { + Type ChunkType + Text string + Thinking string + ToolCall *ToolCall + Done bool +} + +type ChunkType int + +const ( + ChunkText ChunkType = iota + ChunkThinking + ChunkToolCall + ChunkDone +) diff --git a/internal/config/config.go b/internal/config/config.go index b217919..bf7ea2e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,50 +5,58 @@ import ( ) type BridgeConfig struct { - AgentBin string - Host string - Port int - RequiredKey string - DefaultModel string - Mode string - Force bool - ApproveMcps bool - StrictModel bool - Workspace string - TimeoutMs int - TLSCertPath string - TLSKeyPath string - SessionsLogPath string - ChatOnlyWorkspace bool - Verbose bool - MaxMode bool - ConfigDirs []string - MultiPort bool - WinCmdlineMax int + AgentBin string + Host string + Port int + RequiredKey string + DefaultModel string + Mode string + Provider string + Force bool + ApproveMcps bool + StrictModel bool + Workspace string + TimeoutMs int + TLSCertPath string + TLSKeyPath string + SessionsLogPath string + ChatOnlyWorkspace bool + Verbose bool + MaxMode bool + ConfigDirs []string + MultiPort bool + WinCmdlineMax int + GeminiAccountDir string + GeminiBrowserVisible bool + GeminiMaxSessions int } func LoadBridgeConfig(e env.EnvSource, cwd string) BridgeConfig { loaded := env.LoadEnvConfig(e, cwd) return BridgeConfig{ - AgentBin: loaded.AgentBin, - Host: loaded.Host, - Port: loaded.Port, - RequiredKey: loaded.RequiredKey, - DefaultModel: loaded.DefaultModel, - Mode: "ask", - Force: loaded.Force, - ApproveMcps: loaded.ApproveMcps, - StrictModel: loaded.StrictModel, - Workspace: loaded.Workspace, - TimeoutMs: loaded.TimeoutMs, - TLSCertPath: loaded.TLSCertPath, - TLSKeyPath: loaded.TLSKeyPath, - SessionsLogPath: loaded.SessionsLogPath, - ChatOnlyWorkspace: loaded.ChatOnlyWorkspace, - Verbose: loaded.Verbose, - MaxMode: loaded.MaxMode, - ConfigDirs: loaded.ConfigDirs, - MultiPort: loaded.MultiPort, - WinCmdlineMax: loaded.WinCmdlineMax, + AgentBin: loaded.AgentBin, + Host: loaded.Host, + Port: loaded.Port, + RequiredKey: loaded.RequiredKey, + DefaultModel: loaded.DefaultModel, + Mode: "ask", + Provider: loaded.Provider, + Force: loaded.Force, + ApproveMcps: loaded.ApproveMcps, + StrictModel: loaded.StrictModel, + Workspace: loaded.Workspace, + TimeoutMs: loaded.TimeoutMs, + TLSCertPath: loaded.TLSCertPath, + TLSKeyPath: loaded.TLSKeyPath, + SessionsLogPath: loaded.SessionsLogPath, + ChatOnlyWorkspace: loaded.ChatOnlyWorkspace, + Verbose: loaded.Verbose, + MaxMode: loaded.MaxMode, + ConfigDirs: loaded.ConfigDirs, + MultiPort: loaded.MultiPort, + WinCmdlineMax: loaded.WinCmdlineMax, + GeminiAccountDir: loaded.GeminiAccountDir, + GeminiBrowserVisible: loaded.GeminiBrowserVisible, + GeminiMaxSessions: loaded.GeminiMaxSessions, } } diff --git a/internal/env/env.go b/internal/env/env.go index 239e3d8..45dc3be 100644 --- a/internal/env/env.go +++ b/internal/env/env.go @@ -12,28 +12,32 @@ import ( type EnvSource map[string]string type LoadedEnv struct { - AgentBin string - AgentNode string - AgentScript string - CommandShell string - Host string - Port int - RequiredKey string - DefaultModel string - Force bool - ApproveMcps bool - StrictModel bool - Workspace string - TimeoutMs int - TLSCertPath string - TLSKeyPath string - SessionsLogPath string - ChatOnlyWorkspace bool - Verbose bool - MaxMode bool - ConfigDirs []string - MultiPort bool - WinCmdlineMax int + AgentBin string + AgentNode string + AgentScript string + CommandShell string + Host string + Port int + RequiredKey string + DefaultModel string + Provider string + Force bool + ApproveMcps bool + StrictModel bool + Workspace string + TimeoutMs int + TLSCertPath string + TLSKeyPath string + SessionsLogPath string + ChatOnlyWorkspace bool + Verbose bool + MaxMode bool + ConfigDirs []string + MultiPort bool + WinCmdlineMax int + GeminiAccountDir string + GeminiBrowserVisible bool + GeminiMaxSessions int } type AgentCommand struct { @@ -256,29 +260,40 @@ func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv { workspace = cwd } + geminiAccountDir := getEnvVal(e, []string{"GEMINI_ACCOUNT_DIR"}) + if geminiAccountDir == "" { + geminiAccountDir = filepath.Join(home, ".cursor-api-proxy", "gemini-accounts") + } else { + geminiAccountDir = resolveAbs(geminiAccountDir, cwd) + } + return LoadedEnv{ - AgentBin: agentBin, - AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}), - AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}), - CommandShell: commandShell, - Host: host, - Port: port, - RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}), - DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})), - Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false), - ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false), - StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true), - Workspace: workspace, - TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000), - TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd), - TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd), - SessionsLogPath: sessionsLogPath, - ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true), - Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false), - MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false), - ConfigDirs: configDirs, - MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false), - WinCmdlineMax: winMax, + AgentBin: agentBin, + AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}), + AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}), + CommandShell: commandShell, + Host: host, + Port: port, + RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}), + DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})), + Provider: getEnvVal(e, []string{"CURSOR_BRIDGE_PROVIDER"}), + Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false), + ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false), + StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true), + Workspace: workspace, + TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000), + TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd), + TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd), + SessionsLogPath: sessionsLogPath, + ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true), + Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false), + MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false), + ConfigDirs: configDirs, + MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false), + WinCmdlineMax: winMax, + GeminiAccountDir: geminiAccountDir, + GeminiBrowserVisible: envBool(e, []string{"GEMINI_BROWSER_VISIBLE"}, false), + GeminiMaxSessions: envInt(e, []string{"GEMINI_MAX_SESSIONS"}, 3), } } diff --git a/internal/providers/cursor/provider.go b/internal/providers/cursor/provider.go new file mode 100644 index 0000000..e30b3c7 --- /dev/null +++ b/internal/providers/cursor/provider.go @@ -0,0 +1,27 @@ +package cursor + +import ( + "context" + "cursor-api-proxy/internal/apitypes" + "cursor-api-proxy/internal/config" +) + +type Provider struct { + cfg config.BridgeConfig +} + +func NewProvider(cfg config.BridgeConfig) *Provider { + return &Provider{cfg: cfg} +} + +func (p *Provider) Name() string { + return "cursor" +} + +func (p *Provider) Close() error { + return nil +} + +func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error { + return nil +} diff --git a/internal/providers/factory.go b/internal/providers/factory.go new file mode 100644 index 0000000..8fd9ea5 --- /dev/null +++ b/internal/providers/factory.go @@ -0,0 +1,32 @@ +package providers + +import ( + "context" + "cursor-api-proxy/internal/apitypes" + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/providers/cursor" + "cursor-api-proxy/internal/providers/geminiweb" + "fmt" +) + +type Provider interface { + Name() string + Close() error + Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error +} + +func NewProvider(cfg config.BridgeConfig) (Provider, error) { + providerType := cfg.Provider + if providerType == "" { + providerType = "cursor" + } + + switch providerType { + case "cursor": + return cursor.NewProvider(cfg), nil + case "gemini-web": + return geminiweb.NewProvider(cfg), nil + default: + return nil, fmt.Errorf("unknown provider: %s", providerType) + } +} diff --git a/internal/providers/geminiweb/browser.go b/internal/providers/geminiweb/browser.go new file mode 100644 index 0000000..a94894e --- /dev/null +++ b/internal/providers/geminiweb/browser.go @@ -0,0 +1,125 @@ +package geminiweb + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/go-rod/rod" + "github.com/go-rod/rod/lib/launcher" + "github.com/go-rod/rod/lib/proto" +) + +type Browser struct { + browser *rod.Browser + visible bool +} + +func NewBrowser(visible bool) (*Browser, error) { + l := launcher.New() + if visible { + l = l.Headless(false) + } else { + l = l.Headless(true) + } + + url, err := l.Launch() + if err != nil { + return nil, fmt.Errorf("failed to launch browser: %w", err) + } + + b := rod.New().ControlURL(url) + if err := b.Connect(); err != nil { + return nil, fmt.Errorf("failed to connect browser: %w", err) + } + + return &Browser{browser: b, visible: visible}, nil +} + +func (b *Browser) Close() error { + if b.browser != nil { + return b.browser.Close() + } + return nil +} + +func (b *Browser) NewPage() (*rod.Page, error) { + return b.browser.Page(proto.TargetCreateTarget{URL: "about:blank"}) +} + +type Cookie struct { + Name string `json:"name"` + Value string `json:"value"` + Domain string `json:"domain"` + Path string `json:"path"` + Expires float64 `json:"expires"` + HTTPOnly bool `json:"httpOnly"` + Secure bool `json:"secure"` +} + +func LoadCookiesFromFile(cookieFile string) ([]Cookie, error) { + data, err := os.ReadFile(cookieFile) + if err != nil { + return nil, fmt.Errorf("failed to read cookies: %w", err) + } + + var cookies []Cookie + if err := json.Unmarshal(data, &cookies); err != nil { + return nil, fmt.Errorf("failed to parse cookies: %w", err) + } + + return cookies, nil +} + +func SaveCookiesToFile(cookies []Cookie, cookieFile string) error { + data, err := json.MarshalIndent(cookies, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal cookies: %w", err) + } + + dir := filepath.Dir(cookieFile) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create cookie dir: %w", err) + } + + if err := os.WriteFile(cookieFile, data, 0644); err != nil { + return fmt.Errorf("failed to write cookies: %w", err) + } + + return nil +} + +func SetCookiesOnPage(page *rod.Page, cookies []Cookie) error { + var protoCookies []*proto.NetworkCookieParam + for _, c := range cookies { + p := &proto.NetworkCookieParam{ + Name: c.Name, + Value: c.Value, + Domain: c.Domain, + Path: c.Path, + HTTPOnly: c.HTTPOnly, + Secure: c.Secure, + } + if c.Expires > 0 { + exp := proto.TimeSinceEpoch(c.Expires) + p.Expires = exp + } + protoCookies = append(protoCookies, p) + } + return page.SetCookies(protoCookies) +} + +func WaitForElement(page *rod.Page, selector string, timeout time.Duration) (*rod.Element, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return page.Context(ctx).Element(selector) +} + +func WaitForElements(page *rod.Page, selector string, timeout time.Duration) (rod.Elements, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return page.Context(ctx).Elements(selector) +} diff --git a/internal/providers/geminiweb/page.go b/internal/providers/geminiweb/page.go new file mode 100644 index 0000000..6ba935b --- /dev/null +++ b/internal/providers/geminiweb/page.go @@ -0,0 +1,194 @@ +package geminiweb + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/go-rod/rod" + "github.com/go-rod/rod/lib/proto" +) + +const geminiURL = "https://gemini.google.com/app" + +var modelSelectors = map[string]string{ + "gemini-2.0-flash": "Flash", + "gemini-2.5-pro": "Pro", + "gemini-2.5-pro-thinking": "Thinking", +} + +func NormalizeModel(model string) string { + if strings.HasPrefix(model, "gemini-") { + return model + } + return "gemini-" + model +} + +func GetModelDisplayName(model string) string { + if name, ok := modelSelectors[model]; ok { + return name + } + return "Flash" +} + +func NavigateToGemini(page *rod.Page) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := page.Context(ctx).Navigate(geminiURL); err != nil { + return fmt.Errorf("failed to navigate to gemini: %w", err) + } + return page.Context(ctx).WaitLoad() +} + +func IsLoggedIn(page *rod.Page) bool { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := page.Context(ctx).Element(`[aria-label*="New chat"], [data-test-id*="new-chat"], button[aria-label*="chat"]`) + return err == nil +} + +func SelectModel(page *rod.Page, model string) error { + displayName := GetModelDisplayName(model) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + modelSwitcher, err := page.Context(ctx).Element(`button[aria-label*="model"], [data-test-id="model-selector"], button[aria-haspopup="listbox"]`) + if err != nil { + return fmt.Errorf("model selector not found: %w", err) + } + + if err := modelSwitcher.Click(proto.InputMouseButtonLeft, 1); err != nil { + return fmt.Errorf("failed to click model selector: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + option, err := page.Context(ctx).Element(fmt.Sprintf(`[aria-label*="%s"], [data-value="%s"]`, displayName, displayName)) + if err != nil { + return fmt.Errorf("model option %s not found: %w", displayName, err) + } + + return option.Click(proto.InputMouseButtonLeft, 1) +} + +func SendPrompt(page *rod.Page, prompt string) error { + textarea, err := page.Element(`textarea[aria-label*="message"], textarea[placeholder*="message"], rich-textarea, .ql-editor, div[contenteditable="true"]`) + if err != nil { + return fmt.Errorf("input field not found: %w", err) + } + + if err := textarea.Input(prompt); err != nil { + return fmt.Errorf("failed to input prompt: %w", err) + } + + time.Sleep(300 * time.Millisecond) + + sendBtn, err := page.Element(`button[aria-label*="Send"], button[aria-label*="submit"], button[type="submit"]`) + if err != nil { + return fmt.Errorf("send button not found: %w", err) + } + + return sendBtn.Click(proto.InputMouseButtonLeft, 1) +} + +func WaitForResponse(page *rod.Page, onChunk func(text string), onThinking func(thinking string), onComplete func()) error { + lastText := "" + lastThinking := "" + responseComplete := false + + timeout := time.NewTimer(120 * time.Second) + defer timeout.Stop() + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout.C: + return fmt.Errorf("response timeout") + case <-ticker.C: + textChanged := false + + responseEls, err := page.Elements(`.response-text, message-content, .model-response, div[data-test-id="response"]`) + if err == nil && len(responseEls) > 0 { + for _, el := range responseEls { + text, _ := el.Text() + text = strings.TrimSpace(text) + if text != "" && text != lastText { + if strings.Contains(text, lastText) { + newPart := strings.TrimPrefix(text, lastText) + if newPart != "" { + onChunk(newPart) + } + } else { + onChunk(text) + } + lastText = text + textChanged = true + } + } + } + + thinkingEls, err := page.Elements(`.thinking-content, .thought-text, div[data-test-id="thinking"]`) + if err == nil && len(thinkingEls) > 0 { + for _, el := range thinkingEls { + thinking, _ := el.Text() + thinking = strings.TrimSpace(thinking) + if thinking != "" && thinking != lastThinking { + if strings.Contains(thinking, lastThinking) { + newPart := strings.TrimPrefix(thinking, lastThinking) + if newPart != "" { + onThinking(newPart) + } + } else { + onThinking(thinking) + } + lastThinking = thinking + textChanged = true + } + } + } + + doneBtn, err := page.Element(`button[aria-label*="stop"], button[aria-label*="regenerate"]`) + if err == nil && doneBtn != nil { + ariaLabel, _ := doneBtn.Attribute("aria-label") + if ariaLabel != nil && (*ariaLabel == "Stop" || strings.Contains(*ariaLabel, "regenerate")) { + if !responseComplete && lastText != "" { + responseComplete = true + onComplete() + return nil + } + } + } + + if !textChanged && responseComplete { + return nil + } + } + } +} + +func IsRateLimited(page *rod.Page) bool { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + el, err := page.Context(ctx).Element(`[class*="rate-limit"], [class*="quota"], [data-test-id="rate-limited"]`) + return err == nil && el != nil +} + +func GetRateLimitMessage(page *rod.Page) string { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + el, err := page.Context(ctx).Element(`[class*="rate-limit"], [class*="quota"], [class*="error-message"]`) + if err != nil || el == nil { + return "" + } + + text, _ := el.Text() + return strings.TrimSpace(text) +} diff --git a/internal/providers/geminiweb/pool.go b/internal/providers/geminiweb/pool.go new file mode 100644 index 0000000..88d4f89 --- /dev/null +++ b/internal/providers/geminiweb/pool.go @@ -0,0 +1,169 @@ +package geminiweb + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" +) + +type GeminiSession struct { + Name string `json:"name"` + CookieFile string `json:"cookie_file"` + LastUsed int64 `json:"last_used"` + ActiveCount int `json:"active_count"` + RateLimitEnd int64 `json:"rate_limit_end"` +} + +type SessionPool struct { + mu sync.Mutex + sessions []*GeminiSession + dir string + maxCount int +} + +func NewSessionPool(dir string, maxSessions int) (*SessionPool, error) { + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create session dir: %w", err) + } + + sessions, err := loadSessions(dir) + if err != nil { + return nil, fmt.Errorf("failed to load sessions: %w", err) + } + + return &SessionPool{ + sessions: sessions, + dir: dir, + maxCount: maxSessions, + }, nil +} + +func loadSessions(dir string) ([]*GeminiSession, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + var sessions []*GeminiSession + for _, entry := range entries { + if !entry.IsDir() { + continue + } + name := entry.Name() + metaPath := filepath.Join(dir, name, "session.json") + data, err := os.ReadFile(metaPath) + if err != nil { + continue + } + + var s GeminiSession + if err := json.Unmarshal(data, &s); err != nil { + continue + } + sessions = append(sessions, &s) + } + + return sessions, nil +} + +func (p *SessionPool) Count() int { + p.mu.Lock() + defer p.mu.Unlock() + return len(p.sessions) +} + +func (p *SessionPool) GetAvailable() *GeminiSession { + p.mu.Lock() + defer p.mu.Unlock() + + now := time.Now().UnixMilli() + + var available []*GeminiSession + for _, s := range p.sessions { + if s.RateLimitEnd < now { + available = append(available, s) + } + } + + if len(available) == 0 { + return nil + } + + var best *GeminiSession + for _, s := range available { + if best == nil || s.ActiveCount < best.ActiveCount { + best = s + } else if s.ActiveCount == best.ActiveCount && s.LastUsed < best.LastUsed { + best = s + } + } + + return best +} + +func (p *SessionPool) StartSession(s *GeminiSession) { + p.mu.Lock() + defer p.mu.Unlock() + s.ActiveCount++ + s.LastUsed = time.Now().UnixMilli() + p.saveSession(s) +} + +func (p *SessionPool) EndSession(s *GeminiSession) { + p.mu.Lock() + defer p.mu.Unlock() + if s.ActiveCount > 0 { + s.ActiveCount-- + } + p.saveSession(s) +} + +func (p *SessionPool) RateLimitSession(s *GeminiSession, durationMs int64) { + p.mu.Lock() + defer p.mu.Unlock() + s.RateLimitEnd = time.Now().UnixMilli() + durationMs + p.saveSession(s) +} + +func (p *SessionPool) saveSession(s *GeminiSession) { + metaPath := filepath.Join(p.dir, s.Name, "session.json") + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return + } + _ = os.WriteFile(metaPath, data, 0644) +} + +func (p *SessionPool) CreateSession(name string) (*GeminiSession, error) { + p.mu.Lock() + defer p.mu.Unlock() + + sessionDir := filepath.Join(p.dir, name) + if err := os.MkdirAll(sessionDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create session dir: %w", err) + } + + s := &GeminiSession{ + Name: name, + CookieFile: filepath.Join(sessionDir, "cookies.json"), + LastUsed: time.Now().UnixMilli(), + } + + p.sessions = append(p.sessions, s) + p.saveSession(s) + + return s, nil +} + +func (p *SessionPool) GetSessionNames() []string { + p.mu.Lock() + defer p.mu.Unlock() + names := make([]string, len(p.sessions)) + for i, s := range p.sessions { + names[i] = s.Name + } + return names +} diff --git a/internal/providers/geminiweb/provider.go b/internal/providers/geminiweb/provider.go new file mode 100644 index 0000000..f0035be --- /dev/null +++ b/internal/providers/geminiweb/provider.go @@ -0,0 +1,197 @@ +package geminiweb + +import ( + "context" + "cursor-api-proxy/internal/apitypes" + "cursor-api-proxy/internal/config" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/go-rod/rod" +) + +type Provider struct { + cfg config.BridgeConfig + pool *SessionPool +} + +func NewProvider(cfg config.BridgeConfig) *Provider { + return &Provider{cfg: cfg} +} + +func (p *Provider) Name() string { + return "gemini-web" +} + +func (p *Provider) Close() error { + return nil +} + +func (p *Provider) initPool() error { + if p.pool != nil { + return nil + } + pool, err := NewSessionPool(p.cfg.GeminiAccountDir, p.cfg.GeminiMaxSessions) + if err != nil { + return fmt.Errorf("failed to init session pool: %w", err) + } + p.pool = pool + return nil +} + +func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error { + if err := p.initPool(); err != nil { + return err + } + + session := p.pool.GetAvailable() + if session == nil { + return fmt.Errorf("no available sessions") + } + + p.pool.StartSession(session) + defer p.pool.EndSession(session) + + browser, err := NewBrowser(p.cfg.GeminiBrowserVisible) + if err != nil { + return fmt.Errorf("failed to create browser: %w", err) + } + defer browser.Close() + + page, err := browser.NewPage() + if err != nil { + return fmt.Errorf("failed to create page: %w", err) + } + + if session.CookieFile != "" { + cookies, err := LoadCookiesFromFile(session.CookieFile) + if err == nil { + if err := SetCookiesOnPage(page, cookies); err != nil { + return fmt.Errorf("failed to set cookies: %w", err) + } + } + } + + if err := NavigateToGemini(page); err != nil { + return fmt.Errorf("failed to navigate: %w", err) + } + + time.Sleep(2 * time.Second) + + if !IsLoggedIn(page) { + return fmt.Errorf("session not logged in, please run gemini-login first") + } + + if err := SelectModel(page, model); err != nil { + return fmt.Errorf("failed to select model: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + prompt := buildPromptFromMessages(messages) + if err := SendPrompt(page, prompt); err != nil { + return fmt.Errorf("failed to send prompt: %w", err) + } + + return WaitForResponse(page, + func(text string) { + cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: text}) + }, + func(thinking string) { + cb(apitypes.StreamChunk{Type: apitypes.ChunkThinking, Thinking: thinking}) + }, + func() { + cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true}) + }, + ) +} + +func buildPromptFromMessages(messages []apitypes.Message) string { + var prompt string + for _, m := range messages { + switch m.Role { + case "system": + prompt += "System: " + m.Content + "\n\n" + case "user": + prompt += m.Content + "\n\n" + case "assistant": + prompt += "Assistant: " + m.Content + "\n\n" + } + } + return prompt +} + +func RunLogin(cfg config.BridgeConfig, sessionName string) error { + if sessionName == "" { + sessionName = fmt.Sprintf("session-%d", time.Now().Unix()) + } + + pool, err := NewSessionPool(cfg.GeminiAccountDir, cfg.GeminiMaxSessions) + if err != nil { + return fmt.Errorf("failed to init pool: %w", err) + } + + session, err := pool.CreateSession(sessionName) + if err != nil { + return fmt.Errorf("failed to create session: %w", err) + } + + fmt.Printf("Starting browser for login. Session: %s\n", sessionName) + fmt.Println("Please log in to your Gemini account in the browser window.") + fmt.Println("Press Ctrl+C when you have completed the login...") + + browser, err := NewBrowser(true) + if err != nil { + return fmt.Errorf("failed to create browser: %w", err) + } + defer browser.Close() + + page, err := browser.NewPage() + if err != nil { + return fmt.Errorf("failed to create page: %w", err) + } + + if err := NavigateToGemini(page); err != nil { + return fmt.Errorf("failed to navigate: %w", err) + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + <-sigChan + + cookies, err := GetPageCookies(page) + if err != nil { + return fmt.Errorf("failed to get cookies: %w", err) + } + + if err := SaveCookiesToFile(cookies, session.CookieFile); err != nil { + return fmt.Errorf("failed to save cookies: %w", err) + } + + fmt.Printf("Session saved successfully: %s\n", sessionName) + return nil +} + +func GetPageCookies(page *rod.Page) ([]Cookie, error) { + cookies, err := page.Cookies([]string{}) + if err != nil { + return nil, fmt.Errorf("failed to get cookies: %w", err) + } + + var result []Cookie + for _, c := range cookies { + result = append(result, Cookie{ + Name: c.Name, + Value: c.Value, + Domain: c.Domain, + Path: c.Path, + HTTPOnly: c.HTTPOnly, + Secure: c.Secure, + }) + } + return result, nil +}