261 lines
6.9 KiB
Go
261 lines
6.9 KiB
Go
package geminiweb
|
||
|
||
import (
|
||
"context"
|
||
"cursor-api-proxy/internal/apitypes"
|
||
"cursor-api-proxy/internal/config"
|
||
"fmt"
|
||
"os"
|
||
"os/signal"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/go-rod/rod"
|
||
)
|
||
|
||
type Provider struct {
|
||
cfg config.BridgeConfig
|
||
pool *SessionPool
|
||
}
|
||
|
||
func NewProvider(cfg config.BridgeConfig) *Provider {
|
||
return &Provider{cfg: cfg}
|
||
}
|
||
|
||
func (p *Provider) Name() string {
|
||
return "gemini-web"
|
||
}
|
||
|
||
func (p *Provider) Close() error {
|
||
return nil
|
||
}
|
||
|
||
func (p *Provider) initPool() error {
|
||
if p.pool != nil {
|
||
return nil
|
||
}
|
||
pool, err := NewSessionPool(p.cfg.GeminiAccountDir, p.cfg.GeminiMaxSessions)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to init session pool: %w", err)
|
||
}
|
||
p.pool = pool
|
||
return nil
|
||
}
|
||
|
||
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
|
||
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
|
||
|
||
if err := p.initPool(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 檢查是否有可用的已登入 session
|
||
session := p.pool.GetAvailable()
|
||
needLogin := false
|
||
|
||
if session == nil {
|
||
// 沒有 session,建立一個新的
|
||
fmt.Printf("[GeminiWeb] No existing session found, creating new session...\n")
|
||
var err error
|
||
session, err = p.pool.CreateSession(fmt.Sprintf("session-%d", time.Now().Unix()))
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create session: %w", err)
|
||
}
|
||
needLogin = true
|
||
fmt.Printf("[GeminiWeb] Created new session: %s\n", session.Name)
|
||
} else {
|
||
fmt.Printf("[GeminiWeb] Using existing session: %s\n", session.Name)
|
||
}
|
||
|
||
p.pool.StartSession(session)
|
||
defer p.pool.EndSession(session)
|
||
|
||
// 如果沒有登入過,強制使用可見瀏覽器
|
||
visible := p.cfg.GeminiBrowserVisible || needLogin
|
||
|
||
browser, err := NewBrowser(visible)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create browser: %w", err)
|
||
}
|
||
defer browser.Close()
|
||
|
||
page, err := browser.NewPage()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create page: %w", err)
|
||
}
|
||
|
||
// 嘗試載入 cookies
|
||
if session.CookieFile != "" {
|
||
fmt.Printf("[GeminiWeb] Loading cookies from: %s\n", session.CookieFile)
|
||
cookies, err := LoadCookiesFromFile(session.CookieFile)
|
||
if err == nil {
|
||
if err := SetCookiesOnPage(page, cookies); err != nil {
|
||
fmt.Printf("[GeminiWeb] Warning: failed to set cookies: %v\n", err)
|
||
} else {
|
||
fmt.Printf("[GeminiWeb] Loaded %d cookies\n", len(cookies))
|
||
}
|
||
} else {
|
||
fmt.Printf("[GeminiWeb] No existing cookies found\n")
|
||
}
|
||
}
|
||
|
||
fmt.Printf("[GeminiWeb] Navigating to Gemini...\n")
|
||
if err := NavigateToGemini(page); err != nil {
|
||
return fmt.Errorf("failed to navigate: %w", err)
|
||
}
|
||
|
||
time.Sleep(2 * time.Second)
|
||
|
||
fmt.Printf("[GeminiWeb] Checking login status...\n")
|
||
if IsLoggedIn(page) {
|
||
fmt.Printf("[GeminiWeb] Logged in (using saved cookies)\n")
|
||
} else {
|
||
fmt.Printf("[GeminiWeb] Not logged in - continuing without login\n")
|
||
|
||
if visible {
|
||
// 開啟可見瀏覽器,提示使用者可以登入
|
||
fmt.Println("\n========================================")
|
||
fmt.Println("Browser is now open. You can:")
|
||
fmt.Println("1. Log in to Gemini (to use your account)")
|
||
fmt.Println("2. Or continue without login (limited functionality)")
|
||
fmt.Println("\nPress Enter when ready to continue...")
|
||
fmt.Println("========================================\n")
|
||
|
||
// 等待使用者按 Enter
|
||
var input string
|
||
fmt.Scanln(&input)
|
||
|
||
// 檢查是否已登入
|
||
if IsLoggedIn(page) {
|
||
fmt.Printf("[GeminiWeb] Login detected! Saving cookies for future use...\n")
|
||
cookies, err := GetPageCookies(page)
|
||
if err != nil {
|
||
fmt.Printf("[GeminiWeb] Warning: could not get cookies: %v\n", err)
|
||
} else {
|
||
if err := SaveCookiesToFile(cookies, session.CookieFile); err != nil {
|
||
fmt.Printf("[GeminiWeb] Warning: could not save cookies: %v\n", err)
|
||
} else {
|
||
fmt.Printf("[GeminiWeb] Saved %d cookies to %s\n", len(cookies), session.CookieFile)
|
||
}
|
||
}
|
||
} else {
|
||
fmt.Printf("[GeminiWeb] Continuing without login\n")
|
||
}
|
||
}
|
||
}
|
||
|
||
fmt.Printf("[GeminiWeb] Selecting model: %s\n", model)
|
||
if err := SelectModel(page, model); err != nil {
|
||
return fmt.Errorf("failed to select model: %w", err)
|
||
}
|
||
fmt.Printf("[GeminiWeb] Model selected\n")
|
||
|
||
time.Sleep(500 * time.Millisecond)
|
||
|
||
prompt := buildPromptFromMessages(messages)
|
||
fmt.Printf("[GeminiWeb] Sending prompt (length: %d chars)\n", len(prompt))
|
||
if err := SendPrompt(page, prompt); err != nil {
|
||
return fmt.Errorf("failed to send prompt: %w", err)
|
||
}
|
||
fmt.Printf("[GeminiWeb] Prompt sent, waiting for response...\n")
|
||
|
||
return WaitForResponse(page,
|
||
func(text string) {
|
||
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: text})
|
||
},
|
||
func(thinking string) {
|
||
cb(apitypes.StreamChunk{Type: apitypes.ChunkThinking, Thinking: thinking})
|
||
},
|
||
func() {
|
||
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
|
||
},
|
||
)
|
||
}
|
||
|
||
func buildPromptFromMessages(messages []apitypes.Message) string {
|
||
var prompt string
|
||
for _, m := range messages {
|
||
switch m.Role {
|
||
case "system":
|
||
prompt += "System: " + m.Content + "\n\n"
|
||
case "user":
|
||
prompt += m.Content + "\n\n"
|
||
case "assistant":
|
||
prompt += "Assistant: " + m.Content + "\n\n"
|
||
}
|
||
}
|
||
return prompt
|
||
}
|
||
|
||
func RunLogin(cfg config.BridgeConfig, sessionName string) error {
|
||
if sessionName == "" {
|
||
sessionName = fmt.Sprintf("session-%d", time.Now().Unix())
|
||
}
|
||
|
||
pool, err := NewSessionPool(cfg.GeminiAccountDir, cfg.GeminiMaxSessions)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to init pool: %w", err)
|
||
}
|
||
|
||
session, err := pool.CreateSession(sessionName)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create session: %w", err)
|
||
}
|
||
|
||
fmt.Printf("Starting browser for login. Session: %s\n", sessionName)
|
||
fmt.Println("Please log in to your Gemini account in the browser window.")
|
||
fmt.Println("Press Ctrl+C when you have completed the login...")
|
||
|
||
browser, err := NewBrowser(true)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create browser: %w", err)
|
||
}
|
||
defer browser.Close()
|
||
|
||
page, err := browser.NewPage()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create page: %w", err)
|
||
}
|
||
|
||
if err := NavigateToGemini(page); err != nil {
|
||
return fmt.Errorf("failed to navigate: %w", err)
|
||
}
|
||
|
||
sigChan := make(chan os.Signal, 1)
|
||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||
|
||
<-sigChan
|
||
|
||
cookies, err := GetPageCookies(page)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get cookies: %w", err)
|
||
}
|
||
|
||
if err := SaveCookiesToFile(cookies, session.CookieFile); err != nil {
|
||
return fmt.Errorf("failed to save cookies: %w", err)
|
||
}
|
||
|
||
fmt.Printf("Session saved successfully: %s\n", sessionName)
|
||
return nil
|
||
}
|
||
|
||
func GetPageCookies(page *rod.Page) ([]Cookie, error) {
|
||
cookies, err := page.Cookies([]string{})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get cookies: %w", err)
|
||
}
|
||
|
||
var result []Cookie
|
||
for _, c := range cookies {
|
||
result = append(result, Cookie{
|
||
Name: c.Name,
|
||
Value: c.Value,
|
||
Domain: c.Domain,
|
||
Path: c.Path,
|
||
HTTPOnly: c.HTTPOnly,
|
||
Secure: c.Secure,
|
||
})
|
||
}
|
||
return result, nil
|
||
}
|