197 lines
5.3 KiB
Go
197 lines
5.3 KiB
Go
package geminiweb
|
||
|
||
import (
|
||
"context"
|
||
"cursor-api-proxy/internal/apitypes"
|
||
"cursor-api-proxy/internal/config"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// Provider 使用持久化瀏覽器管理器
|
||
type Provider struct {
|
||
cfg config.BridgeConfig
|
||
managerOnce sync.Once
|
||
manager *BrowserManager
|
||
managerErr error
|
||
}
|
||
|
||
// NewProvider 建立新的 Provider
|
||
func NewProvider(cfg config.BridgeConfig) *Provider {
|
||
return &Provider{cfg: cfg}
|
||
}
|
||
|
||
// getName 返回 Provider 名稱
|
||
func (p *Provider) Name() string {
|
||
return "gemini-web"
|
||
}
|
||
|
||
// Close 關閉瀏覽器
|
||
func (p *Provider) Close() error {
|
||
if p.manager != nil {
|
||
return p.manager.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// getManager 獲取或初始化瀏覽器管理器(單例)
|
||
func (p *Provider) getManager() (*BrowserManager, error) {
|
||
p.managerOnce.Do(func() {
|
||
sessionDir := p.getSessionDir()
|
||
p.manager, p.managerErr = GetBrowserManager(sessionDir, p.cfg.GeminiBrowserVisible)
|
||
})
|
||
return p.manager, p.managerErr
|
||
}
|
||
|
||
// getSessionDir 獲取 session 目錄
|
||
func (p *Provider) getSessionDir() string {
|
||
// 使用單一 session 目錄(簡化設計)
|
||
return filepath.Join(p.cfg.GeminiAccountDir, "default-session")
|
||
}
|
||
|
||
// Generate 生成回應
|
||
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
|
||
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
|
||
|
||
// 1. 獲取瀏覽器管理器
|
||
manager, err := p.getManager()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get browser manager: %w", err)
|
||
}
|
||
|
||
// 2. 啟動瀏覽器(如果尚未啟動)
|
||
if !manager.IsRunning() {
|
||
fmt.Printf("[GeminiWeb] Launching browser...\n")
|
||
if err := manager.Launch(); err != nil {
|
||
return fmt.Errorf("failed to launch browser: %w", err)
|
||
}
|
||
}
|
||
|
||
// 3. 獲取頁面
|
||
page, err := manager.GetPage()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get page: %w", err)
|
||
}
|
||
|
||
// 4. 檢查當前 URL,如果不是 Gemini 則導航
|
||
currentURL, _ := page.Info()
|
||
if !strings.Contains(currentURL.URL, "gemini.google.com") {
|
||
fmt.Printf("[GeminiWeb] Navigating to Gemini...\n")
|
||
if err := NavigateToGemini(page); err != nil {
|
||
return fmt.Errorf("failed to navigate: %w", err)
|
||
}
|
||
time.Sleep(2 * time.Second)
|
||
}
|
||
|
||
// 5. 檢查登入狀態
|
||
fmt.Printf("[GeminiWeb] Checking login status...\n")
|
||
if !IsLoggedIn(page) {
|
||
fmt.Printf("[GeminiWeb] Not logged in, continuing anyway\n")
|
||
|
||
if p.cfg.GeminiBrowserVisible {
|
||
fmt.Println("\n========================================")
|
||
fmt.Println("Browser is open. You can:")
|
||
fmt.Println("1. Log in to Gemini now")
|
||
fmt.Println("2. Continue without login")
|
||
fmt.Println("========================================\n")
|
||
}
|
||
} else {
|
||
fmt.Printf("[GeminiWeb] Logged in\n")
|
||
}
|
||
|
||
// 6. 等待頁面就緒
|
||
if err := WaitForReady(page); err != nil {
|
||
fmt.Printf("[GeminiWeb] Warning: %v\n", err)
|
||
}
|
||
|
||
// 7. 建構提示詞
|
||
prompt := buildPromptFromMessages(messages)
|
||
fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt))
|
||
|
||
// 8. 輸入文字
|
||
if err := TypeInput(page, prompt); err != nil {
|
||
return fmt.Errorf("failed to type input: %w", err)
|
||
}
|
||
|
||
// 9. 發送
|
||
fmt.Printf("[GeminiWeb] Sending message...\n")
|
||
if err := ClickSend(page); err != nil {
|
||
return fmt.Errorf("failed to send: %w", err)
|
||
}
|
||
|
||
// 10. 提取回應
|
||
fmt.Printf("[GeminiWeb] Waiting for response...\n")
|
||
response, err := ExtractResponse(page)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to extract response: %w", err)
|
||
}
|
||
|
||
// 11. 串流回調
|
||
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response})
|
||
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
|
||
|
||
fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response))
|
||
return nil
|
||
}
|
||
|
||
// buildPromptFromMessages 從訊息列表建構提示詞
|
||
func buildPromptFromMessages(messages []apitypes.Message) string {
|
||
var prompt string
|
||
for _, m := range messages {
|
||
switch m.Role {
|
||
case "system":
|
||
prompt += "System: " + m.Content + "\n\n"
|
||
case "user":
|
||
prompt += m.Content + "\n\n"
|
||
case "assistant":
|
||
prompt += "Assistant: " + m.Content + "\n\n"
|
||
}
|
||
}
|
||
return prompt
|
||
}
|
||
|
||
// RunLogin 執行登入流程(供 gemini-login 命令使用)
|
||
func RunLogin(cfg config.BridgeConfig, sessionName string) error {
|
||
if sessionName == "" {
|
||
sessionName = "default-session"
|
||
}
|
||
|
||
sessionDir := filepath.Join(cfg.GeminiAccountDir, sessionName)
|
||
if err := os.MkdirAll(sessionDir, 0755); err != nil {
|
||
return fmt.Errorf("failed to create session dir: %w", err)
|
||
}
|
||
|
||
fmt.Printf("Starting browser for login. Session: %s\n", sessionName)
|
||
fmt.Printf("Session directory: %s\n", sessionDir)
|
||
fmt.Println("Please log in to your Gemini account in the browser window.")
|
||
fmt.Println("Press Ctrl+C when you have completed the login...")
|
||
|
||
manager, err := NewBrowserManager(sessionDir, true) // visible=true
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create browser manager: %w", err)
|
||
}
|
||
|
||
if err := manager.Launch(); err != nil {
|
||
return fmt.Errorf("failed to launch browser: %w", err)
|
||
}
|
||
defer manager.Close()
|
||
|
||
page, err := manager.GetPage()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get page: %w", err)
|
||
}
|
||
|
||
if err := NavigateToGemini(page); err != nil {
|
||
return fmt.Errorf("failed to navigate: %w", err)
|
||
}
|
||
|
||
// 等待用戶手動登入...
|
||
// 使用 Ctrl+C 退出,瀏覽器資料會自動保存在 userDataDir
|
||
|
||
return nil
|
||
}
|