opencode-cursor-agent/internal/providers/geminiweb/provider.go

261 lines
6.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package geminiweb
import (
"context"
"cursor-api-proxy/internal/apitypes"
"cursor-api-proxy/internal/config"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"github.com/go-rod/rod"
)
type Provider struct {
cfg config.BridgeConfig
pool *SessionPool
}
func NewProvider(cfg config.BridgeConfig) *Provider {
return &Provider{cfg: cfg}
}
func (p *Provider) Name() string {
return "gemini-web"
}
func (p *Provider) Close() error {
return nil
}
func (p *Provider) initPool() error {
if p.pool != nil {
return nil
}
pool, err := NewSessionPool(p.cfg.GeminiAccountDir, p.cfg.GeminiMaxSessions)
if err != nil {
return fmt.Errorf("failed to init session pool: %w", err)
}
p.pool = pool
return nil
}
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
if err := p.initPool(); err != nil {
return err
}
// 檢查是否有可用的已登入 session
session := p.pool.GetAvailable()
needLogin := false
if session == nil {
// 沒有 session建立一個新的
fmt.Printf("[GeminiWeb] No existing session found, creating new session...\n")
var err error
session, err = p.pool.CreateSession(fmt.Sprintf("session-%d", time.Now().Unix()))
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
}
needLogin = true
fmt.Printf("[GeminiWeb] Created new session: %s\n", session.Name)
} else {
fmt.Printf("[GeminiWeb] Using existing session: %s\n", session.Name)
}
p.pool.StartSession(session)
defer p.pool.EndSession(session)
// 如果沒有登入過,強制使用可見瀏覽器
visible := p.cfg.GeminiBrowserVisible || needLogin
browser, err := NewBrowser(visible)
if err != nil {
return fmt.Errorf("failed to create browser: %w", err)
}
defer browser.Close()
page, err := browser.NewPage()
if err != nil {
return fmt.Errorf("failed to create page: %w", err)
}
// 嘗試載入 cookies
if session.CookieFile != "" {
fmt.Printf("[GeminiWeb] Loading cookies from: %s\n", session.CookieFile)
cookies, err := LoadCookiesFromFile(session.CookieFile)
if err == nil {
if err := SetCookiesOnPage(page, cookies); err != nil {
fmt.Printf("[GeminiWeb] Warning: failed to set cookies: %v\n", err)
} else {
fmt.Printf("[GeminiWeb] Loaded %d cookies\n", len(cookies))
}
} else {
fmt.Printf("[GeminiWeb] No existing cookies found\n")
}
}
fmt.Printf("[GeminiWeb] Navigating to Gemini...\n")
if err := NavigateToGemini(page); err != nil {
return fmt.Errorf("failed to navigate: %w", err)
}
time.Sleep(2 * time.Second)
fmt.Printf("[GeminiWeb] Checking login status...\n")
if IsLoggedIn(page) {
fmt.Printf("[GeminiWeb] Logged in (using saved cookies)\n")
} else {
fmt.Printf("[GeminiWeb] Not logged in - continuing without login\n")
if visible {
// 開啟可見瀏覽器,提示使用者可以登入
fmt.Println("\n========================================")
fmt.Println("Browser is now open. You can:")
fmt.Println("1. Log in to Gemini (to use your account)")
fmt.Println("2. Or continue without login (limited functionality)")
fmt.Println("\nPress Enter when ready to continue...")
fmt.Println("========================================\n")
// 等待使用者按 Enter
var input string
fmt.Scanln(&input)
// 檢查是否已登入
if IsLoggedIn(page) {
fmt.Printf("[GeminiWeb] Login detected! Saving cookies for future use...\n")
cookies, err := GetPageCookies(page)
if err != nil {
fmt.Printf("[GeminiWeb] Warning: could not get cookies: %v\n", err)
} else {
if err := SaveCookiesToFile(cookies, session.CookieFile); err != nil {
fmt.Printf("[GeminiWeb] Warning: could not save cookies: %v\n", err)
} else {
fmt.Printf("[GeminiWeb] Saved %d cookies to %s\n", len(cookies), session.CookieFile)
}
}
} else {
fmt.Printf("[GeminiWeb] Continuing without login\n")
}
}
}
fmt.Printf("[GeminiWeb] Selecting model: %s\n", model)
if err := SelectModel(page, model); err != nil {
return fmt.Errorf("failed to select model: %w", err)
}
fmt.Printf("[GeminiWeb] Model selected\n")
time.Sleep(500 * time.Millisecond)
prompt := buildPromptFromMessages(messages)
fmt.Printf("[GeminiWeb] Sending prompt (length: %d chars)\n", len(prompt))
if err := SendPrompt(page, prompt); err != nil {
return fmt.Errorf("failed to send prompt: %w", err)
}
fmt.Printf("[GeminiWeb] Prompt sent, waiting for response...\n")
return WaitForResponse(page,
func(text string) {
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: text})
},
func(thinking string) {
cb(apitypes.StreamChunk{Type: apitypes.ChunkThinking, Thinking: thinking})
},
func() {
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
},
)
}
func buildPromptFromMessages(messages []apitypes.Message) string {
var prompt string
for _, m := range messages {
switch m.Role {
case "system":
prompt += "System: " + m.Content + "\n\n"
case "user":
prompt += m.Content + "\n\n"
case "assistant":
prompt += "Assistant: " + m.Content + "\n\n"
}
}
return prompt
}
func RunLogin(cfg config.BridgeConfig, sessionName string) error {
if sessionName == "" {
sessionName = fmt.Sprintf("session-%d", time.Now().Unix())
}
pool, err := NewSessionPool(cfg.GeminiAccountDir, cfg.GeminiMaxSessions)
if err != nil {
return fmt.Errorf("failed to init pool: %w", err)
}
session, err := pool.CreateSession(sessionName)
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
}
fmt.Printf("Starting browser for login. Session: %s\n", sessionName)
fmt.Println("Please log in to your Gemini account in the browser window.")
fmt.Println("Press Ctrl+C when you have completed the login...")
browser, err := NewBrowser(true)
if err != nil {
return fmt.Errorf("failed to create browser: %w", err)
}
defer browser.Close()
page, err := browser.NewPage()
if err != nil {
return fmt.Errorf("failed to create page: %w", err)
}
if err := NavigateToGemini(page); err != nil {
return fmt.Errorf("failed to navigate: %w", err)
}
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
cookies, err := GetPageCookies(page)
if err != nil {
return fmt.Errorf("failed to get cookies: %w", err)
}
if err := SaveCookiesToFile(cookies, session.CookieFile); err != nil {
return fmt.Errorf("failed to save cookies: %w", err)
}
fmt.Printf("Session saved successfully: %s\n", sessionName)
return nil
}
func GetPageCookies(page *rod.Page) ([]Cookie, error) {
cookies, err := page.Cookies([]string{})
if err != nil {
return nil, fmt.Errorf("failed to get cookies: %w", err)
}
var result []Cookie
for _, c := range cookies {
result = append(result, Cookie{
Name: c.Name,
Value: c.Value,
Domain: c.Domain,
Path: c.Path,
HTTPOnly: c.HTTPOnly,
Secure: c.Secure,
})
}
return result, nil
}