From 081f404f775fc09dc479428dea5ddda0b142d5ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=A7=E9=A9=8A?= Date: Fri, 3 Apr 2026 22:54:18 +0800 Subject: [PATCH] Task 9: Cleanup - remove old internal files, update import paths, add go-zero entry point - Removed old files from internal/* (migrated to pkg/*) - Removed old CLI files from cmd/ (now in cmd/cli/) - Updated import paths (internal/* -> pkg/*) - Rewrote main.go to support CLI commands + go-zero HTTP server - Fixed AccountStat type references (use entity.AccountStat) - Fixed cmd/cli/* to use usecase package instead of agent - Fixed logger to use entity.AccountStat - Fixed geminiweb fmt.Println redundant newlines - Fixed scripts/detect-gemini-dom.go pointer format issues --- cmd/accounts.go | 196 ------ cmd/args.go | 118 ---- cmd/cli/accounts.go | 11 +- cmd/cli/login.go | 11 +- cmd/gemini-login/main.go | 4 +- cmd/login.go | 125 ---- cmd/resethwid.go | 261 ------- cmd/sqlite.go | 29 - cmd/usage.go | 255 ------- internal/agent/cmdargs.go | 28 - internal/agent/maxmode.go | 85 --- internal/agent/runner.go | 72 -- internal/agent/token.go | 36 - internal/anthropic/anthropic.go | 174 ----- internal/anthropic/anthropic_test.go | 109 --- internal/apitypes/types.go | 40 -- internal/config/config.go | 2 +- internal/config/config_test.go | 2 +- internal/env/env.go | 381 ----------- internal/env/env_test.go | 65 -- internal/handlers/anthropic_handler.go | 577 ---------------- internal/handlers/chat.go | 471 ------------- internal/handlers/gemini_handler.go | 203 ------ internal/handlers/health.go | 20 - internal/handlers/models.go | 107 --- internal/handlers/resolve_model.go | 27 - internal/httputil/httputil.go | 50 -- internal/httputil/httputil_test.go | 50 -- internal/logger/logger.go | 309 --------- internal/models/cursorcli.go | 62 -- internal/models/cursorcli_test.go | 33 - internal/models/cursormap.go | 123 ---- internal/models/cursormap_test.go | 163 ----- internal/openai/openai.go | 243 ------- internal/openai/openai_test.go | 80 --- internal/parser/stream.go | 110 --- internal/parser/stream_test.go | 304 --------- internal/pool/pool.go | 284 -------- internal/pool/pool_test.go | 152 ----- internal/process/kill_unix.go | 21 - internal/process/kill_windows.go | 14 - internal/process/process.go | 250 ------- internal/process/process_test.go | 283 -------- internal/providers/cursor/provider.go | 27 - internal/providers/factory.go | 32 - internal/providers/geminiweb/browser.go | 125 ---- .../providers/geminiweb/browser_manager.go | 173 ----- internal/providers/geminiweb/page.go | 250 ------- .../geminiweb/playwright_provider.go | 641 ------------------ internal/providers/geminiweb/pool.go | 169 ----- internal/providers/geminiweb/provider.go | 196 ------ internal/router/router.go | 147 ---- internal/sanitize/sanitize.go | 95 --- internal/sanitize/sanitize_test.go | 60 -- internal/server/server.go | 159 ----- internal/server/server_test.go | 331 --------- internal/toolcall/toolcall.go | 154 ----- internal/winlimit/winlimit.go | 181 ----- internal/winlimit/winlimit_test.go | 37 - internal/workspace/workspace.go | 30 - main.go | 73 +- pkg/infrastructure/logger/logger.go | 7 +- pkg/infrastructure/process/process_test.go | 2 +- pkg/infrastructure/process/runner.go | 2 +- pkg/infrastructure/winlimit/winlimit.go | 2 +- pkg/provider/geminiweb/playwright_provider.go | 2 +- pkg/provider/geminiweb/provider.go | 2 +- scripts/detect-gemini-dom.go | 15 +- 68 files changed, 81 insertions(+), 8771 deletions(-) delete mode 100644 cmd/accounts.go delete mode 100644 cmd/args.go delete mode 100644 cmd/login.go delete mode 100644 cmd/resethwid.go delete mode 100644 cmd/sqlite.go delete mode 100644 cmd/usage.go delete mode 100644 internal/agent/cmdargs.go delete mode 100644 internal/agent/maxmode.go delete mode 100644 internal/agent/runner.go delete mode 100644 internal/agent/token.go delete mode 100644 internal/anthropic/anthropic.go delete mode 100644 internal/anthropic/anthropic_test.go delete mode 100644 internal/apitypes/types.go delete mode 100644 internal/env/env.go delete mode 100644 internal/env/env_test.go delete mode 100644 internal/handlers/anthropic_handler.go delete mode 100644 internal/handlers/chat.go delete mode 100644 internal/handlers/gemini_handler.go delete mode 100644 internal/handlers/health.go delete mode 100644 internal/handlers/models.go delete mode 100644 internal/handlers/resolve_model.go delete mode 100644 internal/httputil/httputil.go delete mode 100644 internal/httputil/httputil_test.go delete mode 100644 internal/logger/logger.go delete mode 100644 internal/models/cursorcli.go delete mode 100644 internal/models/cursorcli_test.go delete mode 100644 internal/models/cursormap.go delete mode 100644 internal/models/cursormap_test.go delete mode 100644 internal/openai/openai.go delete mode 100644 internal/openai/openai_test.go delete mode 100644 internal/parser/stream.go delete mode 100644 internal/parser/stream_test.go delete mode 100644 internal/pool/pool.go delete mode 100644 internal/pool/pool_test.go delete mode 100644 internal/process/kill_unix.go delete mode 100644 internal/process/kill_windows.go delete mode 100644 internal/process/process.go delete mode 100644 internal/process/process_test.go delete mode 100644 internal/providers/cursor/provider.go delete mode 100644 internal/providers/factory.go delete mode 100644 internal/providers/geminiweb/browser.go delete mode 100644 internal/providers/geminiweb/browser_manager.go delete mode 100644 internal/providers/geminiweb/page.go delete mode 100644 internal/providers/geminiweb/playwright_provider.go delete mode 100644 internal/providers/geminiweb/pool.go delete mode 100644 internal/providers/geminiweb/provider.go delete mode 100644 internal/router/router.go delete mode 100644 internal/sanitize/sanitize.go delete mode 100644 internal/sanitize/sanitize_test.go delete mode 100644 internal/server/server.go delete mode 100644 internal/server/server_test.go delete mode 100644 internal/toolcall/toolcall.go delete mode 100644 internal/winlimit/winlimit.go delete mode 100644 internal/winlimit/winlimit_test.go delete mode 100644 internal/workspace/workspace.go diff --git a/cmd/accounts.go b/cmd/accounts.go deleted file mode 100644 index df1d999..0000000 --- a/cmd/accounts.go +++ /dev/null @@ -1,196 +0,0 @@ -package cmd - -import ( - "cursor-api-proxy/internal/agent" - "encoding/json" - "fmt" - "os" - "path/filepath" -) - -type AccountInfo struct { - Name string - ConfigDir string - Authenticated bool - Email string - DisplayName string - AuthID string - Plan string - SubscriptionStatus string - ExpiresAt string -} - -func ReadAccountInfo(name, configDir string) AccountInfo { - info := AccountInfo{Name: name, ConfigDir: configDir} - - configFile := filepath.Join(configDir, "cli-config.json") - data, err := os.ReadFile(configFile) - if err != nil { - return info - } - - var raw struct { - AuthInfo *struct { - Email string `json:"email"` - DisplayName string `json:"displayName"` - AuthID string `json:"authId"` - } `json:"authInfo"` - } - if err := json.Unmarshal(data, &raw); err == nil && raw.AuthInfo != nil { - info.Authenticated = true - info.Email = raw.AuthInfo.Email - info.DisplayName = raw.AuthInfo.DisplayName - info.AuthID = raw.AuthInfo.AuthID - } - - statsigFile := filepath.Join(configDir, "statsig-cache.json") - statsigData, err := os.ReadFile(statsigFile) - if err != nil { - return info - } - - var statsigWrapper struct { - Data string `json:"data"` - } - if err := json.Unmarshal(statsigData, &statsigWrapper); err != nil || statsigWrapper.Data == "" { - return info - } - - var statsig struct { - User *struct { - Custom *struct { - IsEnterpriseUser bool `json:"isEnterpriseUser"` - StripeSubscriptionStatus string `json:"stripeSubscriptionStatus"` - StripeMembershipExpiration string `json:"stripeMembershipExpiration"` - } `json:"custom"` - } `json:"user"` - } - if err := json.Unmarshal([]byte(statsigWrapper.Data), &statsig); err != nil { - return info - } - - if statsig.User != nil && statsig.User.Custom != nil { - c := statsig.User.Custom - if c.IsEnterpriseUser { - info.Plan = "Enterprise" - } else if c.StripeSubscriptionStatus == "active" { - info.Plan = "Pro" - } else { - info.Plan = "Free" - } - info.SubscriptionStatus = c.StripeSubscriptionStatus - info.ExpiresAt = c.StripeMembershipExpiration - } - - return info -} - -func HandleAccountsList() error { - accountsDir := agent.AccountsDir() - - entries, err := os.ReadDir(accountsDir) - if err != nil { - fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.") - return nil - } - - var names []string - for _, e := range entries { - if e.IsDir() { - names = append(names, e.Name()) - } - } - - if len(names) == 0 { - fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.") - return nil - } - - fmt.Print("Cursor Accounts:\n\n") - - keychainToken := agent.ReadKeychainToken() - - for i, name := range names { - configDir := filepath.Join(accountsDir, name) - info := ReadAccountInfo(name, configDir) - - fmt.Printf(" %d. %s\n", i+1, name) - - if info.Authenticated { - cachedToken := agent.ReadCachedToken(configDir) - keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID - token := cachedToken - if token == "" && keychainMatchesAccount { - token = keychainToken - } - - var liveProfile *StripeProfile - var liveUsage *UsageData - if token != "" { - liveUsage, _ = FetchAccountUsage(token) - liveProfile, _ = FetchStripeProfile(token) - } - - if info.Email != "" { - display := "" - if info.DisplayName != "" { - display = " (" + info.DisplayName + ")" - } - fmt.Printf(" %s%s\n", info.Email, display) - } - - if info.Plan != "" && liveProfile == nil { - canceled := "" - if info.SubscriptionStatus == "canceled" { - canceled = " · canceled" - } - expiry := "" - if info.ExpiresAt != "" { - expiry = " · expires " + info.ExpiresAt - } - fmt.Printf(" %s%s%s\n", info.Plan, canceled, expiry) - } - fmt.Println(" Authenticated") - - if liveProfile != nil { - fmt.Printf(" %s\n", DescribePlan(liveProfile)) - } - if liveUsage != nil { - for _, line := range FormatUsageSummary(liveUsage) { - fmt.Println(line) - } - } - } else { - fmt.Println(" Not authenticated") - } - - fmt.Println("") - } - - fmt.Println("Tip: run 'cursor-api-proxy logout ' to remove an account.") - return nil -} - -func HandleLogout(accountName string) error { - if accountName == "" { - fmt.Fprintln(os.Stderr, "Error: Please specify the account name to remove.") - fmt.Fprintln(os.Stderr, "Usage: cursor-api-proxy logout ") - os.Exit(1) - } - - accountsDir := agent.AccountsDir() - configDir := filepath.Join(accountsDir, accountName) - - if _, err := os.Stat(configDir); os.IsNotExist(err) { - fmt.Fprintf(os.Stderr, "Account '%s' not found.\n", accountName) - os.Exit(1) - } - - if err := os.RemoveAll(configDir); err != nil { - fmt.Fprintf(os.Stderr, "Error removing account: %v\n", err) - os.Exit(1) - } - - fmt.Printf("Account '%s' removed.\n", accountName) - return nil -} diff --git a/cmd/args.go b/cmd/args.go deleted file mode 100644 index 05e1c05..0000000 --- a/cmd/args.go +++ /dev/null @@ -1,118 +0,0 @@ -package cmd - -import "fmt" - -type ParsedArgs struct { - Tailscale bool - Help bool - Login bool - AccountsList bool - Logout bool - AccountName string - Proxies []string - ResetHwid bool - DeepClean bool - DryRun bool -} - -func ParseArgs(argv []string) (ParsedArgs, error) { - var args ParsedArgs - - for i := 0; i < len(argv); i++ { - arg := argv[i] - - switch arg { - case "login", "add-account": - args.Login = true - if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' { - i++ - args.AccountName = argv[i] - } - - case "logout", "remove-account": - args.Logout = true - if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' { - i++ - args.AccountName = argv[i] - } - - case "accounts", "list-accounts": - args.AccountsList = true - - case "reset-hwid", "reset": - args.ResetHwid = true - - case "--deep-clean": - args.DeepClean = true - - case "--dry-run": - args.DryRun = true - - case "--tailscale": - args.Tailscale = true - - case "--help", "-h": - args.Help = true - - default: - if len(arg) > len("--proxy=") && arg[:len("--proxy=")] == "--proxy=" { - raw := arg[len("--proxy="):] - parts := splitComma(raw) - for _, p := range parts { - if p != "" { - args.Proxies = append(args.Proxies, p) - } - } - } else { - return args, fmt.Errorf("Unknown argument: %s", arg) - } - } - } - - return args, nil -} - -func splitComma(s string) []string { - var result []string - start := 0 - for i := 0; i <= len(s); i++ { - if i == len(s) || s[i] == ',' { - part := trim(s[start:i]) - if part != "" { - result = append(result, part) - } - start = i + 1 - } - } - return result -} - -func trim(s string) string { - start := 0 - end := len(s) - for start < end && (s[start] == ' ' || s[start] == '\t') { - start++ - } - for end > start && (s[end-1] == ' ' || s[end-1] == '\t') { - end-- - } - return s[start:end] -} - -func PrintHelp(version string) { - fmt.Printf("cursor-api-proxy v%s\n\n", version) - fmt.Println("Usage:") - fmt.Println(" cursor-api-proxy [options]") - fmt.Println("") - fmt.Println("Commands:") - fmt.Println(" login [name] Log into a Cursor account (saved to ~/.cursor-api-proxy/accounts/)") - fmt.Println(" login [name] --proxy=... Same, but with a proxy from a comma-separated list") - fmt.Println(" logout Remove a saved Cursor account") - fmt.Println(" accounts List saved accounts with plan info") - fmt.Println(" reset-hwid Reset Cursor machine/telemetry IDs (anti-ban)") - fmt.Println(" reset-hwid --deep-clean Also wipe session storage and cookies") - fmt.Println("") - fmt.Println("Options:") - fmt.Println(" --tailscale Bind to 0.0.0.0 for tailnet/LAN access") - fmt.Println(" -h, --help Show this help message") -} diff --git a/cmd/cli/accounts.go b/cmd/cli/accounts.go index df1d999..1a34a6f 100644 --- a/cmd/cli/accounts.go +++ b/cmd/cli/accounts.go @@ -1,11 +1,12 @@ package cmd import ( - "cursor-api-proxy/internal/agent" "encoding/json" "fmt" "os" "path/filepath" + + "cursor-api-proxy/pkg/usecase" ) type AccountInfo struct { @@ -86,7 +87,7 @@ func ReadAccountInfo(name, configDir string) AccountInfo { } func HandleAccountsList() error { - accountsDir := agent.AccountsDir() + accountsDir := usecase.AccountsDir() entries, err := os.ReadDir(accountsDir) if err != nil { @@ -108,7 +109,7 @@ func HandleAccountsList() error { fmt.Print("Cursor Accounts:\n\n") - keychainToken := agent.ReadKeychainToken() + keychainToken := usecase.ReadKeychainToken() for i, name := range names { configDir := filepath.Join(accountsDir, name) @@ -117,7 +118,7 @@ func HandleAccountsList() error { fmt.Printf(" %d. %s\n", i+1, name) if info.Authenticated { - cachedToken := agent.ReadCachedToken(configDir) + cachedToken := usecase.ReadCachedToken(configDir) keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID token := cachedToken if token == "" && keychainMatchesAccount { @@ -178,7 +179,7 @@ func HandleLogout(accountName string) error { os.Exit(1) } - accountsDir := agent.AccountsDir() + accountsDir := usecase.AccountsDir() configDir := filepath.Join(accountsDir, accountName) if _, err := os.Stat(configDir); os.IsNotExist(err) { diff --git a/cmd/cli/login.go b/cmd/cli/login.go index f1eb22d..995a5d5 100644 --- a/cmd/cli/login.go +++ b/cmd/cli/login.go @@ -2,8 +2,6 @@ package cmd import ( "bufio" - "cursor-api-proxy/internal/agent" - "cursor-api-proxy/internal/env" "fmt" "os" "os/exec" @@ -12,6 +10,9 @@ import ( "regexp" "syscall" "time" + + "cursor-api-proxy/pkg/infrastructure/env" + "cursor-api-proxy/pkg/usecase" ) var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`) @@ -25,7 +26,7 @@ func HandleLogin(accountName string, proxies []string) error { accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000) } - accountsDir := agent.AccountsDir() + accountsDir := usecase.AccountsDir() configDir := filepath.Join(accountsDir, accountName) dirWasNew := !fileExists(configDir) @@ -110,9 +111,9 @@ func HandleLogin(accountName string, proxies []string) error { } // Cache keychain token for this account - token := agent.ReadKeychainToken() + token := usecase.ReadKeychainToken() if token != "" { - agent.WriteCachedToken(configDir, token) + usecase.WriteCachedToken(configDir, token) } fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName) diff --git a/cmd/gemini-login/main.go b/cmd/gemini-login/main.go index 7a0d7f7..725519a 100644 --- a/cmd/gemini-login/main.go +++ b/cmd/gemini-login/main.go @@ -2,8 +2,8 @@ package main import ( "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/env" - "cursor-api-proxy/internal/providers/geminiweb" + "cursor-api-proxy/pkg/infrastructure/env" + "cursor-api-proxy/pkg/provider/geminiweb" "fmt" "os" "strings" diff --git a/cmd/login.go b/cmd/login.go deleted file mode 100644 index f1eb22d..0000000 --- a/cmd/login.go +++ /dev/null @@ -1,125 +0,0 @@ -package cmd - -import ( - "bufio" - "cursor-api-proxy/internal/agent" - "cursor-api-proxy/internal/env" - "fmt" - "os" - "os/exec" - "os/signal" - "path/filepath" - "regexp" - "syscall" - "time" -) - -var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`) - -func HandleLogin(accountName string, proxies []string) error { - e := env.OsEnvToMap() - loaded := env.LoadEnvConfig(e, "") - agentBin := loaded.AgentBin - - if accountName == "" { - accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000) - } - - accountsDir := agent.AccountsDir() - configDir := filepath.Join(accountsDir, accountName) - dirWasNew := !fileExists(configDir) - - if err := os.MkdirAll(accountsDir, 0755); err != nil { - return fmt.Errorf("failed to create accounts dir: %w", err) - } - if err := os.MkdirAll(configDir, 0755); err != nil { - return fmt.Errorf("failed to create config dir: %w", err) - } - - fmt.Printf("Logging into Cursor account: %s\n", accountName) - fmt.Printf("Config: %s\n\n", configDir) - fmt.Println("Run the login command — complete the login in your browser.") - fmt.Println("") - - cleanupDir := func() { - if dirWasNew { - _ = os.RemoveAll(configDir) - } - } - - cmdEnv := make([]string, 0, len(e)+2) - for k, v := range e { - cmdEnv = append(cmdEnv, k+"="+v) - } - cmdEnv = append(cmdEnv, "CURSOR_CONFIG_DIR="+configDir) - cmdEnv = append(cmdEnv, "NO_OPEN_BROWSER=1") - - child := exec.Command(agentBin, "login") - child.Env = cmdEnv - child.Stdin = os.Stdin - child.Stderr = os.Stderr - - stdoutPipe, err := child.StdoutPipe() - if err != nil { - return fmt.Errorf("failed to create stdout pipe: %w", err) - } - - if err := child.Start(); err != nil { - cleanupDir() - if os.IsNotExist(err) { - return fmt.Errorf("could not find '%s'. Make sure the Cursor CLI is installed", agentBin) - } - return fmt.Errorf("error launching agent login: %w", err) - } - - // Handle cancellation signals - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - go func() { - sig := <-sigCh - _ = child.Process.Kill() - cleanupDir() - if sig == syscall.SIGINT { - fmt.Println("\n\nLogin cancelled.") - } - os.Exit(0) - }() - defer signal.Stop(sigCh) - - var stdoutBuf string - scanner := bufio.NewScanner(stdoutPipe) - for scanner.Scan() { - line := scanner.Text() - fmt.Println(line) - stdoutBuf += line + "\n" - - if loginURLRe.MatchString(stdoutBuf) { - match := loginURLRe.FindString(stdoutBuf) - if match != "" { - fmt.Printf("\nOpen this URL in your browser (incognito recommended):\n%s\n\n", match) - } - } - } - - if err := child.Wait(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - cleanupDir() - return fmt.Errorf("login failed with code %d", exitErr.ExitCode()) - } - return err - } - - // Cache keychain token for this account - token := agent.ReadKeychainToken() - if token != "" { - agent.WriteCachedToken(configDir, token) - } - - fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName) - return nil -} - -func fileExists(path string) bool { - _, err := os.Stat(path) - return err == nil -} diff --git a/cmd/resethwid.go b/cmd/resethwid.go deleted file mode 100644 index ba53c2b..0000000 --- a/cmd/resethwid.go +++ /dev/null @@ -1,261 +0,0 @@ -package cmd - -import ( - "crypto/rand" - "crypto/sha256" - "crypto/sha512" - "encoding/hex" - "encoding/json" - "fmt" - "os" - "os/exec" - "path/filepath" - "runtime" - "time" - - "github.com/google/uuid" -) - -func sha256hex() string { - b := make([]byte, 32) - _, _ = rand.Read(b) - h := sha256.Sum256(b) - return hex.EncodeToString(h[:]) -} - -func sha512hex() string { - b := make([]byte, 64) - _, _ = rand.Read(b) - h := sha512.Sum512(b) - return hex.EncodeToString(h[:]) -} - -func newUUID() string { - return uuid.New().String() -} - -func log(icon, msg string) { - fmt.Printf(" %s %s\n", icon, msg) -} - -func getCursorGlobalStorage() string { - switch runtime.GOOS { - case "darwin": - home, _ := os.UserHomeDir() - return filepath.Join(home, "Library", "Application Support", "Cursor", "User", "globalStorage") - case "windows": - appdata := os.Getenv("APPDATA") - return filepath.Join(appdata, "Cursor", "User", "globalStorage") - default: - xdg := os.Getenv("XDG_CONFIG_HOME") - if xdg == "" { - home, _ := os.UserHomeDir() - xdg = filepath.Join(home, ".config") - } - return filepath.Join(xdg, "Cursor", "User", "globalStorage") - } -} - -func getCursorRoot() string { - gs := getCursorGlobalStorage() - return filepath.Dir(filepath.Dir(gs)) -} - -func generateNewIDs() map[string]string { - return map[string]string{ - "telemetry.machineId": sha256hex(), - "telemetry.macMachineId": sha512hex(), - "telemetry.devDeviceId": newUUID(), - "telemetry.sqmId": "{" + fmt.Sprintf("%s", newUUID()+"") + "}", - "storage.serviceMachineId": newUUID(), - } -} - -func killCursor() { - log("", "Stopping Cursor processes...") - switch runtime.GOOS { - case "windows": - exec.Command("taskkill", "/F", "/IM", "Cursor.exe").Run() - default: - exec.Command("pkill", "-x", "Cursor").Run() - exec.Command("pkill", "-f", "Cursor.app").Run() - } - log("", "Cursor stopped (or was not running)") -} - -func updateStorageJSON(storagePath string, ids map[string]string) { - if _, err := os.Stat(storagePath); os.IsNotExist(err) { - log("", fmt.Sprintf("storage.json not found: %s", storagePath)) - return - } - - if runtime.GOOS == "darwin" { - exec.Command("chflags", "nouchg", storagePath).Run() - exec.Command("chmod", "644", storagePath).Run() - } - - data, err := os.ReadFile(storagePath) - if err != nil { - log("", fmt.Sprintf("storage.json read error: %v", err)) - return - } - - var obj map[string]interface{} - if err := json.Unmarshal(data, &obj); err != nil { - log("", fmt.Sprintf("storage.json parse error: %v", err)) - return - } - - for k, v := range ids { - obj[k] = v - } - - out, err := json.MarshalIndent(obj, "", " ") - if err != nil { - log("", fmt.Sprintf("storage.json marshal error: %v", err)) - return - } - - if err := os.WriteFile(storagePath, out, 0644); err != nil { - log("", fmt.Sprintf("storage.json write error: %v", err)) - return - } - log("", "storage.json updated") -} - -func updateStateVscdb(dbPath string, ids map[string]string) { - if _, err := os.Stat(dbPath); os.IsNotExist(err) { - log("", fmt.Sprintf("state.vscdb not found: %s", dbPath)) - return - } - - if runtime.GOOS == "darwin" { - exec.Command("chflags", "nouchg", dbPath).Run() - exec.Command("chmod", "644", dbPath).Run() - } - - if err := updateVscdbPureGo(dbPath, ids); err != nil { - log("", fmt.Sprintf("state.vscdb error: %v", err)) - } else { - log("", "state.vscdb updated") - } -} - -func updateMachineIDFile(machineID, cursorRoot string) { - var candidates []string - if runtime.GOOS == "linux" { - candidates = []string{ - filepath.Join(cursorRoot, "machineid"), - filepath.Join(cursorRoot, "machineId"), - } - } else { - candidates = []string{filepath.Join(cursorRoot, "machineId")} - } - - filePath := candidates[0] - for _, c := range candidates { - if _, err := os.Stat(c); err == nil { - filePath = c - break - } - } - - if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { - log("", fmt.Sprintf("machineId dir error: %v", err)) - return - } - - if runtime.GOOS == "darwin" { - if _, err := os.Stat(filePath); err == nil { - exec.Command("chflags", "nouchg", filePath).Run() - exec.Command("chmod", "644", filePath).Run() - } - } - - if err := os.WriteFile(filePath, []byte(machineID+"\n"), 0644); err != nil { - log("", fmt.Sprintf("machineId write error: %v", err)) - return - } - log("", fmt.Sprintf("machineId file updated (%s)", filepath.Base(filePath))) -} - -var dirsToWipe = []string{ - "Session Storage", "Local Storage", "IndexedDB", "Cache", "Code Cache", - "GPUCache", "Service Worker", "Network", "Cookies", "Cookies-journal", -} - -func deepClean(cursorRoot string) { - log("", "Deep-cleaning session data...") - wiped := 0 - for _, name := range dirsToWipe { - target := filepath.Join(cursorRoot, name) - if _, err := os.Stat(target); os.IsNotExist(err) { - continue - } - info, err := os.Stat(target) - if err != nil { - continue - } - if info.IsDir() { - if err := os.RemoveAll(target); err == nil { - wiped++ - } - } else { - if err := os.Remove(target); err == nil { - wiped++ - } - } - } - log("", fmt.Sprintf("Wiped %d cache/session items", wiped)) -} - -func HandleResetHwid(doDeepClean, dryRun bool) error { - fmt.Print("\nCursor HWID Reset\n\n") - fmt.Println(" Resets all machine / telemetry IDs so Cursor sees a fresh install.") - fmt.Print(" Cursor must be closed — it will be killed automatically.\n\n") - - globalStorage := getCursorGlobalStorage() - cursorRoot := getCursorRoot() - - if _, err := os.Stat(globalStorage); os.IsNotExist(err) { - fmt.Printf("Cursor config not found at:\n %s\n", globalStorage) - fmt.Println(" Make sure Cursor is installed and has been run at least once.") - os.Exit(1) - } - - if dryRun { - fmt.Println(" [DRY RUN] Would reset IDs in:") - fmt.Printf(" %s\n", filepath.Join(globalStorage, "storage.json")) - fmt.Printf(" %s\n", filepath.Join(globalStorage, "state.vscdb")) - fmt.Printf(" %s\n", filepath.Join(cursorRoot, "machineId")) - return nil - } - - killCursor() - - time.Sleep(800 * time.Millisecond) - - newIDs := generateNewIDs() - log("", "Generated new IDs:") - for k, v := range newIDs { - fmt.Printf(" %s: %s\n", k, v) - } - fmt.Println() - - log("", "Updating storage.json...") - updateStorageJSON(filepath.Join(globalStorage, "storage.json"), newIDs) - - log("", "Updating state.vscdb...") - updateStateVscdb(filepath.Join(globalStorage, "state.vscdb"), newIDs) - - log("", "Updating machineId file...") - updateMachineIDFile(newIDs["telemetry.machineId"], cursorRoot) - - if doDeepClean { - fmt.Println() - deepClean(cursorRoot) - } - - fmt.Print("\nHWID reset complete. You can now restart Cursor.\n\n") - return nil -} diff --git a/cmd/sqlite.go b/cmd/sqlite.go deleted file mode 100644 index a8173b2..0000000 --- a/cmd/sqlite.go +++ /dev/null @@ -1,29 +0,0 @@ -package cmd - -import ( - "database/sql" - "fmt" - - _ "modernc.org/sqlite" -) - -func updateVscdbPureGo(dbPath string, ids map[string]string) error { - db, err := sql.Open("sqlite", dbPath) - if err != nil { - return fmt.Errorf("open db: %w", err) - } - defer db.Close() - - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS ItemTable (key TEXT PRIMARY KEY, value TEXT NOT NULL)`) - if err != nil { - return fmt.Errorf("create table: %w", err) - } - - for k, v := range ids { - _, err = db.Exec(`INSERT OR REPLACE INTO ItemTable (key, value) VALUES (?, ?)`, k, v) - if err != nil { - return fmt.Errorf("insert %s: %w", k, err) - } - } - return nil -} diff --git a/cmd/usage.go b/cmd/usage.go deleted file mode 100644 index 4999020..0000000 --- a/cmd/usage.go +++ /dev/null @@ -1,255 +0,0 @@ -package cmd - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" -) - -type ModelUsage struct { - NumRequests int `json:"numRequests"` - NumRequestsTotal int `json:"numRequestsTotal"` - NumTokens int `json:"numTokens"` - MaxTokenUsage *int `json:"maxTokenUsage"` - MaxRequestUsage *int `json:"maxRequestUsage"` -} - -type UsageData struct { - StartOfMonth string `json:"startOfMonth"` - Models map[string]ModelUsage `json:"-"` -} - -type StripeProfile struct { - MembershipType string `json:"membershipType"` - SubscriptionStatus string `json:"subscriptionStatus"` - DaysRemainingOnTrial *int `json:"daysRemainingOnTrial"` - IsTeamMember bool `json:"isTeamMember"` - IsYearlyPlan bool `json:"isYearlyPlan"` -} - -func DecodeJWTPayload(token string) map[string]interface{} { - parts := strings.Split(token, ".") - if len(parts) < 2 { - return nil - } - padded := strings.ReplaceAll(parts[1], "-", "+") - padded = strings.ReplaceAll(padded, "_", "/") - data, err := base64.StdEncoding.DecodeString(padded + strings.Repeat("=", (4-len(padded)%4)%4)) - if err != nil { - return nil - } - var result map[string]interface{} - if err := json.Unmarshal(data, &result); err != nil { - return nil - } - return result -} - -func TokenSub(token string) string { - payload := DecodeJWTPayload(token) - if payload == nil { - return "" - } - if sub, ok := payload["sub"].(string); ok { - return sub - } - return "" -} - -func apiGet(path, token string) (map[string]interface{}, error) { - client := &http.Client{Timeout: 8 * time.Second} - req, err := http.NewRequest("GET", "https://api2.cursor.sh"+path, nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+token) - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - var result map[string]interface{} - if err := json.Unmarshal(data, &result); err != nil { - return nil, nil - } - return result, nil -} - -func FetchAccountUsage(token string) (*UsageData, error) { - raw, err := apiGet("/auth/usage", token) - if err != nil || raw == nil { - return nil, err - } - - startOfMonth, _ := raw["startOfMonth"].(string) - usage := &UsageData{ - StartOfMonth: startOfMonth, - Models: make(map[string]ModelUsage), - } - - for k, v := range raw { - if k == "startOfMonth" { - continue - } - data, err := json.Marshal(v) - if err != nil { - continue - } - var mu ModelUsage - if err := json.Unmarshal(data, &mu); err == nil { - usage.Models[k] = mu - } - } - return usage, nil -} - -func FetchStripeProfile(token string) (*StripeProfile, error) { - raw, err := apiGet("/auth/full_stripe_profile", token) - if err != nil || raw == nil { - return nil, err - } - - profile := &StripeProfile{ - MembershipType: fmt.Sprintf("%v", raw["membershipType"]), - SubscriptionStatus: fmt.Sprintf("%v", raw["subscriptionStatus"]), - IsTeamMember: raw["isTeamMember"] == true, - IsYearlyPlan: raw["isYearlyPlan"] == true, - } - if d, ok := raw["daysRemainingOnTrial"].(float64); ok { - di := int(d) - profile.DaysRemainingOnTrial = &di - } - return profile, nil -} - -func DescribePlan(profile *StripeProfile) string { - if profile == nil { - return "" - } - switch profile.MembershipType { - case "free_trial": - days := 0 - if profile.DaysRemainingOnTrial != nil { - days = *profile.DaysRemainingOnTrial - } - return fmt.Sprintf("Pro Trial (%dd left) — unlimited fast requests", days) - case "pro": - return "Pro — extended limits" - case "pro_plus": - return "Pro+ — extended limits" - case "ultra": - return "Ultra — extended limits" - case "free", "hobby": - return "Hobby (free) — limited agent requests" - default: - return fmt.Sprintf("%s · %s", profile.MembershipType, profile.SubscriptionStatus) - } -} - -var modelLabels = map[string]string{ - "gpt-4": "Fast Premium Requests", - "claude-sonnet-4-6": "Claude Sonnet 4.6", - "claude-sonnet-4-5-20250929-v1": "Claude Sonnet 4.5", - "claude-sonnet-4-20250514-v1": "Claude Sonnet 4", - "claude-opus-4-6-v1": "Claude Opus 4.6", - "claude-opus-4-5-20251101-v1": "Claude Opus 4.5", - "claude-opus-4-1-20250805-v1": "Claude Opus 4.1", - "claude-opus-4-20250514-v1": "Claude Opus 4", - "claude-haiku-4-5-20251001-v1": "Claude Haiku 4.5", - "claude-3-5-haiku-20241022-v1": "Claude 3.5 Haiku", - "gpt-5": "GPT-5", - "gpt-4o": "GPT-4o", - "o1": "o1", - "o3-mini": "o3-mini", - "cursor-small": "Cursor Small (free)", -} - -func modelLabel(key string) string { - if label, ok := modelLabels[key]; ok { - return label - } - return key -} - -func FormatUsageSummary(usage *UsageData) []string { - if usage == nil { - return nil - } - var lines []string - - start := "?" - if usage.StartOfMonth != "" { - if t, err := time.Parse(time.RFC3339, usage.StartOfMonth); err == nil { - start = t.Format("2006-01-02") - } else { - start = usage.StartOfMonth - } - } - lines = append(lines, fmt.Sprintf(" Billing period from %s", start)) - - if len(usage.Models) == 0 { - lines = append(lines, " No requests this billing period") - return lines - } - - type entry struct { - key string - usage ModelUsage - } - var entries []entry - for k, v := range usage.Models { - entries = append(entries, entry{k, v}) - } - - // Sort: entries with limits first, then by usage descending - for i := 1; i < len(entries); i++ { - for j := i; j > 0; j-- { - a, b := entries[j-1], entries[j] - aHasLimit := a.usage.MaxRequestUsage != nil - bHasLimit := b.usage.MaxRequestUsage != nil - if !aHasLimit && bHasLimit { - entries[j-1], entries[j] = entries[j], entries[j-1] - } else if aHasLimit == bHasLimit && a.usage.NumRequests < b.usage.NumRequests { - entries[j-1], entries[j] = entries[j], entries[j-1] - } else { - break - } - } - } - - for _, e := range entries { - used := e.usage.NumRequests - max := e.usage.MaxRequestUsage - label := modelLabel(e.key) - if max != nil && *max > 0 { - pct := int(float64(used) / float64(*max) * 100) - bar := makeBar(used, *max, 12) - lines = append(lines, fmt.Sprintf(" %s: %d/%d (%d%%) [%s]", label, used, *max, pct, bar)) - } else if used > 0 { - lines = append(lines, fmt.Sprintf(" %s: %d requests", label, used)) - } else { - lines = append(lines, fmt.Sprintf(" %s: 0 requests (unlimited)", label)) - } - } - - return lines -} - -func makeBar(used, max, width int) string { - fill := int(float64(used) / float64(max) * float64(width)) - if fill > width { - fill = width - } - return strings.Repeat("█", fill) + strings.Repeat("░", width-fill) -} diff --git a/internal/agent/cmdargs.go b/internal/agent/cmdargs.go deleted file mode 100644 index cfb1030..0000000 --- a/internal/agent/cmdargs.go +++ /dev/null @@ -1,28 +0,0 @@ -package agent - -import "cursor-api-proxy/internal/config" - -func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, stream bool) []string { - args := []string{"--print"} - if cfg.ApproveMcps { - args = append(args, "--approve-mcps") - } - if cfg.Force { - args = append(args, "--force") - } - if cfg.ChatOnlyWorkspace { - args = append(args, "--trust") - } - args = append(args, "--workspace", workspaceDir) - args = append(args, "--model", model) - if stream { - args = append(args, "--stream-partial-output", "--output-format", "stream-json") - } else { - args = append(args, "--output-format", "text") - } - return args -} - -func BuildAgentCmdArgs(cfg config.BridgeConfig, workspaceDir, model, prompt string, stream bool) []string { - return append(BuildAgentFixedArgs(cfg, workspaceDir, model, stream), prompt) -} diff --git a/internal/agent/maxmode.go b/internal/agent/maxmode.go deleted file mode 100644 index 0a094c7..0000000 --- a/internal/agent/maxmode.go +++ /dev/null @@ -1,85 +0,0 @@ -package agent - -import ( - "encoding/json" - "os" - "path/filepath" - "runtime" -) - -func getCandidates(agentScriptPath, configDirOverride string) []string { - if configDirOverride != "" { - return []string{filepath.Join(configDirOverride, "cli-config.json")} - } - - var result []string - - if dir := os.Getenv("CURSOR_CONFIG_DIR"); dir != "" { - result = append(result, filepath.Join(dir, "cli-config.json")) - } - - if agentScriptPath != "" { - agentDir := filepath.Dir(agentScriptPath) - result = append(result, filepath.Join(agentDir, "..", "data", "config", "cli-config.json")) - } - - home := os.Getenv("HOME") - if home == "" { - home = os.Getenv("USERPROFILE") - } - - switch runtime.GOOS { - case "windows": - local := os.Getenv("LOCALAPPDATA") - if local == "" { - local = filepath.Join(home, "AppData", "Local") - } - result = append(result, filepath.Join(local, "cursor-agent", "cli-config.json")) - case "darwin": - result = append(result, filepath.Join(home, "Library", "Application Support", "cursor-agent", "cli-config.json")) - default: - xdg := os.Getenv("XDG_CONFIG_HOME") - if xdg == "" { - xdg = filepath.Join(home, ".config") - } - result = append(result, filepath.Join(xdg, "cursor-agent", "cli-config.json")) - } - - return result -} - -func RunMaxModePreflight(agentScriptPath, configDirOverride string) { - for _, candidate := range getCandidates(agentScriptPath, configDirOverride) { - data, err := os.ReadFile(candidate) - if err != nil { - continue - } - - // Strip BOM if present - if len(data) >= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF { - data = data[3:] - } - - var raw map[string]interface{} - if err := json.Unmarshal(data, &raw); err != nil { - continue - } - if raw == nil || len(raw) <= 1 { - continue - } - - raw["maxMode"] = true - if model, ok := raw["model"].(map[string]interface{}); ok { - model["maxMode"] = true - } - - out, err := json.MarshalIndent(raw, "", " ") - if err != nil { - continue - } - if err := os.WriteFile(candidate, out, 0644); err != nil { - continue - } - return - } -} diff --git a/internal/agent/runner.go b/internal/agent/runner.go deleted file mode 100644 index 2428b4c..0000000 --- a/internal/agent/runner.go +++ /dev/null @@ -1,72 +0,0 @@ -package agent - -import ( - "context" - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/process" - "os" - "path/filepath" -) - -func init() { - process.MaxModeFn = RunMaxModePreflight -} - -func cacheTokenForAccount(configDir string) { - if configDir == "" { - return - } - token := ReadKeychainToken() - if token != "" { - WriteCachedToken(configDir, token) - } -} - -func AccountsDir() string { - home := os.Getenv("HOME") - if home == "" { - home = os.Getenv("USERPROFILE") - } - return filepath.Join(home, ".cursor-api-proxy", "accounts") -} - -func RunAgentSync(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, tempDir, configDir string, ctx context.Context) (process.RunResult, error) { - opts := process.RunOptions{ - Cwd: workspaceDir, - TimeoutMs: cfg.TimeoutMs, - MaxMode: cfg.MaxMode, - ConfigDir: configDir, - Ctx: ctx, - } - - result, err := process.Run(cfg.AgentBin, cmdArgs, opts) - - cacheTokenForAccount(configDir) - if tempDir != "" { - os.RemoveAll(tempDir) - } - - return result, err -} - -func RunAgentStreamWithContext(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, onLine func(string), tempDir, configDir string, ctx context.Context) (process.StreamResult, error) { - opts := process.RunStreamingOptions{ - RunOptions: process.RunOptions{ - Cwd: workspaceDir, - TimeoutMs: cfg.TimeoutMs, - MaxMode: cfg.MaxMode, - ConfigDir: configDir, - Ctx: ctx, - }, - OnLine: onLine, - } - - result, err := process.RunStreaming(cfg.AgentBin, cmdArgs, opts) - - cacheTokenForAccount(configDir) - if tempDir != "" { - os.RemoveAll(tempDir) - } - - return result, err -} diff --git a/internal/agent/token.go b/internal/agent/token.go deleted file mode 100644 index cb9d025..0000000 --- a/internal/agent/token.go +++ /dev/null @@ -1,36 +0,0 @@ -package agent - -import ( - "os" - "os/exec" - "path/filepath" - "runtime" - "strings" -) - -const tokenFile = ".cursor-token" - -func ReadCachedToken(configDir string) string { - p := filepath.Join(configDir, tokenFile) - data, err := os.ReadFile(p) - if err != nil { - return "" - } - return strings.TrimSpace(string(data)) -} - -func WriteCachedToken(configDir, token string) { - p := filepath.Join(configDir, tokenFile) - _ = os.WriteFile(p, []byte(token), 0600) -} - -func ReadKeychainToken() string { - if runtime.GOOS != "darwin" { - return "" - } - out, err := exec.Command("security", "find-generic-password", "-s", "cursor-access-token", "-w").Output() - if err != nil { - return "" - } - return strings.TrimSpace(string(out)) -} diff --git a/internal/anthropic/anthropic.go b/internal/anthropic/anthropic.go deleted file mode 100644 index 1f29404..0000000 --- a/internal/anthropic/anthropic.go +++ /dev/null @@ -1,174 +0,0 @@ -package anthropic - -import ( - "cursor-api-proxy/internal/openai" - "encoding/json" - "fmt" - "strings" -) - -type MessageParam struct { - Role string `json:"role"` - Content interface{} `json:"content"` -} - -type MessagesRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - Messages []MessageParam `json:"messages"` - System interface{} `json:"system"` - Stream bool `json:"stream"` - Tools []interface{} `json:"tools"` -} - -func systemToText(system interface{}) string { - if system == nil { - return "" - } - switch v := system.(type) { - case string: - return strings.TrimSpace(v) - case []interface{}: - var parts []string - for _, p := range v { - if m, ok := p.(map[string]interface{}); ok { - if m["type"] == "text" { - if t, ok := m["text"].(string); ok { - parts = append(parts, t) - } - } - } - } - return strings.Join(parts, "\n") - } - return "" -} - -func anthropicBlockToText(p interface{}) string { - if p == nil { - return "" - } - switch v := p.(type) { - case string: - return v - case map[string]interface{}: - typ, _ := v["type"].(string) - switch typ { - case "text": - if t, ok := v["text"].(string); ok { - return t - } - case "image": - if src, ok := v["source"].(map[string]interface{}); ok { - srcType, _ := src["type"].(string) - switch srcType { - case "base64": - mt, _ := src["media_type"].(string) - if mt == "" { - mt = "image" - } - return "[Image: base64 " + mt + "]" - case "url": - url, _ := src["url"].(string) - return "[Image: " + url + "]" - } - } - return "[Image]" - case "document": - title, _ := v["title"].(string) - if title == "" { - if src, ok := v["source"].(map[string]interface{}); ok { - title, _ = src["url"].(string) - } - } - if title != "" { - return "[Document: " + title + "]" - } - return "[Document]" - case "tool_use": - name, _ := v["name"].(string) - id, _ := v["id"].(string) - input := v["input"] - inputJSON, _ := json.Marshal(input) - if inputJSON == nil { - inputJSON = []byte("{}") - } - tag := fmt.Sprintf("\n{\"name\": \"%s\", \"arguments\": %s}\n", name, string(inputJSON)) - if id != "" { - tag = fmt.Sprintf("[tool_use_id=%s] ", id) + tag - } - return tag - case "tool_result": - toolUseID, _ := v["tool_use_id"].(string) - content := v["content"] - var contentText string - switch c := content.(type) { - case string: - contentText = c - case []interface{}: - var parts []string - for _, block := range c { - if bm, ok := block.(map[string]interface{}); ok { - if bm["type"] == "text" { - if t, ok := bm["text"].(string); ok { - parts = append(parts, t) - } - } - } - } - contentText = strings.Join(parts, "\n") - } - label := "Tool result" - if toolUseID != "" { - label += " [id=" + toolUseID + "]" - } - return label + ": " + contentText - } - } - return "" -} - -func anthropicContentToText(content interface{}) string { - switch v := content.(type) { - case string: - return v - case []interface{}: - var parts []string - for _, p := range v { - if t := anthropicBlockToText(p); t != "" { - parts = append(parts, t) - } - } - return strings.Join(parts, " ") - } - return "" -} - -func BuildPromptFromAnthropicMessages(messages []MessageParam, system interface{}) string { - var oaiMessages []interface{} - - systemText := systemToText(system) - if systemText != "" { - oaiMessages = append(oaiMessages, map[string]interface{}{ - "role": "system", - "content": systemText, - }) - } - - for _, m := range messages { - text := anthropicContentToText(m.Content) - if text == "" { - continue - } - role := m.Role - if role != "user" && role != "assistant" { - role = "user" - } - oaiMessages = append(oaiMessages, map[string]interface{}{ - "role": role, - "content": text, - }) - } - - return openai.BuildPromptFromMessages(oaiMessages) -} diff --git a/internal/anthropic/anthropic_test.go b/internal/anthropic/anthropic_test.go deleted file mode 100644 index 8237a5a..0000000 --- a/internal/anthropic/anthropic_test.go +++ /dev/null @@ -1,109 +0,0 @@ -package anthropic_test - -import ( - "cursor-api-proxy/internal/anthropic" - "strings" - "testing" -) - -func TestBuildPromptFromAnthropicMessages_Simple(t *testing.T) { - messages := []anthropic.MessageParam{ - {Role: "user", Content: "Hello"}, - {Role: "assistant", Content: "Hi there"}, - } - prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) - if !strings.Contains(prompt, "Hello") { - t.Errorf("prompt missing user message: %q", prompt) - } - if !strings.Contains(prompt, "Hi there") { - t.Errorf("prompt missing assistant message: %q", prompt) - } -} - -func TestBuildPromptFromAnthropicMessages_WithSystem(t *testing.T) { - messages := []anthropic.MessageParam{ - {Role: "user", Content: "ping"}, - } - prompt := anthropic.BuildPromptFromAnthropicMessages(messages, "You are a helpful bot.") - if !strings.Contains(prompt, "You are a helpful bot.") { - t.Errorf("prompt missing system: %q", prompt) - } - if !strings.Contains(prompt, "ping") { - t.Errorf("prompt missing user: %q", prompt) - } -} - -func TestBuildPromptFromAnthropicMessages_SystemArray(t *testing.T) { - system := []interface{}{ - map[string]interface{}{"type": "text", "text": "Part A"}, - map[string]interface{}{"type": "text", "text": "Part B"}, - } - messages := []anthropic.MessageParam{ - {Role: "user", Content: "test"}, - } - prompt := anthropic.BuildPromptFromAnthropicMessages(messages, system) - if !strings.Contains(prompt, "Part A") { - t.Errorf("prompt missing Part A: %q", prompt) - } - if !strings.Contains(prompt, "Part B") { - t.Errorf("prompt missing Part B: %q", prompt) - } -} - -func TestBuildPromptFromAnthropicMessages_ContentBlocks(t *testing.T) { - content := []interface{}{ - map[string]interface{}{"type": "text", "text": "block one"}, - map[string]interface{}{"type": "text", "text": "block two"}, - } - messages := []anthropic.MessageParam{ - {Role: "user", Content: content}, - } - prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) - if !strings.Contains(prompt, "block one") { - t.Errorf("prompt missing 'block one': %q", prompt) - } - if !strings.Contains(prompt, "block two") { - t.Errorf("prompt missing 'block two': %q", prompt) - } -} - -func TestBuildPromptFromAnthropicMessages_ImageBlock(t *testing.T) { - content := []interface{}{ - map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "base64", - "media_type": "image/png", - "data": "abc123", - }, - }, - } - messages := []anthropic.MessageParam{ - {Role: "user", Content: content}, - } - prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) - if !strings.Contains(prompt, "[Image") { - t.Errorf("prompt missing [Image]: %q", prompt) - } -} - -func TestBuildPromptFromAnthropicMessages_EmptyContentSkipped(t *testing.T) { - messages := []anthropic.MessageParam{ - {Role: "user", Content: ""}, - {Role: "assistant", Content: "response"}, - } - prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) - if !strings.Contains(prompt, "response") { - t.Errorf("prompt missing 'response': %q", prompt) - } -} - -func TestBuildPromptFromAnthropicMessages_UnknownRoleBecomesUser(t *testing.T) { - messages := []anthropic.MessageParam{ - {Role: "system", Content: "system-as-user"}, - } - prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) - if !strings.Contains(prompt, "system-as-user") { - t.Errorf("prompt missing 'system-as-user': %q", prompt) - } -} diff --git a/internal/apitypes/types.go b/internal/apitypes/types.go deleted file mode 100644 index abdcfcd..0000000 --- a/internal/apitypes/types.go +++ /dev/null @@ -1,40 +0,0 @@ -package apitypes - -type Message struct { - Role string - Content string -} - -type Tool struct { - Type string - Function ToolFunction -} - -type ToolFunction struct { - Name string - Description string - Parameters interface{} -} - -type ToolCall struct { - ID string - Name string - Arguments string -} - -type StreamChunk struct { - Type ChunkType - Text string - Thinking string - ToolCall *ToolCall - Done bool -} - -type ChunkType int - -const ( - ChunkText ChunkType = iota - ChunkThinking - ChunkToolCall - ChunkDone -) diff --git a/internal/config/config.go b/internal/config/config.go index 6c959b9..b60e758 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,7 @@ package config import ( - "cursor-api-proxy/internal/env" + "cursor-api-proxy/pkg/infrastructure/env" "github.com/zeromicro/go-zero/rest" ) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 15f2abc..c9ce0ec 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -2,7 +2,7 @@ package config_test import ( "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/env" + "cursor-api-proxy/pkg/infrastructure/env" "path/filepath" "strings" "testing" diff --git a/internal/env/env.go b/internal/env/env.go deleted file mode 100644 index 45dc3be..0000000 --- a/internal/env/env.go +++ /dev/null @@ -1,381 +0,0 @@ -package env - -import ( - "encoding/json" - "os" - "path/filepath" - "runtime" - "strconv" - "strings" -) - -type EnvSource map[string]string - -type LoadedEnv struct { - AgentBin string - AgentNode string - AgentScript string - CommandShell string - Host string - Port int - RequiredKey string - DefaultModel string - Provider string - Force bool - ApproveMcps bool - StrictModel bool - Workspace string - TimeoutMs int - TLSCertPath string - TLSKeyPath string - SessionsLogPath string - ChatOnlyWorkspace bool - Verbose bool - MaxMode bool - ConfigDirs []string - MultiPort bool - WinCmdlineMax int - GeminiAccountDir string - GeminiBrowserVisible bool - GeminiMaxSessions int -} - -type AgentCommand struct { - Command string - Args []string - Env map[string]string - WindowsVerbatimArguments bool - AgentScriptPath string - ConfigDir string -} - -func getEnvVal(e EnvSource, names []string) string { - for _, name := range names { - if v, ok := e[name]; ok && strings.TrimSpace(v) != "" { - return strings.TrimSpace(v) - } - } - return "" -} - -func envBool(e EnvSource, names []string, def bool) bool { - raw := getEnvVal(e, names) - if raw == "" { - return def - } - switch strings.ToLower(raw) { - case "1", "true", "yes", "on": - return true - case "0", "false", "no", "off": - return false - } - return def -} - -func envInt(e EnvSource, names []string, def int) int { - raw := getEnvVal(e, names) - if raw == "" { - return def - } - v, err := strconv.Atoi(raw) - if err != nil { - return def - } - return v -} - -func normalizeModelId(raw string) string { - raw = strings.TrimSpace(raw) - if raw == "" { - return "auto" - } - parts := strings.Split(raw, "/") - last := parts[len(parts)-1] - if last == "" { - return "auto" - } - return last -} - -func resolveAbs(raw, cwd string) string { - if raw == "" { - return "" - } - if filepath.IsAbs(raw) { - return raw - } - return filepath.Join(cwd, raw) -} - -func isAuthenticatedAccountDir(dir string) bool { - data, err := os.ReadFile(filepath.Join(dir, "cli-config.json")) - if err != nil { - return false - } - var cfg struct { - AuthInfo *struct { - Email string `json:"email"` - } `json:"authInfo"` - } - if err := json.Unmarshal(data, &cfg); err != nil { - return false - } - return cfg.AuthInfo != nil && cfg.AuthInfo.Email != "" -} - -func discoverAccountDirs(homeDir string) []string { - if homeDir == "" { - return nil - } - accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts") - entries, err := os.ReadDir(accountsDir) - if err != nil { - return nil - } - var dirs []string - for _, e := range entries { - if !e.IsDir() { - continue - } - dir := filepath.Join(accountsDir, e.Name()) - if isAuthenticatedAccountDir(dir) { - dirs = append(dirs, dir) - } - } - return dirs -} - -func parseDotEnv(path string) EnvSource { - data, err := os.ReadFile(path) - if err != nil { - return nil - } - m := make(EnvSource) - for _, line := range strings.Split(string(data), "\n") { - line = strings.TrimSpace(line) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - parts := strings.SplitN(line, "=", 2) - if len(parts) == 2 { - m[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) - } - } - return m -} - -func OsEnvToMap(cwdHint ...string) EnvSource { - m := make(EnvSource) - for _, kv := range os.Environ() { - parts := strings.SplitN(kv, "=", 2) - if len(parts) == 2 { - m[parts[0]] = parts[1] - } - } - - cwd := "" - if len(cwdHint) > 0 && cwdHint[0] != "" { - cwd = cwdHint[0] - } else { - cwd, _ = os.Getwd() - } - - if dotenv := parseDotEnv(filepath.Join(cwd, ".env")); dotenv != nil { - for k, v := range dotenv { - if _, exists := m[k]; !exists { - m[k] = v - } - } - } - - return m -} - -func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv { - if e == nil { - e = OsEnvToMap() - } - if cwd == "" { - var err error - cwd, err = os.Getwd() - if err != nil { - cwd = "." - } - } - - host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"}) - if host == "" { - host = "127.0.0.1" - } - port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765) - if port <= 0 { - port = 8765 - } - - home := getEnvVal(e, []string{"HOME", "USERPROFILE"}) - - sessionsLogPath := func() string { - if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" { - return p - } - if home != "" { - return filepath.Join(home, ".cursor-api-proxy", "sessions.log") - } - return filepath.Join(cwd, "sessions.log") - }() - - var configDirs []string - if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" { - for _, d := range strings.Split(raw, ",") { - d = strings.TrimSpace(d) - if d != "" { - if p := resolveAbs(d, cwd); p != "" { - configDirs = append(configDirs, p) - } - } - } - } - if len(configDirs) == 0 { - configDirs = discoverAccountDirs(home) - } - - winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000) - if winMax < 4096 { - winMax = 4096 - } - if winMax > 32700 { - winMax = 32700 - } - - agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"}) - if agentBin == "" { - agentBin = "agent" - } - commandShell := getEnvVal(e, []string{"COMSPEC"}) - if commandShell == "" { - commandShell = "cmd.exe" - } - workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd) - if workspace == "" { - workspace = cwd - } - - geminiAccountDir := getEnvVal(e, []string{"GEMINI_ACCOUNT_DIR"}) - if geminiAccountDir == "" { - geminiAccountDir = filepath.Join(home, ".cursor-api-proxy", "gemini-accounts") - } else { - geminiAccountDir = resolveAbs(geminiAccountDir, cwd) - } - - return LoadedEnv{ - AgentBin: agentBin, - AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}), - AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}), - CommandShell: commandShell, - Host: host, - Port: port, - RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}), - DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})), - Provider: getEnvVal(e, []string{"CURSOR_BRIDGE_PROVIDER"}), - Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false), - ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false), - StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true), - Workspace: workspace, - TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000), - TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd), - TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd), - SessionsLogPath: sessionsLogPath, - ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true), - Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false), - MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false), - ConfigDirs: configDirs, - MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false), - WinCmdlineMax: winMax, - GeminiAccountDir: geminiAccountDir, - GeminiBrowserVisible: envBool(e, []string{"GEMINI_BROWSER_VISIBLE"}, false), - GeminiMaxSessions: envInt(e, []string{"GEMINI_MAX_SESSIONS"}, 3), - } -} - -func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand { - if e == nil { - e = OsEnvToMap() - } - loaded := LoadEnvConfig(e, cwd) - - cloneEnv := func() map[string]string { - m := make(map[string]string, len(e)) - for k, v := range e { - m[k] = v - } - return m - } - - if runtime.GOOS == "windows" { - if loaded.AgentNode != "" && loaded.AgentScript != "" { - agentScriptPath := loaded.AgentScript - if !filepath.IsAbs(agentScriptPath) { - agentScriptPath = filepath.Join(cwd, agentScriptPath) - } - agentDir := filepath.Dir(agentScriptPath) - configDir := filepath.Join(agentDir, "..", "data", "config") - env2 := cloneEnv() - env2["CURSOR_INVOKED_AS"] = "agent.cmd" - ac := AgentCommand{ - Command: loaded.AgentNode, - Args: append([]string{loaded.AgentScript}, args...), - Env: env2, - AgentScriptPath: agentScriptPath, - } - if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil { - ac.ConfigDir = configDir - } - return ac - } - - if strings.HasSuffix(strings.ToLower(cmd), ".cmd") { - cmdResolved := cmd - if !filepath.IsAbs(cmd) { - cmdResolved = filepath.Join(cwd, cmd) - } - dir := filepath.Dir(cmdResolved) - nodeBin := filepath.Join(dir, "node.exe") - script := filepath.Join(dir, "index.js") - if _, err1 := os.Stat(nodeBin); err1 == nil { - if _, err2 := os.Stat(script); err2 == nil { - configDir := filepath.Join(dir, "..", "data", "config") - env2 := cloneEnv() - env2["CURSOR_INVOKED_AS"] = "agent.cmd" - ac := AgentCommand{ - Command: nodeBin, - Args: append([]string{script}, args...), - Env: env2, - AgentScriptPath: script, - } - if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil { - ac.ConfigDir = configDir - } - return ac - } - } - - quotedArgs := make([]string, len(args)) - for i, a := range args { - if strings.Contains(a, " ") { - quotedArgs[i] = `"` + a + `"` - } else { - quotedArgs[i] = a - } - } - cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"` - return AgentCommand{ - Command: loaded.CommandShell, - Args: []string{"/d", "/s", "/c", cmdLine}, - Env: cloneEnv(), - WindowsVerbatimArguments: true, - } - } - } - - return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()} -} diff --git a/internal/env/env_test.go b/internal/env/env_test.go deleted file mode 100644 index e589d59..0000000 --- a/internal/env/env_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package env - -import "testing" - -func TestLoadEnvConfigDefaults(t *testing.T) { - e := EnvSource{} - loaded := LoadEnvConfig(e, "/tmp") - - if loaded.Host != "127.0.0.1" { - t.Errorf("expected 127.0.0.1, got %s", loaded.Host) - } - if loaded.Port != 8765 { - t.Errorf("expected 8765, got %d", loaded.Port) - } - if loaded.DefaultModel != "auto" { - t.Errorf("expected auto, got %s", loaded.DefaultModel) - } - if loaded.AgentBin != "agent" { - t.Errorf("expected agent, got %s", loaded.AgentBin) - } - if !loaded.StrictModel { - t.Error("expected strictModel=true by default") - } -} - -func TestLoadEnvConfigOverride(t *testing.T) { - e := EnvSource{ - "CURSOR_BRIDGE_HOST": "0.0.0.0", - "CURSOR_BRIDGE_PORT": "9000", - "CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4", - "CURSOR_AGENT_BIN": "/usr/local/bin/agent", - } - loaded := LoadEnvConfig(e, "/tmp") - - if loaded.Host != "0.0.0.0" { - t.Errorf("expected 0.0.0.0, got %s", loaded.Host) - } - if loaded.Port != 9000 { - t.Errorf("expected 9000, got %d", loaded.Port) - } - if loaded.DefaultModel != "gpt-4" { - t.Errorf("expected gpt-4, got %s", loaded.DefaultModel) - } - if loaded.AgentBin != "/usr/local/bin/agent" { - t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin) - } -} - -func TestNormalizeModelID(t *testing.T) { - tests := []struct { - input string - want string - }{ - {"gpt-4", "gpt-4"}, - {"openai/gpt-4", "gpt-4"}, - {"", "auto"}, - {" ", "auto"}, - } - for _, tc := range tests { - got := normalizeModelId(tc.input) - if got != tc.want { - t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want) - } - } -} diff --git a/internal/handlers/anthropic_handler.go b/internal/handlers/anthropic_handler.go deleted file mode 100644 index fc2f74e..0000000 --- a/internal/handlers/anthropic_handler.go +++ /dev/null @@ -1,577 +0,0 @@ -package handlers - -import ( - "context" - "cursor-api-proxy/internal/agent" - "cursor-api-proxy/internal/anthropic" - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/httputil" - "cursor-api-proxy/internal/logger" - "cursor-api-proxy/internal/models" - "cursor-api-proxy/internal/openai" - "cursor-api-proxy/internal/parser" - "cursor-api-proxy/internal/pool" - "cursor-api-proxy/internal/sanitize" - "cursor-api-proxy/internal/toolcall" - "cursor-api-proxy/internal/winlimit" - "cursor-api-proxy/internal/workspace" - "encoding/json" - "fmt" - "net/http" - "regexp" - "strings" - "time" - - "github.com/google/uuid" -) - -func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { - var req anthropic.MessagesRequest - if err := json.Unmarshal([]byte(rawBody), &req); err != nil { - httputil.WriteJSON(w, 400, map[string]interface{}{ - "error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"}, - }, nil) - return - } - - requested := openai.NormalizeModelID(req.Model) - model := ResolveModel(requested, lastModelRef, cfg) - - var rawMap map[string]interface{} - _ = json.Unmarshal([]byte(rawBody), &rawMap) - - cleanSystem := sanitize.SanitizeSystem(req.System) - - rawMessages := make([]interface{}, len(req.Messages)) - for i, m := range req.Messages { - rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content} - } - cleanRawMessages := sanitize.SanitizeMessages(rawMessages) - - var cleanMessages []anthropic.MessageParam - for _, raw := range cleanRawMessages { - if m, ok := raw.(map[string]interface{}); ok { - role, _ := m["role"].(string) - cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]}) - } - } - - toolsText := openai.ToolsToSystemText(req.Tools, nil) - var systemWithTools interface{} - if toolsText != "" { - sysStr := "" - switch v := cleanSystem.(type) { - case string: - sysStr = v - } - if sysStr != "" { - systemWithTools = sysStr + "\n\n" + toolsText - } else { - systemWithTools = toolsText - } - } else { - systemWithTools = cleanSystem - } - - prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools) - - if req.MaxTokens == 0 { - httputil.WriteJSON(w, 400, map[string]interface{}{ - "error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"}, - }, nil) - return - } - - cursorModel := models.ResolveToCursorModel(model) - if cursorModel == "" { - cursorModel = model - } - - var trafficMsgs []logger.TrafficMessage - if s := systemToString(cleanSystem); s != "" { - trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s}) - } - for _, m := range cleanMessages { - text := contentToString(m.Content) - if text != "" { - trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text}) - } - } - logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream) - - headerWs := r.Header.Get("x-cursor-workspace") - ws := workspace.ResolveWorkspace(cfg, headerWs) - - fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream) - fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir) - - if cfg.Verbose { - if len(prompt) > 200 { - logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...") - } else { - logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt) - } - logger.LogDebug("cmd_args=%v", fit.Args) - } - - if !fit.OK { - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"type": "api_error", "message": fit.Error}, - }, nil) - return - } - if fit.Truncated { - logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength) - } - - cmdArgs := fit.Args - msgID := "msg_" + uuid.New().String() - - var truncatedHeaders map[string]string - if fit.Truncated { - truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} - } - - hasTools := len(req.Tools) > 0 - var toolNames map[string]bool - if hasTools { - toolNames = toolcall.CollectToolNames(req.Tools) - } - - if req.Stream { - httputil.WriteSSEHeaders(w, truncatedHeaders) - flusher, _ := w.(http.Flusher) - - writeEvent := func(evt interface{}) { - data, _ := json.Marshal(evt) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - } - - var accumulated string - var accumulatedThinking string - var chunkNum int - var p parser.Parser - - writeEvent(map[string]interface{}{ - "type": "message_start", - "message": map[string]interface{}{ - "id": msgID, - "type": "message", - "role": "assistant", - "model": model, - "content": []interface{}{}, - }, - }) - - if hasTools { - toolCallMarkerRe := regexp.MustCompile(`行政法规|`) - var toolCallMode bool - - textBlockOpen := false - textBlockIndex := 0 - thinkingOpen := false - thinkingBlockIndex := 0 - blockCount := 0 - - p = parser.CreateStreamParserWithThinking( - func(text string) { - accumulated += text - chunkNum++ - logger.LogStreamChunk(model, text, chunkNum) - - if toolCallMode { - return - } - if toolCallMarkerRe.MatchString(text) { - if textBlockOpen { - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex}) - textBlockOpen = false - } - if thinkingOpen { - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex}) - thinkingOpen = false - } - toolCallMode = true - return - } - if !textBlockOpen && !thinkingOpen { - textBlockIndex = blockCount - writeEvent(map[string]interface{}{ - "type": "content_block_start", - "index": textBlockIndex, - "content_block": map[string]string{"type": "text", "text": ""}, - }) - textBlockOpen = true - blockCount++ - } - if thinkingOpen { - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex}) - thinkingOpen = false - } - writeEvent(map[string]interface{}{ - "type": "content_block_delta", - "index": textBlockIndex, - "delta": map[string]string{"type": "text_delta", "text": text}, - }) - }, - func(thinking string) { - accumulatedThinking += thinking - chunkNum++ - if toolCallMode { - return - } - if !thinkingOpen { - thinkingBlockIndex = blockCount - writeEvent(map[string]interface{}{ - "type": "content_block_start", - "index": thinkingBlockIndex, - "content_block": map[string]string{"type": "thinking", "thinking": ""}, - }) - thinkingOpen = true - blockCount++ - } - writeEvent(map[string]interface{}{ - "type": "content_block_delta", - "index": thinkingBlockIndex, - "delta": map[string]string{"type": "thinking_delta", "thinking": thinking}, - }) - }, - func() { - logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) - parsed := toolcall.ExtractToolCalls(accumulated, toolNames) - - blockIndex := 0 - if thinkingOpen { - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex}) - thinkingOpen = false - } - - if parsed.HasToolCalls() { - if textBlockOpen { - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex}) - blockIndex = textBlockIndex + 1 - } - if parsed.TextContent != "" && !textBlockOpen && !toolCallMode { - writeEvent(map[string]interface{}{ - "type": "content_block_start", "index": blockIndex, - "content_block": map[string]string{"type": "text", "text": ""}, - }) - writeEvent(map[string]interface{}{ - "type": "content_block_delta", "index": blockIndex, - "delta": map[string]string{"type": "text_delta", "text": parsed.TextContent}, - }) - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) - blockIndex++ - } - for _, tc := range parsed.ToolCalls { - toolID := "toolu_" + uuid.New().String()[:12] - var inputObj interface{} - _ = json.Unmarshal([]byte(tc.Arguments), &inputObj) - if inputObj == nil { - inputObj = map[string]interface{}{} - } - writeEvent(map[string]interface{}{ - "type": "content_block_start", "index": blockIndex, - "content_block": map[string]interface{}{ - "type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{}, - }, - }) - writeEvent(map[string]interface{}{ - "type": "content_block_delta", "index": blockIndex, - "delta": map[string]interface{}{ - "type": "input_json_delta", "partial_json": tc.Arguments, - }, - }) - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) - blockIndex++ - } - writeEvent(map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil}, - "usage": map[string]int{"output_tokens": 0}, - }) - writeEvent(map[string]interface{}{"type": "message_stop"}) - } else { - if textBlockOpen { - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex}) - } else if accumulated != "" { - writeEvent(map[string]interface{}{ - "type": "content_block_start", "index": blockIndex, - "content_block": map[string]string{"type": "text", "text": ""}, - }) - writeEvent(map[string]interface{}{ - "type": "content_block_delta", "index": blockIndex, - "delta": map[string]string{"type": "text_delta", "text": accumulated}, - }) - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) - blockIndex++ - } else { - writeEvent(map[string]interface{}{ - "type": "content_block_start", "index": blockIndex, - "content_block": map[string]string{"type": "text", "text": ""}, - }) - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) - blockIndex++ - } - writeEvent(map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, - "usage": map[string]int{"output_tokens": 0}, - }) - writeEvent(map[string]interface{}{"type": "message_stop"}) - } - }, - ) - } else { - // 非 tools 模式:即時串流 thinking 和 text - blockCount := 0 - thinkingOpen := false - textOpen := false - - p = parser.CreateStreamParserWithThinking( - func(text string) { - accumulated += text - chunkNum++ - logger.LogStreamChunk(model, text, chunkNum) - if thinkingOpen { - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) - thinkingOpen = false - } - if !textOpen { - writeEvent(map[string]interface{}{ - "type": "content_block_start", - "index": blockCount, - "content_block": map[string]string{"type": "text", "text": ""}, - }) - textOpen = true - blockCount++ - } - writeEvent(map[string]interface{}{ - "type": "content_block_delta", - "index": blockCount - 1, - "delta": map[string]string{"type": "text_delta", "text": text}, - }) - }, - func(thinking string) { - accumulatedThinking += thinking - chunkNum++ - if !thinkingOpen { - writeEvent(map[string]interface{}{ - "type": "content_block_start", - "index": blockCount, - "content_block": map[string]string{"type": "thinking", "thinking": ""}, - }) - thinkingOpen = true - blockCount++ - } - writeEvent(map[string]interface{}{ - "type": "content_block_delta", - "index": blockCount - 1, - "delta": map[string]string{"type": "thinking_delta", "thinking": thinking}, - }) - }, - func() { - logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) - if thinkingOpen { - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) - thinkingOpen = false - } - if !textOpen { - writeEvent(map[string]interface{}{ - "type": "content_block_start", - "index": blockCount, - "content_block": map[string]string{"type": "text", "text": ""}, - }) - blockCount++ - } - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) - writeEvent(map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, - "usage": map[string]int{"output_tokens": 0}, - }) - writeEvent(map[string]interface{}{"type": "message_stop"}) - }, - ) - } - - configDir := ph.GetNextConfigDir() - logger.LogAccountAssigned(configDir) - ph.ReportRequestStart(configDir) - logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) - streamStart := time.Now().UnixMilli() - - ctx := r.Context() - wrappedParser := func(line string) { - logger.LogRawLine(line) - p.Parse(line) - } - result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx) - - if ctx.Err() == nil { - p.Flush() - } - - latencyMs := time.Now().UnixMilli() - streamStart - ph.ReportRequestEnd(configDir) - - if ctx.Err() == context.DeadlineExceeded { - logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) - } else if ctx.Err() == context.Canceled { - logger.LogClientDisconnect(method, pathname, model, latencyMs) - } else if err == nil && isRateLimited(result.Stderr) { - ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) - } - - if err != nil || (result.Code != 0 && ctx.Err() == nil) { - ph.ReportRequestError(configDir, latencyMs) - if err != nil { - logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) - } else { - logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr) - } - logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) - } else if ctx.Err() == nil { - ph.ReportRequestSuccess(configDir, latencyMs) - logger.LogRequestDone(method, pathname, model, latencyMs, 0) - } - logger.LogAccountStats(cfg.Verbose, ph.GetStats()) - return - } - - configDir := ph.GetNextConfigDir() - logger.LogAccountAssigned(configDir) - ph.ReportRequestStart(configDir) - logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false) - syncStart := time.Now().UnixMilli() - - out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context()) - syncLatency := time.Now().UnixMilli() - syncStart - ph.ReportRequestEnd(configDir) - - ctx := r.Context() - if ctx.Err() == context.DeadlineExceeded { - logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) - httputil.WriteJSON(w, 504, map[string]interface{}{ - "error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)}, - }, nil) - return - } - if ctx.Err() == context.Canceled { - logger.LogClientDisconnect(method, pathname, model, syncLatency) - return - } - - if err != nil { - ph.ReportRequestError(configDir, syncLatency) - logger.LogAccountStats(cfg.Verbose, ph.GetStats()) - logger.LogRequestDone(method, pathname, model, syncLatency, -1) - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"type": "api_error", "message": err.Error()}, - }, nil) - return - } - - if isRateLimited(out.Stderr) { - ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr)) - } - - if out.Code != 0 { - ph.ReportRequestError(configDir, syncLatency) - logger.LogAccountStats(cfg.Verbose, ph.GetStats()) - errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr) - logger.LogRequestDone(method, pathname, model, syncLatency, out.Code) - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"type": "api_error", "message": errMsg}, - }, nil) - return - } - - ph.ReportRequestSuccess(configDir, syncLatency) - content := strings.TrimSpace(out.Stdout) - logger.LogTrafficResponse(cfg.Verbose, model, content, false) - logger.LogAccountStats(cfg.Verbose, ph.GetStats()) - logger.LogRequestDone(method, pathname, model, syncLatency, 0) - - if hasTools { - parsed := toolcall.ExtractToolCalls(content, toolNames) - if parsed.HasToolCalls() { - var contentBlocks []map[string]interface{} - if parsed.TextContent != "" { - contentBlocks = append(contentBlocks, map[string]interface{}{ - "type": "text", "text": parsed.TextContent, - }) - } - for _, tc := range parsed.ToolCalls { - toolID := "toolu_" + uuid.New().String()[:12] - var inputObj interface{} - _ = json.Unmarshal([]byte(tc.Arguments), &inputObj) - if inputObj == nil { - inputObj = map[string]interface{}{} - } - contentBlocks = append(contentBlocks, map[string]interface{}{ - "type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj, - }) - } - httputil.WriteJSON(w, 200, map[string]interface{}{ - "id": msgID, - "type": "message", - "role": "assistant", - "content": contentBlocks, - "model": model, - "stop_reason": "tool_use", - "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, - }, truncatedHeaders) - return - } - } - - httputil.WriteJSON(w, 200, map[string]interface{}{ - "id": msgID, - "type": "message", - "role": "assistant", - "content": []map[string]string{{"type": "text", "text": content}}, - "model": model, - "stop_reason": "end_turn", - "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, - }, truncatedHeaders) -} - -func systemToString(system interface{}) string { - switch v := system.(type) { - case string: - return v - case []interface{}: - result := "" - for _, p := range v { - if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" { - if t, ok := m["text"].(string); ok { - result += t - } - } - } - return result - } - return "" -} - -func contentToString(content interface{}) string { - switch v := content.(type) { - case string: - return v - case []interface{}: - result := "" - for _, p := range v { - if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" { - if t, ok := m["text"].(string); ok { - result += t - } - } - } - return result - } - return "" -} diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go deleted file mode 100644 index 9d67324..0000000 --- a/internal/handlers/chat.go +++ /dev/null @@ -1,471 +0,0 @@ -package handlers - -import ( - "context" - "cursor-api-proxy/internal/agent" - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/httputil" - "cursor-api-proxy/internal/logger" - "cursor-api-proxy/internal/models" - "cursor-api-proxy/internal/openai" - "cursor-api-proxy/internal/parser" - "cursor-api-proxy/internal/pool" - "cursor-api-proxy/internal/sanitize" - "cursor-api-proxy/internal/toolcall" - "cursor-api-proxy/internal/winlimit" - "cursor-api-proxy/internal/workspace" - "encoding/json" - "fmt" - "net/http" - "regexp" - "strconv" - "strings" - "time" - - "github.com/google/uuid" -) - -var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`) -var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`) - -func isRateLimited(stderr string) bool { - return rateLimitRe.MatchString(stderr) -} - -func extractRetryAfterMs(stderr string) int64 { - if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 { - if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 { - return secs * 1000 - } - } - return 60000 -} - -func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { - var bodyMap map[string]interface{} - if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil { - httputil.WriteJSON(w, 400, map[string]interface{}{ - "error": map[string]string{"message": "invalid JSON body", "code": "bad_request"}, - }, nil) - return - } - - rawModel, _ := bodyMap["model"].(string) - requested := openai.NormalizeModelID(rawModel) - model := ResolveModel(requested, lastModelRef, cfg) - cursorModel := models.ResolveToCursorModel(model) - if cursorModel == "" { - cursorModel = model - } - - var messages []interface{} - if m, ok := bodyMap["messages"].([]interface{}); ok { - messages = m - } - - var tools []interface{} - if t, ok := bodyMap["tools"].([]interface{}); ok { - tools = t - } - var funcs []interface{} - if f, ok := bodyMap["functions"].([]interface{}); ok { - funcs = f - } - - cleanMessages := sanitize.SanitizeMessages(messages) - - toolsText := openai.ToolsToSystemText(tools, funcs) - messagesWithTools := cleanMessages - if toolsText != "" { - messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...) - } - prompt := openai.BuildPromptFromMessages(messagesWithTools) - - var trafficMsgs []logger.TrafficMessage - for _, raw := range cleanMessages { - if m, ok := raw.(map[string]interface{}); ok { - role, _ := m["role"].(string) - content := openai.MessageContentToText(m["content"]) - trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content}) - } - } - - isStream := false - if s, ok := bodyMap["stream"].(bool); ok { - isStream = s - } - - logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream) - - headerWs := r.Header.Get("x-cursor-workspace") - ws := workspace.ResolveWorkspace(cfg, headerWs) - - promptLen := len(prompt) - if cfg.Verbose { - if promptLen > 200 { - logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200]) - } else { - logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt) - } - } - - fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream) - fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir) - - if cfg.Verbose { - logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args) - } - - if !fit.OK { - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"}, - }, nil) - return - } - if fit.Truncated { - logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength) - } - - cmdArgs := fit.Args - id := "chatcmpl_" + uuid.New().String() - created := time.Now().Unix() - - var truncatedHeaders map[string]string - if fit.Truncated { - truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} - } - - hasTools := len(tools) > 0 || len(funcs) > 0 - var toolNames map[string]bool - if hasTools { - toolNames = toolcall.CollectToolNames(tools) - for _, f := range funcs { - if fm, ok := f.(map[string]interface{}); ok { - if name, ok := fm["name"].(string); ok { - toolNames[name] = true - } - } - } - } - - if isStream { - httputil.WriteSSEHeaders(w, truncatedHeaders) - flusher, _ := w.(http.Flusher) - - var accumulated string - var chunkNum int - var p parser.Parser - - // toolCallMarkerRe 偵測 tool call 開頭標記,一旦出現就停止即時輸出並進入累積模式 - toolCallMarkerRe := regexp.MustCompile(`|`) - if hasTools { - var toolCallMode bool // 是否已進入 tool call 累積模式 - p = parser.CreateStreamParserWithThinking( - func(text string) { - accumulated += text - chunkNum++ - logger.LogStreamChunk(model, text, chunkNum) - if toolCallMode { - // 已進入累積模式,不即時輸出 - return - } - if toolCallMarkerRe.MatchString(text) { - // 偵測到 tool call 標記,切換為累積模式 - toolCallMode = true - return - } - chunk := map[string]interface{}{ - "id": id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, - }, - } - data, _ := json.Marshal(chunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - }, - func(_ string) {}, // thinking ignored in tools mode - func() { - logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) - parsed := toolcall.ExtractToolCalls(accumulated, toolNames) - - if parsed.HasToolCalls() { - if parsed.TextContent != "" && toolCallMode { - // 已有部分 text 被即時輸出,只補發剩餘的 - chunk := map[string]interface{}{ - "id": id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil}, - }, - } - data, _ := json.Marshal(chunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - } - for i, tc := range parsed.ToolCalls { - callID := "call_" + uuid.New().String()[:8] - chunk := map[string]interface{}{ - "id": id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]interface{}{ - "tool_calls": []map[string]interface{}{ - { - "index": i, - "id": callID, - "type": "function", - "function": map[string]interface{}{ - "name": tc.Name, - "arguments": tc.Arguments, - }, - }, - }, - }, "finish_reason": nil}, - }, - } - data, _ := json.Marshal(chunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - } - stopChunk := map[string]interface{}{ - "id": id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"}, - }, - } - data, _ := json.Marshal(stopChunk) - fmt.Fprintf(w, "data: %s\n\n", data) - fmt.Fprintf(w, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - } else { - stopChunk := map[string]interface{}{ - "id": id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, - }, - } - data, _ := json.Marshal(stopChunk) - fmt.Fprintf(w, "data: %s\n\n", data) - fmt.Fprintf(w, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - } - }, - ) - } else { - p = parser.CreateStreamParserWithThinking( - func(text string) { - accumulated += text - chunkNum++ - logger.LogStreamChunk(model, text, chunkNum) - chunk := map[string]interface{}{ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, - }, - } - data, _ := json.Marshal(chunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - }, - func(thinking string) { - chunk := map[string]interface{}{ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil}, - }, - } - data, _ := json.Marshal(chunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - }, - func() { - logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) - stopChunk := map[string]interface{}{ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, - }, - } - data, _ := json.Marshal(stopChunk) - fmt.Fprintf(w, "data: %s\n\n", data) - fmt.Fprintf(w, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - }, - ) - } - - configDir := ph.GetNextConfigDir() - logger.LogAccountAssigned(configDir) - ph.ReportRequestStart(configDir) - logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) - streamStart := time.Now().UnixMilli() - - ctx := r.Context() - wrappedParser := func(line string) { - logger.LogRawLine(line) - p.Parse(line) - } - result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx) - - // agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾 - if ctx.Err() == nil { - p.Flush() - } - - latencyMs := time.Now().UnixMilli() - streamStart - ph.ReportRequestEnd(configDir) - - if ctx.Err() == context.DeadlineExceeded { - logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) - } else if ctx.Err() == context.Canceled { - logger.LogClientDisconnect(method, pathname, model, latencyMs) - } else if err == nil && isRateLimited(result.Stderr) { - ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) - } - - if err != nil || (result.Code != 0 && ctx.Err() == nil) { - ph.ReportRequestError(configDir, latencyMs) - if err != nil { - logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) - } else { - logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr) - } - logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) - } else if ctx.Err() == nil { - ph.ReportRequestSuccess(configDir, latencyMs) - logger.LogRequestDone(method, pathname, model, latencyMs, 0) - } - logger.LogAccountStats(cfg.Verbose, ph.GetStats()) - return - } - - configDir := ph.GetNextConfigDir() - logger.LogAccountAssigned(configDir) - ph.ReportRequestStart(configDir) - logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false) - syncStart := time.Now().UnixMilli() - - out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context()) - syncLatency := time.Now().UnixMilli() - syncStart - ph.ReportRequestEnd(configDir) - - ctx := r.Context() - if ctx.Err() == context.DeadlineExceeded { - logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) - httputil.WriteJSON(w, 504, map[string]interface{}{ - "error": map[string]string{"message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs), "code": "timeout"}, - }, nil) - return - } - if ctx.Err() == context.Canceled { - logger.LogClientDisconnect(method, pathname, model, syncLatency) - return - } - - if err != nil { - ph.ReportRequestError(configDir, syncLatency) - logger.LogAccountStats(cfg.Verbose, ph.GetStats()) - logger.LogRequestDone(method, pathname, model, syncLatency, -1) - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"}, - }, nil) - return - } - - if isRateLimited(out.Stderr) { - ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr)) - } - - if out.Code != 0 { - ph.ReportRequestError(configDir, syncLatency) - logger.LogAccountStats(cfg.Verbose, ph.GetStats()) - errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr) - logger.LogRequestDone(method, pathname, model, syncLatency, out.Code) - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"message": errMsg, "code": "cursor_cli_error"}, - }, nil) - return - } - - ph.ReportRequestSuccess(configDir, syncLatency) - content := strings.TrimSpace(out.Stdout) - logger.LogTrafficResponse(cfg.Verbose, model, content, false) - logger.LogAccountStats(cfg.Verbose, ph.GetStats()) - logger.LogRequestDone(method, pathname, model, syncLatency, 0) - - if hasTools { - parsed := toolcall.ExtractToolCalls(content, toolNames) - if parsed.HasToolCalls() { - msg := map[string]interface{}{"role": "assistant"} - if parsed.TextContent != "" { - msg["content"] = parsed.TextContent - } else { - msg["content"] = nil - } - var tcArr []map[string]interface{} - for _, tc := range parsed.ToolCalls { - callID := "call_" + uuid.New().String()[:8] - tcArr = append(tcArr, map[string]interface{}{ - "id": callID, - "type": "function", - "function": map[string]interface{}{ - "name": tc.Name, - "arguments": tc.Arguments, - }, - }) - } - msg["tool_calls"] = tcArr - httputil.WriteJSON(w, 200, map[string]interface{}{ - "id": id, - "object": "chat.completion", - "created": created, - "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "message": msg, "finish_reason": "tool_calls"}, - }, - "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - }, truncatedHeaders) - return - } - } - - httputil.WriteJSON(w, 200, map[string]interface{}{ - "id": id, - "object": "chat.completion", - "created": created, - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "message": map[string]string{"role": "assistant", "content": content}, - "finish_reason": "stop", - }, - }, - "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - }, truncatedHeaders) -} - diff --git a/internal/handlers/gemini_handler.go b/internal/handlers/gemini_handler.go deleted file mode 100644 index 04d7065..0000000 --- a/internal/handlers/gemini_handler.go +++ /dev/null @@ -1,203 +0,0 @@ -package handlers - -import ( - "context" - "cursor-api-proxy/internal/apitypes" - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/httputil" - "cursor-api-proxy/internal/logger" - "cursor-api-proxy/internal/providers/geminiweb" - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/google/uuid" -) - -func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) { - _ = context.Background() // 確保 context 被使用 - var bodyMap map[string]interface{} - if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil { - httputil.WriteJSON(w, 400, map[string]interface{}{ - "error": map[string]string{"message": "invalid JSON body", "code": "bad_request"}, - }, nil) - return - } - - rawModel, _ := bodyMap["model"].(string) - if rawModel == "" { - rawModel = "gemini-2.0-flash" - } - - var messages []interface{} - if m, ok := bodyMap["messages"].([]interface{}); ok { - messages = m - } - - isStream := false - if s, ok := bodyMap["stream"].(bool); ok { - isStream = s - } - - // 轉換 messages 為 apitypes.Message - var apiMessages []apitypes.Message - for _, m := range messages { - if msgMap, ok := m.(map[string]interface{}); ok { - role, _ := msgMap["role"].(string) - content := "" - if c, ok := msgMap["content"].(string); ok { - content = c - } - apiMessages = append(apiMessages, apitypes.Message{ - Role: role, - Content: content, - }) - } - } - - logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream) - start := time.Now().UnixMilli() - - // 創建 Gemini provider (使用 Playwright) - provider, provErr := geminiweb.NewPlaywrightProvider(cfg) - if provErr != nil { - logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, provErr.Error()) - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"message": provErr.Error(), "code": "provider_error"}, - }, nil) - return - } - - if isStream { - httputil.WriteSSEHeaders(w, nil) - flusher, _ := w.(http.Flusher) - - id := "chatcmpl_" + uuid.New().String() - created := time.Now().Unix() - var accumulated string - - err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) { - if chunk.Type == apitypes.ChunkText { - accumulated += chunk.Text - respChunk := map[string]interface{}{ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": rawModel, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]string{"content": chunk.Text}, - "finish_reason": nil, - }, - }, - } - data, _ := json.Marshal(respChunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - } else if chunk.Type == apitypes.ChunkThinking { - respChunk := map[string]interface{}{ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": rawModel, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{"reasoning_content": chunk.Thinking}, - "finish_reason": nil, - }, - }, - } - data, _ := json.Marshal(respChunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - } else if chunk.Type == apitypes.ChunkDone { - stopChunk := map[string]interface{}{ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": rawModel, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{}, - "finish_reason": "stop", - }, - }, - } - data, _ := json.Marshal(stopChunk) - fmt.Fprintf(w, "data: %s\n\n", data) - fmt.Fprintf(w, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - } - }) - - latencyMs := time.Now().UnixMilli() - start - if err != nil { - logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) - logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1) - return - } - - logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true) - logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0) - return - } - - // 非串流模式 - var resultText string - var resultThinking string - - err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) { - if chunk.Type == apitypes.ChunkText { - resultText += chunk.Text - } else if chunk.Type == apitypes.ChunkThinking { - resultThinking += chunk.Thinking - } - }) - - latencyMs := time.Now().UnixMilli() - start - - if err != nil { - logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) - logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1) - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"message": err.Error(), "code": "gemini_error"}, - }, nil) - return - } - - logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false) - logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0) - - id := "chatcmpl_" + uuid.New().String() - created := time.Now().Unix() - - resp := map[string]interface{}{ - "id": id, - "object": "chat.completion", - "created": created, - "model": rawModel, - "choices": []map[string]interface{}{ - { - "index": 0, - "message": map[string]interface{}{ - "role": "assistant", - "content": resultText, - }, - "finish_reason": "stop", - }, - }, - "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - } - - httputil.WriteJSON(w, 200, resp, nil) -} diff --git a/internal/handlers/health.go b/internal/handlers/health.go deleted file mode 100644 index 6b26713..0000000 --- a/internal/handlers/health.go +++ /dev/null @@ -1,20 +0,0 @@ -package handlers - -import ( - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/httputil" - "net/http" -) - -func HandleHealth(w http.ResponseWriter, r *http.Request, version string, cfg config.BridgeConfig) { - httputil.WriteJSON(w, 200, map[string]interface{}{ - "ok": true, - "version": version, - "workspace": cfg.Workspace, - "mode": cfg.Mode, - "defaultModel": cfg.DefaultModel, - "force": cfg.Force, - "approveMcps": cfg.ApproveMcps, - "strictModel": cfg.StrictModel, - }, nil) -} diff --git a/internal/handlers/models.go b/internal/handlers/models.go deleted file mode 100644 index 16c7873..0000000 --- a/internal/handlers/models.go +++ /dev/null @@ -1,107 +0,0 @@ -package handlers - -import ( - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/httputil" - "cursor-api-proxy/internal/models" - "net/http" - "sync" - "time" -) - -const modelCacheTTLMs = 5 * 60 * 1000 - -type ModelCache struct { - At int64 - Models []models.CursorCliModel -} - -type ModelCacheRef struct { - mu sync.Mutex - cache *ModelCache - inflight bool - waiters []chan struct{} -} - -func (ref *ModelCacheRef) HandleModels(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig) { - now := time.Now().UnixMilli() - - ref.mu.Lock() - if ref.cache != nil && now-ref.cache.At <= modelCacheTTLMs { - cache := ref.cache - ref.mu.Unlock() - writeModels(w, cache.Models) - return - } - - if ref.inflight { - // Wait for the in-flight fetch - ch := make(chan struct{}, 1) - ref.waiters = append(ref.waiters, ch) - ref.mu.Unlock() - <-ch - ref.mu.Lock() - cache := ref.cache - ref.mu.Unlock() - writeModels(w, cache.Models) - return - } - - ref.inflight = true - ref.mu.Unlock() - - fetched, err := models.ListCursorCliModels(cfg.AgentBin, 60000) - - ref.mu.Lock() - ref.inflight = false - if err == nil { - ref.cache = &ModelCache{At: time.Now().UnixMilli(), Models: fetched} - } - waiters := ref.waiters - ref.waiters = nil - ref.mu.Unlock() - - for _, ch := range waiters { - ch <- struct{}{} - } - - if err != nil { - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"message": err.Error(), "code": "models_fetch_error"}, - }, nil) - return - } - - writeModels(w, fetched) -} - -func writeModels(w http.ResponseWriter, mods []models.CursorCliModel) { - cursorModels := make([]map[string]interface{}, len(mods)) - for i, m := range mods { - cursorModels[i] = map[string]interface{}{ - "id": m.ID, - "object": "model", - "owned_by": "cursor", - "name": m.Name, - } - } - - ids := make([]string, len(mods)) - for i, m := range mods { - ids[i] = m.ID - } - aliases := models.GetAnthropicModelAliases(ids) - for _, a := range aliases { - cursorModels = append(cursorModels, map[string]interface{}{ - "id": a.ID, - "object": "model", - "owned_by": "cursor", - "name": a.Name, - }) - } - - httputil.WriteJSON(w, 200, map[string]interface{}{ - "object": "list", - "data": cursorModels, - }, nil) -} diff --git a/internal/handlers/resolve_model.go b/internal/handlers/resolve_model.go deleted file mode 100644 index e20b353..0000000 --- a/internal/handlers/resolve_model.go +++ /dev/null @@ -1,27 +0,0 @@ -package handlers - -import "cursor-api-proxy/internal/config" - -func ResolveModel(requested string, lastModelRef *string, cfg config.BridgeConfig) string { - isAuto := requested == "auto" - var explicitModel string - if requested != "" && !isAuto { - explicitModel = requested - } - if explicitModel != "" { - *lastModelRef = explicitModel - } - if isAuto { - return "auto" - } - if explicitModel != "" { - return explicitModel - } - if cfg.StrictModel && *lastModelRef != "" { - return *lastModelRef - } - if *lastModelRef != "" { - return *lastModelRef - } - return cfg.DefaultModel -} diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go deleted file mode 100644 index bb39663..0000000 --- a/internal/httputil/httputil.go +++ /dev/null @@ -1,50 +0,0 @@ -package httputil - -import ( - "encoding/json" - "io" - "net/http" - "regexp" -) - -var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`) - -func ExtractBearerToken(r *http.Request) string { - h := r.Header.Get("Authorization") - if h == "" { - return "" - } - m := bearerRe.FindStringSubmatch(h) - if m == nil { - return "" - } - return m[1] -} - -func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) { - for k, v := range extraHeaders { - w.Header().Set(k, v) - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(body) -} - -func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) { - for k, v := range extraHeaders { - w.Header().Set(k, v) - } - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - w.WriteHeader(200) -} - -func ReadBody(r *http.Request) (string, error) { - data, err := io.ReadAll(r.Body) - if err != nil { - return "", err - } - return string(data), nil -} diff --git a/internal/httputil/httputil_test.go b/internal/httputil/httputil_test.go deleted file mode 100644 index 530e536..0000000 --- a/internal/httputil/httputil_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package httputil - -import ( - "net/http/httptest" - "testing" -) - -func TestExtractBearerToken(t *testing.T) { - tests := []struct { - header string - want string - }{ - {"Bearer mytoken123", "mytoken123"}, - {"bearer MYTOKEN", "MYTOKEN"}, - {"", ""}, - {"Basic abc", ""}, - {"Bearer ", ""}, - } - for _, tc := range tests { - req := httptest.NewRequest("GET", "/", nil) - if tc.header != "" { - req.Header.Set("Authorization", tc.header) - } - got := ExtractBearerToken(req) - if got != tc.want { - t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want) - } - } -} - -func TestWriteJSON(t *testing.T) { - w := httptest.NewRecorder() - WriteJSON(w, 200, map[string]string{"ok": "true"}, nil) - - if w.Code != 200 { - t.Errorf("expected 200, got %d", w.Code) - } - if w.Header().Get("Content-Type") != "application/json" { - t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type")) - } -} - -func TestWriteJSONWithExtraHeaders(t *testing.T) { - w := httptest.NewRecorder() - WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"}) - - if w.Header().Get("X-Custom") != "value" { - t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom")) - } -} diff --git a/internal/logger/logger.go b/internal/logger/logger.go deleted file mode 100644 index db18a0a..0000000 --- a/internal/logger/logger.go +++ /dev/null @@ -1,309 +0,0 @@ -package logger - -import ( - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/pool" - "fmt" - "os" - "path/filepath" - "strings" - "time" -) - -const ( - cReset = "\x1b[0m" - cBold = "\x1b[1m" - cDim = "\x1b[2m" - cCyan = "\x1b[36m" - cBCyan = "\x1b[1;96m" - cGreen = "\x1b[32m" - cBGreen = "\x1b[1;92m" - cYellow = "\x1b[33m" - cMagenta = "\x1b[35m" - cBMagenta = "\x1b[1;95m" - cRed = "\x1b[31m" - cGray = "\x1b[90m" - cWhite = "\x1b[97m" -) - -var roleStyle = map[string]string{ - "system": cYellow, - "user": cCyan, - "assistant": cGreen, -} - -var roleEmoji = map[string]string{ - "system": "🔧", - "user": "👤", - "assistant": "🤖", -} - -func ts() string { - return cGray + time.Now().UTC().Format("15:04:05") + cReset -} - -func tsDate() string { - return cGray + time.Now().UTC().Format("2006-01-02 15:04:05") + cReset -} - -func truncate(s string, max int) string { - if len(s) <= max { - return s - } - head := int(float64(max) * 0.6) - tail := max - head - omitted := len(s) - head - tail - return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset -} - -func hr(ch string, length int) string { - return cGray + strings.Repeat(ch, length) + cReset -} - -type TrafficMessage struct { - Role string - Content string -} - -func LogDebug(format string, args ...interface{}) { - msg := fmt.Sprintf(format, args...) - fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg) -} - -func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) { - provider := cfg.Provider - if provider == "" { - provider = "cursor" - } - fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset) - fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n", - cBold, cBCyan, cReset, cBold, cWhite, version, cBCyan, cReset) - fmt.Printf("%s%s╚══════════════════════════════════════════╝%s\n\n", cBold, cBCyan, cReset) - url := fmt.Sprintf("%s://%s:%d", scheme, host, port) - fmt.Printf(" %s●%s listening %s%s%s\n", cBGreen, cReset, cBold, url, cReset) - fmt.Printf(" %s▸%s provider %s%s%s\n", cCyan, cReset, cBold, provider, cReset) - fmt.Printf(" %s▸%s agent %s%s%s\n", cCyan, cReset, cDim, cfg.AgentBin, cReset) - fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset) - fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset) - fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset) - fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset) - - // 顯示 Gemini Web Provider 相關設定 - if provider == "gemini-web" { - fmt.Printf(" %s▸%s gemini-dir %s%s%s\n", cCyan, cReset, cDim, cfg.GeminiAccountDir, cReset) - fmt.Printf(" %s▸%s max-sess %s%d%s\n", cCyan, cReset, cDim, cfg.GeminiMaxSessions, cReset) - } - - flags := []string{} - if cfg.Force { - flags = append(flags, "force") - } - if cfg.ApproveMcps { - flags = append(flags, "approve-mcps") - } - if cfg.MaxMode { - flags = append(flags, "max-mode") - } - if cfg.Verbose { - flags = append(flags, "verbose") - } - if cfg.ChatOnlyWorkspace { - flags = append(flags, "chat-only") - } - if cfg.RequiredKey != "" { - flags = append(flags, "api-key-required") - } - if len(flags) > 0 { - fmt.Printf(" %s▸%s flags %s%s%s\n", cCyan, cReset, cYellow, strings.Join(flags, " · "), cReset) - } - if len(cfg.ConfigDirs) > 0 { - fmt.Printf(" %s▸%s pool %s%d accounts%s\n", cCyan, cReset, cBGreen, len(cfg.ConfigDirs), cReset) - } - fmt.Println() -} - -func LogShutdown(sig string) { - fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset) -} - -func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) { - modeTag := fmt.Sprintf("%ssync%s", cDim, cReset) - if isStream { - modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset) - } - fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n", - ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag) -} - -func LogRequestDone(method, pathname, model string, latencyMs int64, code int) { - statusColor := cBGreen - if code != 0 { - statusColor = cRed - } - fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n", - ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset) -} - -func LogRequestTimeout(method, pathname, model string, timeoutMs int) { - fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n", - ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset) -} - -func LogClientDisconnect(method, pathname, model string, latencyMs int64) { - fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n", - ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset) -} - -func LogStreamChunk(model string, text string, chunkNum int) { - preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120) - fmt.Printf("%s %s▸%s #%d %s%s%s\n", - ts(), cDim, cReset, chunkNum, cWhite, preview, cReset) -} - -func LogRawLine(line string) { - preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200) - fmt.Printf("%s %s│%s %sraw%s %s\n", - ts(), cGray, cReset, cDim, cReset, preview) -} - -func LogIncoming(method, pathname, remoteAddress string) { - methodColor := cBCyan - switch method { - case "POST": - methodColor = cBMagenta - case "GET": - methodColor = cBCyan - case "DELETE": - methodColor = cRed - } - fmt.Printf("%s %s%s%s%s %s%s%s %s(%s)%s\n", - ts(), - methodColor, cBold, method, cReset, - cWhite, pathname, cReset, - cDim, remoteAddress, cReset, - ) -} - -func LogAccountAssigned(configDir string) { - if configDir == "" { - return - } - name := filepath.Base(configDir) - fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset) -} - -func LogAccountStats(verbose bool, stats []pool.AccountStat) { - if !verbose || len(stats) == 0 { - return - } - now := time.Now().UnixMilli() - fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset) - for _, s := range stats { - name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir)) - active := fmt.Sprintf("%sactive:0%s", cDim, cReset) - if s.ActiveRequests > 0 { - active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset) - } - total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset) - ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset) - errStr := fmt.Sprintf("%serr:0%s", cDim, cReset) - if s.TotalErrors > 0 { - errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset) - } - rl := fmt.Sprintf("%srl:0%s", cDim, cReset) - if s.TotalRateLimits > 0 { - rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset) - } - avg := "avg:-" - if s.TotalRequests > 0 { - avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests)) - } - status := fmt.Sprintf("%s✓%s", cGreen, cReset) - if s.IsRateLimited { - recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339) - _ = now - status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset) - } - fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n", - cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status) - } - fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset) -} - -func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) { - if !verbose { - return - } - modeTag := fmt.Sprintf("%ssync%s", cDim, cReset) - if isStream { - modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset) - } - modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset) - fmt.Println(hr("─", 60)) - fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag) - for _, m := range messages { - roleColor := cWhite - if c, ok := roleStyle[m.Role]; ok { - roleColor = c - } - emoji := "💬" - if e, ok := roleEmoji[m.Role]; ok { - emoji = e - } - label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset) - charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset) - preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280) - fmt.Printf(" %s %s %s\n", emoji, label, charCount) - fmt.Printf(" %s%s%s\n", cDim, preview, cReset) - } -} - -func LogTrafficResponse(verbose bool, model, text string, isStream bool) { - if !verbose { - return - } - modeTag := fmt.Sprintf("%ssync%s", cDim, cReset) - if isStream { - modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset) - } - modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset) - charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset) - preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480) - fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount) - fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset) - fmt.Println(hr("─", 60)) -} - -func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) { - line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode) - dir := filepath.Dir(logPath) - if err := os.MkdirAll(dir, 0755); err == nil { - f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err == nil { - _, _ = f.WriteString(line) - f.Close() - } - } -} - -func LogTruncation(originalLen, finalLen int) { - fmt.Printf("%s %s⚠ prompt truncated%s %s(%d → %d chars, tail preserved)%s\n", - ts(), cYellow, cReset, cDim, originalLen, finalLen, cReset) -} - -func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string { - errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr)) - fmt.Fprintf(os.Stderr, "%s %s✗ agent error%s %s%s%s\n", ts(), cRed, cReset, cDim, errMsg, cReset) - truncated := strings.TrimSpace(stderr) - if len(truncated) > 200 { - truncated = truncated[:200] - } - truncated = strings.ReplaceAll(truncated, "\n", " ") - line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n", - time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated) - if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { - _, _ = f.WriteString(line) - f.Close() - } - return errMsg -} diff --git a/internal/models/cursorcli.go b/internal/models/cursorcli.go deleted file mode 100644 index 69d1f6c..0000000 --- a/internal/models/cursorcli.go +++ /dev/null @@ -1,62 +0,0 @@ -package models - -import ( - "cursor-api-proxy/internal/process" - "fmt" - "os" - "regexp" - "strings" -) - -type CursorCliModel struct { - ID string - Name string -} - -var modelLineRe = regexp.MustCompile(`^([A-Za-z0-9][A-Za-z0-9._:/-]*)\s+-\s+(.*)$`) -var trailingParenRe = regexp.MustCompile(`\s*\([^)]*\)\s*$`) - -func ParseCursorCliModels(output string) []CursorCliModel { - lines := strings.Split(output, "\n") - seen := make(map[string]CursorCliModel) - var order []string - - for _, line := range lines { - line = strings.TrimSpace(line) - m := modelLineRe.FindStringSubmatch(line) - if m == nil { - continue - } - id := m[1] - rawName := m[2] - name := strings.TrimSpace(trailingParenRe.ReplaceAllString(rawName, "")) - if name == "" { - name = id - } - if _, exists := seen[id]; !exists { - seen[id] = CursorCliModel{ID: id, Name: name} - order = append(order, id) - } - } - - result := make([]CursorCliModel, 0, len(order)) - for _, id := range order { - result = append(result, seen[id]) - } - return result -} - -func ListCursorCliModels(agentBin string, timeoutMs int) ([]CursorCliModel, error) { - tmpDir := os.TempDir() - result, err := process.Run(agentBin, []string{"--list-models"}, process.RunOptions{ - Cwd: tmpDir, - TimeoutMs: timeoutMs, - }) - if err != nil { - return nil, err - } - if result.Code != 0 { - return nil, fmt.Errorf("agent --list-models failed: %s", strings.TrimSpace(result.Stderr)) - } - return ParseCursorCliModels(result.Stdout), nil -} diff --git a/internal/models/cursorcli_test.go b/internal/models/cursorcli_test.go deleted file mode 100644 index 9a911ac..0000000 --- a/internal/models/cursorcli_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package models - -import "testing" - -func TestParseCursorCliModels(t *testing.T) { - output := ` -gpt-4o - GPT-4o (some info) -claude-3-5-sonnet - Claude 3.5 Sonnet -gpt-4o - GPT-4o duplicate -invalid line without dash -` - result := ParseCursorCliModels(output) - - if len(result) != 2 { - t.Fatalf("expected 2 unique models, got %d: %v", len(result), result) - } - if result[0].ID != "gpt-4o" { - t.Errorf("expected gpt-4o, got %s", result[0].ID) - } - if result[0].Name != "GPT-4o" { - t.Errorf("expected 'GPT-4o', got %s", result[0].Name) - } - if result[1].ID != "claude-3-5-sonnet" { - t.Errorf("expected claude-3-5-sonnet, got %s", result[1].ID) - } -} - -func TestParseCursorCliModelsEmpty(t *testing.T) { - result := ParseCursorCliModels("") - if len(result) != 0 { - t.Fatalf("expected empty, got %v", result) - } -} diff --git a/internal/models/cursormap.go b/internal/models/cursormap.go deleted file mode 100644 index bdcc568..0000000 --- a/internal/models/cursormap.go +++ /dev/null @@ -1,123 +0,0 @@ -package models - -import ( - "regexp" - "strings" -) - -var anthropicToCursor = map[string]string{ - "claude-opus-4-6": "opus-4.6", - "claude-opus-4.6": "opus-4.6", - "claude-sonnet-4-6": "sonnet-4.6", - "claude-sonnet-4.6": "sonnet-4.6", - "claude-opus-4-5": "opus-4.5", - "claude-opus-4.5": "opus-4.5", - "claude-sonnet-4-5": "sonnet-4.5", - "claude-sonnet-4.5": "sonnet-4.5", - "claude-opus-4": "opus-4.6", - "claude-sonnet-4": "sonnet-4.6", - "claude-haiku-4-5-20251001": "sonnet-4.5", - "claude-haiku-4-5": "sonnet-4.5", - "claude-haiku-4-6": "sonnet-4.6", - "claude-haiku-4": "sonnet-4.5", - "claude-opus-4-6-thinking": "opus-4.6-thinking", - "claude-sonnet-4-6-thinking": "sonnet-4.6-thinking", - "claude-opus-4-5-thinking": "opus-4.5-thinking", - "claude-sonnet-4-5-thinking": "sonnet-4.5-thinking", -} - -type ModelAlias struct { - CursorID string - AnthropicID string - Name string -} - -var cursorToAnthropicAlias = []ModelAlias{ - {"opus-4.6", "claude-opus-4-6", "Claude 4.6 Opus"}, - {"opus-4.6-thinking", "claude-opus-4-6-thinking", "Claude 4.6 Opus (Thinking)"}, - {"sonnet-4.6", "claude-sonnet-4-6", "Claude 4.6 Sonnet"}, - {"sonnet-4.6-thinking", "claude-sonnet-4-6-thinking", "Claude 4.6 Sonnet (Thinking)"}, - {"opus-4.5", "claude-opus-4-5", "Claude 4.5 Opus"}, - {"opus-4.5-thinking", "claude-opus-4-5-thinking", "Claude 4.5 Opus (Thinking)"}, - {"sonnet-4.5", "claude-sonnet-4-5", "Claude 4.5 Sonnet"}, - {"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"}, -} - -// cursorModelPattern matches cursor model IDs like "opus-4.6", "sonnet-4.7-thinking". -var cursorModelPattern = regexp.MustCompile(`^([a-zA-Z]+)-(\d+)\.(\d+)(-thinking)?$`) - -// reverseDynamicPattern matches dynamically generated anthropic aliases -// like "claude-opus-4-7", "claude-sonnet-4-7-thinking". -var reverseDynamicPattern = regexp.MustCompile(`^claude-([a-zA-Z]+)-(\d+)-(\d+)(-thinking)?$`) - -func generateDynamicAlias(cursorID string) (AnthropicAlias, bool) { - m := cursorModelPattern.FindStringSubmatch(cursorID) - if m == nil { - return AnthropicAlias{}, false - } - family := m[1] - major := m[2] - minor := m[3] - thinking := m[4] - - anthropicID := "claude-" + family + "-" + major + "-" + minor + thinking - capFamily := strings.ToUpper(family[:1]) + family[1:] - name := capFamily + " " + major + "." + minor - if thinking == "-thinking" { - name += " (Thinking)" - } - return AnthropicAlias{ID: anthropicID, Name: name}, true -} - -func reverseDynamicAlias(anthropicID string) (string, bool) { - m := reverseDynamicPattern.FindStringSubmatch(anthropicID) - if m == nil { - return "", false - } - return m[1] + "-" + m[2] + "." + m[3] + m[4], true -} - -func ResolveToCursorModel(requested string) string { - if strings.TrimSpace(requested) == "" { - return "" - } - key := strings.ToLower(strings.TrimSpace(requested)) - if v, ok := anthropicToCursor[key]; ok { - return v - } - if v, ok := reverseDynamicAlias(key); ok { - return v - } - return strings.TrimSpace(requested) -} - -type AnthropicAlias struct { - ID string - Name string -} - -func GetAnthropicModelAliases(availableCursorIDs []string) []AnthropicAlias { - set := make(map[string]bool, len(availableCursorIDs)) - for _, id := range availableCursorIDs { - set[id] = true - } - - staticSet := make(map[string]bool, len(cursorToAnthropicAlias)) - var result []AnthropicAlias - for _, a := range cursorToAnthropicAlias { - if set[a.CursorID] { - staticSet[a.CursorID] = true - result = append(result, AnthropicAlias{ID: a.AnthropicID, Name: a.Name}) - } - } - - for _, id := range availableCursorIDs { - if staticSet[id] { - continue - } - if alias, ok := generateDynamicAlias(id); ok { - result = append(result, alias) - } - } - return result -} diff --git a/internal/models/cursormap_test.go b/internal/models/cursormap_test.go deleted file mode 100644 index 3b3d54b..0000000 --- a/internal/models/cursormap_test.go +++ /dev/null @@ -1,163 +0,0 @@ -package models - -import "testing" - -func TestGetAnthropicModelAliases_StaticOnly(t *testing.T) { - aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.5"}) - if len(aliases) != 2 { - t.Fatalf("expected 2 aliases, got %d: %v", len(aliases), aliases) - } - ids := map[string]string{} - for _, a := range aliases { - ids[a.ID] = a.Name - } - if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" { - t.Errorf("unexpected name for claude-sonnet-4-6: %s", ids["claude-sonnet-4-6"]) - } - if ids["claude-opus-4-5"] != "Claude 4.5 Opus" { - t.Errorf("unexpected name for claude-opus-4-5: %s", ids["claude-opus-4-5"]) - } -} - -func TestGetAnthropicModelAliases_DynamicFallback(t *testing.T) { - aliases := GetAnthropicModelAliases([]string{"sonnet-4.7", "opus-5.0-thinking", "gpt-4o"}) - ids := map[string]string{} - for _, a := range aliases { - ids[a.ID] = a.Name - } - if ids["claude-sonnet-4-7"] != "Sonnet 4.7" { - t.Errorf("unexpected name for claude-sonnet-4-7: %s", ids["claude-sonnet-4-7"]) - } - if ids["claude-opus-5-0-thinking"] != "Opus 5.0 (Thinking)" { - t.Errorf("unexpected name for claude-opus-5-0-thinking: %s", ids["claude-opus-5-0-thinking"]) - } - if _, ok := ids["claude-gpt-4o"]; ok { - t.Errorf("gpt-4o should not generate a claude alias") - } -} - -func TestGetAnthropicModelAliases_Mixed(t *testing.T) { - aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.7", "gpt-4o"}) - ids := map[string]string{} - for _, a := range aliases { - ids[a.ID] = a.Name - } - // static entry keeps its custom name - if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" { - t.Errorf("static alias should keep original name, got: %s", ids["claude-sonnet-4-6"]) - } - // dynamic entry uses auto-generated name - if ids["claude-opus-4-7"] != "Opus 4.7" { - t.Errorf("dynamic alias name mismatch: %s", ids["claude-opus-4-7"]) - } -} - -func TestGetAnthropicModelAliases_UnknownPattern(t *testing.T) { - aliases := GetAnthropicModelAliases([]string{"some-unknown-model"}) - if len(aliases) != 0 { - t.Fatalf("expected 0 aliases for unknown pattern, got %d: %v", len(aliases), aliases) - } -} - -func TestResolveToCursorModel_Static(t *testing.T) { - tests := []struct { - input string - want string - }{ - {"claude-opus-4-6", "opus-4.6"}, - {"claude-opus-4.6", "opus-4.6"}, - {"claude-sonnet-4-5", "sonnet-4.5"}, - {"claude-opus-4-6-thinking", "opus-4.6-thinking"}, - } - for _, tc := range tests { - got := ResolveToCursorModel(tc.input) - if got != tc.want { - t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want) - } - } -} - -func TestResolveToCursorModel_DynamicFallback(t *testing.T) { - tests := []struct { - input string - want string - }{ - {"claude-opus-4-7", "opus-4.7"}, - {"claude-sonnet-5-0", "sonnet-5.0"}, - {"claude-opus-4-7-thinking", "opus-4.7-thinking"}, - {"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking"}, - } - for _, tc := range tests { - got := ResolveToCursorModel(tc.input) - if got != tc.want { - t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want) - } - } -} - -func TestResolveToCursorModel_Passthrough(t *testing.T) { - tests := []string{"sonnet-4.6", "gpt-4o", "custom-model"} - for _, input := range tests { - got := ResolveToCursorModel(input) - if got != input { - t.Errorf("ResolveToCursorModel(%q) = %q, want passthrough %q", input, got, input) - } - } -} - -func TestResolveToCursorModel_Empty(t *testing.T) { - if got := ResolveToCursorModel(""); got != "" { - t.Errorf("ResolveToCursorModel(\"\") = %q, want empty", got) - } - if got := ResolveToCursorModel(" "); got != "" { - t.Errorf("ResolveToCursorModel(\" \") = %q, want empty", got) - } -} - -func TestGenerateDynamicAlias(t *testing.T) { - tests := []struct { - input string - want AnthropicAlias - ok bool - }{ - {"opus-4.7", AnthropicAlias{"claude-opus-4-7", "Opus 4.7"}, true}, - {"sonnet-5.0-thinking", AnthropicAlias{"claude-sonnet-5-0-thinking", "Sonnet 5.0 (Thinking)"}, true}, - {"gpt-4o", AnthropicAlias{}, false}, - {"invalid", AnthropicAlias{}, false}, - } - for _, tc := range tests { - got, ok := generateDynamicAlias(tc.input) - if ok != tc.ok { - t.Errorf("generateDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok) - continue - } - if ok && (got.ID != tc.want.ID || got.Name != tc.want.Name) { - t.Errorf("generateDynamicAlias(%q) = {%q, %q}, want {%q, %q}", tc.input, got.ID, got.Name, tc.want.ID, tc.want.Name) - } - } -} - -func TestReverseDynamicAlias(t *testing.T) { - tests := []struct { - input string - want string - ok bool - }{ - {"claude-opus-4-7", "opus-4.7", true}, - {"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking", true}, - {"claude-opus-4-6", "opus-4.6", true}, - {"claude-opus-4.6", "", false}, - {"claude-haiku-4-5-20251001", "", false}, - {"some-model", "", false}, - } - for _, tc := range tests { - got, ok := reverseDynamicAlias(tc.input) - if ok != tc.ok { - t.Errorf("reverseDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok) - continue - } - if ok && got != tc.want { - t.Errorf("reverseDynamicAlias(%q) = %q, want %q", tc.input, got, tc.want) - } - } -} diff --git a/internal/openai/openai.go b/internal/openai/openai.go deleted file mode 100644 index 86919a6..0000000 --- a/internal/openai/openai.go +++ /dev/null @@ -1,243 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - "strings" -) - -type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []interface{} `json:"messages"` - Stream bool `json:"stream"` - Tools []interface{} `json:"tools"` - ToolChoice interface{} `json:"tool_choice"` - Functions []interface{} `json:"functions"` - FunctionCall interface{} `json:"function_call"` -} - -func NormalizeModelID(raw string) string { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return "" - } - parts := strings.Split(trimmed, "/") - last := parts[len(parts)-1] - if last == "" { - return "" - } - return last -} - -func imageURLToText(imageURL interface{}) string { - if imageURL == nil { - return "[Image]" - } - var url string - switch v := imageURL.(type) { - case string: - url = v - case map[string]interface{}: - if u, ok := v["url"].(string); ok { - url = u - } - } - if url == "" { - return "[Image]" - } - if strings.HasPrefix(url, "data:") { - end := strings.Index(url, ";") - mime := "image" - if end > 5 { - mime = url[5:end] - } - return "[Image: base64 " + mime + "]" - } - return "[Image: " + url + "]" -} - -func MessageContentToText(content interface{}) string { - if content == nil { - return "" - } - switch v := content.(type) { - case string: - return v - case []interface{}: - var parts []string - for _, p := range v { - if p == nil { - continue - } - switch part := p.(type) { - case string: - parts = append(parts, part) - case map[string]interface{}: - typ, _ := part["type"].(string) - switch typ { - case "text": - if t, ok := part["text"].(string); ok { - parts = append(parts, t) - } - case "image_url": - parts = append(parts, imageURLToText(part["image_url"])) - case "image": - src := part["source"] - if src == nil { - src = part["url"] - } - parts = append(parts, imageURLToText(src)) - } - } - } - return strings.Join(parts, " ") - } - return "" -} - -func ToolsToSystemText(tools []interface{}, functions []interface{}) string { - var defs []interface{} - - for _, t := range tools { - if m, ok := t.(map[string]interface{}); ok { - if m["type"] == "function" { - if fn := m["function"]; fn != nil { - defs = append(defs, fn) - } - } else { - defs = append(defs, t) - } - } - } - defs = append(defs, functions...) - - if len(defs) == 0 { - return "" - } - - var lines []string - lines = append(lines, "Available tools (respond with a JSON object to call one):", "") - - for _, raw := range defs { - fn, ok := raw.(map[string]interface{}) - if !ok { - continue - } - name, _ := fn["name"].(string) - desc, _ := fn["description"].(string) - params := "{}" - if p := fn["parameters"]; p != nil { - if b, err := json.MarshalIndent(p, "", " "); err == nil { - params = string(b) - } - } else if p := fn["input_schema"]; p != nil { - if b, err := json.MarshalIndent(p, "", " "); err == nil { - params = string(b) - } - } - lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params) - } - - lines = append(lines, "", - "When you want to call a tool, use this EXACT format:", - "", - "", - `{"name": "function_name", "arguments": {"param1": "value1"}}`, - "", - "", - "Rules:", - "- Write your reasoning BEFORE the tool call", - "- You may make multiple tool calls by using multiple blocks", - "- STOP writing after the last tag", - "- If no tool is needed, respond normally without tags", - ) - - return strings.Join(lines, "\n") -} - -type SimpleMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -func BuildPromptFromMessages(messages []interface{}) string { - var systemParts []string - var convo []string - - for _, raw := range messages { - m, ok := raw.(map[string]interface{}) - if !ok { - continue - } - role, _ := m["role"].(string) - text := MessageContentToText(m["content"]) - - switch role { - case "system", "developer": - if text != "" { - systemParts = append(systemParts, text) - } - case "user": - if text != "" { - convo = append(convo, "User: "+text) - } - case "assistant": - toolCalls, _ := m["tool_calls"].([]interface{}) - if len(toolCalls) > 0 { - var parts []string - if text != "" { - parts = append(parts, text) - } - for _, tc := range toolCalls { - tcMap, ok := tc.(map[string]interface{}) - if !ok { - continue - } - fn, _ := tcMap["function"].(map[string]interface{}) - if fn == nil { - continue - } - name, _ := fn["name"].(string) - args, _ := fn["arguments"].(string) - if args == "" { - args = "{}" - } - parts = append(parts, fmt.Sprintf("\n{\"name\": \"%s\", \"arguments\": %s}\n", name, args)) - } - if len(parts) > 0 { - convo = append(convo, "Assistant: "+strings.Join(parts, "\n")) - } - } else if text != "" { - convo = append(convo, "Assistant: "+text) - } - case "tool", "function": - name, _ := m["name"].(string) - toolCallID, _ := m["tool_call_id"].(string) - label := "Tool result" - if name != "" { - label = "Tool result (" + name + ")" - } - if toolCallID != "" { - label += " [id=" + toolCallID + "]" - } - if text != "" { - convo = append(convo, label+": "+text) - } - } - } - - system := "" - if len(systemParts) > 0 { - system = "System:\n" + strings.Join(systemParts, "\n\n") + "\n\n" - } - transcript := strings.Join(convo, "\n\n") - return system + transcript + "\n\nAssistant:" -} - -func BuildPromptFromSimpleMessages(messages []SimpleMessage) string { - ifaces := make([]interface{}, len(messages)) - for i, m := range messages { - ifaces[i] = map[string]interface{}{"role": m.Role, "content": m.Content} - } - return BuildPromptFromMessages(ifaces) -} diff --git a/internal/openai/openai_test.go b/internal/openai/openai_test.go deleted file mode 100644 index 04ede1b..0000000 --- a/internal/openai/openai_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package openai - -import "testing" - -func TestNormalizeModelID(t *testing.T) { - tests := []struct { - input string - want string - }{ - {"gpt-4", "gpt-4"}, - {"openai/gpt-4", "gpt-4"}, - {"anthropic/claude-3", "claude-3"}, - {"", ""}, - {" ", ""}, - {"a/b/c", "c"}, - } - for _, tc := range tests { - got := NormalizeModelID(tc.input) - if got != tc.want { - t.Errorf("NormalizeModelID(%q) = %q, want %q", tc.input, got, tc.want) - } - } -} - -func TestBuildPromptFromMessages(t *testing.T) { - messages := []interface{}{ - map[string]interface{}{"role": "system", "content": "You are helpful."}, - map[string]interface{}{"role": "user", "content": "Hello"}, - map[string]interface{}{"role": "assistant", "content": "Hi there"}, - } - got := BuildPromptFromMessages(messages) - if got == "" { - t.Fatal("expected non-empty prompt") - } - containsSystem := false - containsUser := false - containsAssistant := false - for i := 0; i < len(got)-10; i++ { - if got[i:i+6] == "System" { - containsSystem = true - } - if got[i:i+4] == "User" { - containsUser = true - } - if got[i:i+9] == "Assistant" { - containsAssistant = true - } - } - if !containsSystem || !containsUser || !containsAssistant { - t.Errorf("prompt missing sections: system=%v user=%v assistant=%v\n%s", - containsSystem, containsUser, containsAssistant, got) - } -} - -func TestToolsToSystemText(t *testing.T) { - tools := []interface{}{ - map[string]interface{}{ - "type": "function", - "function": map[string]interface{}{ - "name": "get_weather", - "description": "Get weather", - "parameters": map[string]interface{}{"type": "object"}, - }, - }, - } - got := ToolsToSystemText(tools, nil) - if got == "" { - t.Fatal("expected non-empty tools text") - } - if len(got) < 10 { - t.Errorf("tools text too short: %q", got) - } -} - -func TestToolsToSystemTextEmpty(t *testing.T) { - got := ToolsToSystemText(nil, nil) - if got != "" { - t.Errorf("expected empty string for no tools, got %q", got) - } -} diff --git a/internal/parser/stream.go b/internal/parser/stream.go deleted file mode 100644 index bbdd231..0000000 --- a/internal/parser/stream.go +++ /dev/null @@ -1,110 +0,0 @@ -package parser - -import "encoding/json" - -type StreamParser func(line string) - -type Parser struct { - Parse StreamParser - Flush func() -} - -// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking) -func CreateStreamParser(onText func(string), onDone func()) Parser { - return CreateStreamParserWithThinking(onText, nil, onDone) -} - -// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。 -// onThinking 可為 nil,表示忽略思考過程。 -func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser { - // accumulated 是所有已輸出內容的串接 - accumulatedText := "" - accumulatedThinking := "" - done := false - - parse := func(line string) { - if done { - return - } - - var obj struct { - Type string `json:"type"` - Subtype string `json:"subtype"` - Message *struct { - Content []struct { - Type string `json:"type"` - Text string `json:"text"` - Thinking string `json:"thinking"` - } `json:"content"` - } `json:"message"` - } - - if err := json.Unmarshal([]byte(line), &obj); err != nil { - return - } - - if obj.Type == "assistant" && obj.Message != nil { - fullText := "" - fullThinking := "" - for _, p := range obj.Message.Content { - switch p.Type { - case "text": - if p.Text != "" { - fullText += p.Text - } - case "thinking": - if p.Thinking != "" { - fullThinking += p.Thinking - } - } - } - - // 處理思考過程(不因去重而 return,避免跳過同行的文字內容) - if onThinking != nil && fullThinking != "" && fullThinking != accumulatedThinking { - // 增量模式:新內容以 accumulated 為前綴 - if len(fullThinking) >= len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking { - delta := fullThinking[len(accumulatedThinking):] - if delta != "" { - onThinking(delta) - } - accumulatedThinking = fullThinking - } else { - // 獨立片段:直接輸出,但 accumulated 要串接 - onThinking(fullThinking) - accumulatedThinking = accumulatedThinking + fullThinking - } - } - - // 處理一般文字 - if fullText == "" || fullText == accumulatedText { - return - } - // 增量模式:新內容以 accumulated 為前綴 - if len(fullText) >= len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText { - delta := fullText[len(accumulatedText):] - if delta != "" { - onText(delta) - } - accumulatedText = fullText - } else { - // 獨立片段:直接輸出,但 accumulated 要串接 - onText(fullText) - accumulatedText = accumulatedText + fullText - } - } - - if obj.Type == "result" && obj.Subtype == "success" { - done = true - onDone() - } - } - - flush := func() { - if !done { - done = true - onDone() - } - } - - return Parser{Parse: parse, Flush: flush} -} diff --git a/internal/parser/stream_test.go b/internal/parser/stream_test.go deleted file mode 100644 index 146cdc2..0000000 --- a/internal/parser/stream_test.go +++ /dev/null @@ -1,304 +0,0 @@ -package parser - -import ( - "encoding/json" - "testing" -) - -func makeAssistantLine(text string) string { - obj := map[string]interface{}{ - "type": "assistant", - "message": map[string]interface{}{ - "content": []map[string]interface{}{ - {"type": "text", "text": text}, - }, - }, - } - b, _ := json.Marshal(obj) - return string(b) -} - -func makeResultLine() string { - b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"}) - return string(b) -} - -func TestStreamParserFragmentMode(t *testing.T) { - // cursor --stream-partial-output 模式:每個訊息是獨立 token fragment - var texts []string - p := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() {}, - ) - - p.Parse(makeAssistantLine("你")) - p.Parse(makeAssistantLine("好!有")) - p.Parse(makeAssistantLine("什")) - p.Parse(makeAssistantLine("麼")) - - if len(texts) != 4 { - t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts) - } - if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" { - t.Fatalf("unexpected texts: %v", texts) - } -} - -func TestStreamParserDeduplicatesFinalFullText(t *testing.T) { - // 最後一個訊息是完整的累積文字,應被跳過(去重) - var texts []string - p := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() {}, - ) - - p.Parse(makeAssistantLine("Hello")) - p.Parse(makeAssistantLine(" world")) - // 最後一個是完整累積文字,應被去重 - p.Parse(makeAssistantLine("Hello world")) - - if len(texts) != 2 { - t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts) - } - if texts[0] != "Hello" || texts[1] != " world" { - t.Fatalf("unexpected texts: %v", texts) - } -} - -func TestStreamParserCallsOnDone(t *testing.T) { - var texts []string - doneCount := 0 - p := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() { doneCount++ }, - ) - - p.Parse(makeResultLine()) - if doneCount != 1 { - t.Fatalf("expected onDone called once, got %d", doneCount) - } - if len(texts) != 0 { - t.Fatalf("expected no text, got %v", texts) - } -} - -func TestStreamParserIgnoresLinesAfterDone(t *testing.T) { - var texts []string - doneCount := 0 - p := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() { doneCount++ }, - ) - - p.Parse(makeResultLine()) - p.Parse(makeAssistantLine("late")) - if len(texts) != 0 { - t.Fatalf("expected no text after done, got %v", texts) - } - if doneCount != 1 { - t.Fatalf("expected onDone called once, got %d", doneCount) - } -} - -func TestStreamParserIgnoresNonAssistantLines(t *testing.T) { - var texts []string - p := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() {}, - ) - - b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}}) - p.Parse(string(b1)) - b2, _ := json.Marshal(map[string]interface{}{ - "type": "assistant", - "message": map[string]interface{}{"content": []interface{}{}}, - }) - p.Parse(string(b2)) - p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`) - - if len(texts) != 0 { - t.Fatalf("expected no texts, got %v", texts) - } -} - -func TestStreamParserIgnoresParseErrors(t *testing.T) { - var texts []string - doneCount := 0 - p := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() { doneCount++ }, - ) - - p.Parse("not json") - p.Parse("{") - p.Parse("") - - if len(texts) != 0 || doneCount != 0 { - t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount) - } -} - -func TestStreamParserJoinsMultipleTextParts(t *testing.T) { - var texts []string - p := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() {}, - ) - - obj := map[string]interface{}{ - "type": "assistant", - "message": map[string]interface{}{ - "content": []map[string]interface{}{ - {"type": "text", "text": "Hello"}, - {"type": "text", "text": " "}, - {"type": "text", "text": "world"}, - }, - }, - } - b, _ := json.Marshal(obj) - p.Parse(string(b)) - - if len(texts) != 1 || texts[0] != "Hello world" { - t.Fatalf("expected ['Hello world'], got %v", texts) - } -} - -func TestStreamParserFlushTriggersDone(t *testing.T) { - var texts []string - doneCount := 0 - p := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() { doneCount++ }, - ) - - p.Parse(makeAssistantLine("Hello")) - // agent 結束但沒有 result/success,手動 flush - p.Flush() - if doneCount != 1 { - t.Fatalf("expected onDone called once after Flush, got %d", doneCount) - } - // 再 flush 不應重複觸發 - p.Flush() - if doneCount != 1 { - t.Fatalf("expected onDone called only once, got %d", doneCount) - } -} - -func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) { - doneCount := 0 - p := CreateStreamParser( - func(text string) {}, - func() { doneCount++ }, - ) - - p.Parse(makeResultLine()) - p.Flush() - if doneCount != 1 { - t.Fatalf("expected onDone called once, got %d", doneCount) - } -} - -func makeThinkingLine(thinking string) string { - obj := map[string]interface{}{ - "type": "assistant", - "message": map[string]interface{}{ - "content": []map[string]interface{}{ - {"type": "thinking", "thinking": thinking}, - }, - }, - } - b, _ := json.Marshal(obj) - return string(b) -} - -func makeThinkingAndTextLine(thinking, text string) string { - obj := map[string]interface{}{ - "type": "assistant", - "message": map[string]interface{}{ - "content": []map[string]interface{}{ - {"type": "thinking", "thinking": thinking}, - {"type": "text", "text": text}, - }, - }, - } - b, _ := json.Marshal(obj) - return string(b) -} - -func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) { - var texts []string - var thinkings []string - p := CreateStreamParserWithThinking( - func(text string) { texts = append(texts, text) }, - func(thinking string) { thinkings = append(thinkings, thinking) }, - func() {}, - ) - - p.Parse(makeThinkingLine("思考中...")) - p.Parse(makeAssistantLine("回答")) - - if len(thinkings) != 1 || thinkings[0] != "思考中..." { - t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings) - } - if len(texts) != 1 || texts[0] != "回答" { - t.Fatalf("expected texts=['回答'], got %v", texts) - } -} - -func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) { - var texts []string - p := CreateStreamParserWithThinking( - func(text string) { texts = append(texts, text) }, - nil, - func() {}, - ) - - p.Parse(makeThinkingLine("忽略的思考")) - p.Parse(makeAssistantLine("文字")) - - if len(texts) != 1 || texts[0] != "文字" { - t.Fatalf("expected texts=['文字'], got %v", texts) - } -} - -func TestStreamParserWithThinkingDeduplication(t *testing.T) { - var thinkings []string - p := CreateStreamParserWithThinking( - func(text string) {}, - func(thinking string) { thinkings = append(thinkings, thinking) }, - func() {}, - ) - - p.Parse(makeThinkingLine("A")) - p.Parse(makeThinkingLine("B")) - // 重複的完整思考,應被跳過 - p.Parse(makeThinkingLine("AB")) - - if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" { - t.Fatalf("expected thinkings=['A','B'], got %v", thinkings) - } -} - -// TestStreamParserThinkingDuplicateButTextStillEmitted 驗證 bug 修復: -// 當 thinking 重複(去重跳過)但同一行有 text 時,text 仍必須輸出。 -func TestStreamParserThinkingDuplicateButTextStillEmitted(t *testing.T) { - var texts []string - var thinkings []string - p := CreateStreamParserWithThinking( - func(text string) { texts = append(texts, text) }, - func(thinking string) { thinkings = append(thinkings, thinking) }, - func() {}, - ) - - // 第一行:thinking="思考中" + text(thinking 為新增,兩者都應輸出) - p.Parse(makeThinkingAndTextLine("思考中", "第一段")) - // 第二行:thinking 與上一行相同(去重),但 text 是新的,text 仍應輸出 - p.Parse(makeThinkingAndTextLine("思考中", "第二段")) - - if len(thinkings) != 1 || thinkings[0] != "思考中" { - t.Fatalf("expected thinkings=['思考中'], got %v", thinkings) - } - if len(texts) != 2 || texts[0] != "第一段" || texts[1] != "第二段" { - t.Fatalf("expected texts=['第一段','第二段'], got %v", texts) - } -} diff --git a/internal/pool/pool.go b/internal/pool/pool.go deleted file mode 100644 index 20a2a53..0000000 --- a/internal/pool/pool.go +++ /dev/null @@ -1,284 +0,0 @@ -package pool - -import ( - "sync" - "time" -) - -type accountStatus struct { - configDir string - activeRequests int - lastUsed int64 - rateLimitUntil int64 - totalRequests int - totalSuccess int - totalErrors int - totalRateLimits int - totalLatencyMs int64 -} - -type AccountStat struct { - ConfigDir string - ActiveRequests int - TotalRequests int - TotalSuccess int - TotalErrors int - TotalRateLimits int - TotalLatencyMs int64 - IsRateLimited bool - RateLimitUntil int64 -} - -type AccountPool struct { - mu sync.Mutex - accounts []*accountStatus -} - -func NewAccountPool(configDirs []string) *AccountPool { - accounts := make([]*accountStatus, 0, len(configDirs)) - for _, dir := range configDirs { - accounts = append(accounts, &accountStatus{configDir: dir}) - } - return &AccountPool{accounts: accounts} -} - -func (p *AccountPool) GetNextConfigDir() string { - p.mu.Lock() - defer p.mu.Unlock() - - if len(p.accounts) == 0 { - return "" - } - - now := time.Now().UnixMilli() - - available := make([]*accountStatus, 0, len(p.accounts)) - for _, a := range p.accounts { - if a.rateLimitUntil < now { - available = append(available, a) - } - } - - target := available - if len(target) == 0 { - target = make([]*accountStatus, len(p.accounts)) - copy(target, p.accounts) - // sort by earliest recovery - for i := 1; i < len(target); i++ { - for j := i; j > 0 && target[j].rateLimitUntil < target[j-1].rateLimitUntil; j-- { - target[j], target[j-1] = target[j-1], target[j] - } - } - } - - // pick least busy then least recently used - best := target[0] - for _, a := range target[1:] { - if a.activeRequests < best.activeRequests { - best = a - } else if a.activeRequests == best.activeRequests && a.lastUsed < best.lastUsed { - best = a - } - } - best.lastUsed = now - return best.configDir -} - -func (p *AccountPool) find(configDir string) *accountStatus { - for _, a := range p.accounts { - if a.configDir == configDir { - return a - } - } - return nil -} - -func (p *AccountPool) ReportRequestStart(configDir string) { - if configDir == "" { - return - } - p.mu.Lock() - defer p.mu.Unlock() - if a := p.find(configDir); a != nil { - a.activeRequests++ - a.totalRequests++ - } -} - -func (p *AccountPool) ReportRequestEnd(configDir string) { - if configDir == "" { - return - } - p.mu.Lock() - defer p.mu.Unlock() - if a := p.find(configDir); a != nil && a.activeRequests > 0 { - a.activeRequests-- - } -} - -func (p *AccountPool) ReportRequestSuccess(configDir string, latencyMs int64) { - if configDir == "" { - return - } - p.mu.Lock() - defer p.mu.Unlock() - if a := p.find(configDir); a != nil { - a.totalSuccess++ - a.totalLatencyMs += latencyMs - } -} - -func (p *AccountPool) ReportRequestError(configDir string, latencyMs int64) { - if configDir == "" { - return - } - p.mu.Lock() - defer p.mu.Unlock() - if a := p.find(configDir); a != nil { - a.totalErrors++ - a.totalLatencyMs += latencyMs - } -} - -func (p *AccountPool) ReportRateLimit(configDir string, penaltyMs int64) { - if configDir == "" { - return - } - if penaltyMs <= 0 { - penaltyMs = 60000 - } - p.mu.Lock() - defer p.mu.Unlock() - if a := p.find(configDir); a != nil { - a.rateLimitUntil = time.Now().UnixMilli() + penaltyMs - a.totalRateLimits++ - } -} - -func (p *AccountPool) GetStats() []AccountStat { - p.mu.Lock() - defer p.mu.Unlock() - now := time.Now().UnixMilli() - stats := make([]AccountStat, len(p.accounts)) - for i, a := range p.accounts { - stats[i] = AccountStat{ - ConfigDir: a.configDir, - ActiveRequests: a.activeRequests, - TotalRequests: a.totalRequests, - TotalSuccess: a.totalSuccess, - TotalErrors: a.totalErrors, - TotalRateLimits: a.totalRateLimits, - TotalLatencyMs: a.totalLatencyMs, - IsRateLimited: a.rateLimitUntil > now, - RateLimitUntil: a.rateLimitUntil, - } - } - return stats -} - -func (p *AccountPool) Count() int { - return len(p.accounts) -} - - -// ─── PoolHandle interface ────────────────────────────────────────────────── -// PoolHandle 讓 handler 可以注入獨立的 pool 實例,避免多 port 模式共用全域 pool。 - -type PoolHandle interface { - GetNextConfigDir() string - ReportRequestStart(configDir string) - ReportRequestEnd(configDir string) - ReportRequestSuccess(configDir string, latencyMs int64) - ReportRequestError(configDir string, latencyMs int64) - ReportRateLimit(configDir string, penaltyMs int64) - GetStats() []AccountStat -} - -// GlobalPoolHandle 包裝全域函式以實作 PoolHandle 介面(單 port 模式使用) -type GlobalPoolHandle struct{} - -func (GlobalPoolHandle) GetNextConfigDir() string { return GetNextAccountConfigDir() } -func (GlobalPoolHandle) ReportRequestStart(d string) { ReportRequestStart(d) } -func (GlobalPoolHandle) ReportRequestEnd(d string) { ReportRequestEnd(d) } -func (GlobalPoolHandle) ReportRequestSuccess(d string, l int64) { ReportRequestSuccess(d, l) } -func (GlobalPoolHandle) ReportRequestError(d string, l int64) { ReportRequestError(d, l) } -func (GlobalPoolHandle) ReportRateLimit(d string, p int64) { ReportRateLimit(d, p) } -func (GlobalPoolHandle) GetStats() []AccountStat { return GetAccountStats() } - -// ─── Global pool ─────────────────────────────────────────────────────────── - -var ( - globalPool *AccountPool - globalMu sync.Mutex -) - -func InitAccountPool(configDirs []string) { - globalMu.Lock() - defer globalMu.Unlock() - globalPool = NewAccountPool(configDirs) -} - -func GetNextAccountConfigDir() string { - globalMu.Lock() - p := globalPool - globalMu.Unlock() - if p == nil { - return "" - } - return p.GetNextConfigDir() -} - -func ReportRequestStart(configDir string) { - globalMu.Lock() - p := globalPool - globalMu.Unlock() - if p != nil { - p.ReportRequestStart(configDir) - } -} - -func ReportRequestEnd(configDir string) { - globalMu.Lock() - p := globalPool - globalMu.Unlock() - if p != nil { - p.ReportRequestEnd(configDir) - } -} - -func ReportRequestSuccess(configDir string, latencyMs int64) { - globalMu.Lock() - p := globalPool - globalMu.Unlock() - if p != nil { - p.ReportRequestSuccess(configDir, latencyMs) - } -} - -func ReportRequestError(configDir string, latencyMs int64) { - globalMu.Lock() - p := globalPool - globalMu.Unlock() - if p != nil { - p.ReportRequestError(configDir, latencyMs) - } -} - -func ReportRateLimit(configDir string, penaltyMs int64) { - globalMu.Lock() - p := globalPool - globalMu.Unlock() - if p != nil { - p.ReportRateLimit(configDir, penaltyMs) - } -} - -func GetAccountStats() []AccountStat { - globalMu.Lock() - p := globalPool - globalMu.Unlock() - if p == nil { - return nil - } - return p.GetStats() -} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go deleted file mode 100644 index 27c1621..0000000 --- a/internal/pool/pool_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package pool - -import ( - "testing" - "time" -) - -func TestEmptyPool(t *testing.T) { - p := NewAccountPool(nil) - if got := p.GetNextConfigDir(); got != "" { - t.Fatalf("expected empty string for empty pool, got %q", got) - } - if p.Count() != 0 { - t.Fatalf("expected count 0, got %d", p.Count()) - } -} - -func TestSingleDir(t *testing.T) { - p := NewAccountPool([]string{"/dir1"}) - if got := p.GetNextConfigDir(); got != "/dir1" { - t.Fatalf("expected /dir1, got %q", got) - } - if got := p.GetNextConfigDir(); got != "/dir1" { - t.Fatalf("expected /dir1 again, got %q", got) - } -} - -func TestRoundRobin(t *testing.T) { - p := NewAccountPool([]string{"/a", "/b", "/c"}) - got := []string{ - p.GetNextConfigDir(), - p.GetNextConfigDir(), - p.GetNextConfigDir(), - p.GetNextConfigDir(), - } - want := []string{"/a", "/b", "/c", "/a"} - for i, w := range want { - if got[i] != w { - t.Fatalf("call %d: expected %q, got %q", i, w, got[i]) - } - } -} - -func TestLeastBusy(t *testing.T) { - p := NewAccountPool([]string{"/dir1", "/dir2", "/dir3"}) - p.ReportRequestStart("/dir1") - p.ReportRequestStart("/dir2") - - if got := p.GetNextConfigDir(); got != "/dir3" { - t.Fatalf("expected /dir3 (least busy), got %q", got) - } - - p.ReportRequestStart("/dir3") - p.ReportRequestEnd("/dir1") - - if got := p.GetNextConfigDir(); got != "/dir1" { - t.Fatalf("expected /dir1 after end, got %q", got) - } -} - -func TestSkipsRateLimited(t *testing.T) { - p := NewAccountPool([]string{"/dir1", "/dir2"}) - p.ReportRateLimit("/dir1", 60000) - - if got := p.GetNextConfigDir(); got != "/dir2" { - t.Fatalf("expected /dir2, got %q", got) - } - if got := p.GetNextConfigDir(); got != "/dir2" { - t.Fatalf("expected /dir2 again, got %q", got) - } -} - -func TestFallbackToSoonestRecovery(t *testing.T) { - p := NewAccountPool([]string{"/dir1", "/dir2"}) - p.ReportRateLimit("/dir1", 60000) - p.ReportRateLimit("/dir2", 30000) - - // dir2 recovers sooner — should be selected - if got := p.GetNextConfigDir(); got != "/dir2" { - t.Fatalf("expected /dir2 (sooner recovery), got %q", got) - } -} - -func TestActiveRequestsDoesNotGoNegative(t *testing.T) { - p := NewAccountPool([]string{"/dir1"}) - p.ReportRequestEnd("/dir1") - p.ReportRequestEnd("/dir1") - if got := p.GetNextConfigDir(); got != "/dir1" { - t.Fatalf("pool should still work, got %q", got) - } -} - -func TestIgnoreUnknownConfigDir(t *testing.T) { - p := NewAccountPool([]string{"/dir1"}) - p.ReportRequestStart("/nonexistent") - p.ReportRequestEnd("/nonexistent") - p.ReportRateLimit("/nonexistent", 60000) - if got := p.GetNextConfigDir(); got != "/dir1" { - t.Fatalf("expected /dir1, got %q", got) - } -} - -func TestRateLimitExpires(t *testing.T) { - p := NewAccountPool([]string{"/dir1", "/dir2"}) - p.ReportRateLimit("/dir1", 50) - - if got := p.GetNextConfigDir(); got != "/dir2" { - t.Fatalf("immediately expected /dir2, got %q", got) - } - - time.Sleep(100 * time.Millisecond) - - if got := p.GetNextConfigDir(); got != "/dir1" { - t.Fatalf("after expiry expected /dir1, got %q", got) - } -} - -func TestGlobalPool(t *testing.T) { - InitAccountPool([]string{"/g1", "/g2"}) - if got := GetNextAccountConfigDir(); got != "/g1" { - t.Fatalf("expected /g1, got %q", got) - } - if got := GetNextAccountConfigDir(); got != "/g2" { - t.Fatalf("expected /g2, got %q", got) - } - if got := GetNextAccountConfigDir(); got != "/g1" { - t.Fatalf("expected /g1 again, got %q", got) - } -} - -func TestGlobalPoolEmpty(t *testing.T) { - InitAccountPool(nil) - if got := GetNextAccountConfigDir(); got != "" { - t.Fatalf("expected empty string for empty global pool, got %q", got) - } -} - -func TestGlobalPoolReinit(t *testing.T) { - InitAccountPool([]string{"/old1", "/old2"}) - GetNextAccountConfigDir() - InitAccountPool([]string{"/new1"}) - if got := GetNextAccountConfigDir(); got != "/new1" { - t.Fatalf("expected /new1 after reinit, got %q", got) - } -} - -func TestGlobalPoolFunctionsNoopBeforeInit(t *testing.T) { - InitAccountPool(nil) - ReportRequestStart("/dir1") - ReportRequestEnd("/dir1") - ReportRateLimit("/dir1", 1000) -} diff --git a/internal/process/kill_unix.go b/internal/process/kill_unix.go deleted file mode 100644 index b235d96..0000000 --- a/internal/process/kill_unix.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !windows - -package process - -import ( - "os/exec" - "syscall" -) - -func killProcessGroup(c *exec.Cmd) error { - if c.Process == nil { - return nil - } - // 殺死整個 process group(負號表示 group) - pgid, err := syscall.Getpgid(c.Process.Pid) - if err == nil { - _ = syscall.Kill(-pgid, syscall.SIGKILL) - } - // 同時也 kill 主程序,以防萬一 - return c.Process.Kill() -} diff --git a/internal/process/kill_windows.go b/internal/process/kill_windows.go deleted file mode 100644 index 1874bbf..0000000 --- a/internal/process/kill_windows.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build windows - -package process - -import ( - "os/exec" -) - -func killProcessGroup(c *exec.Cmd) error { - if c.Process == nil { - return nil - } - return c.Process.Kill() -} diff --git a/internal/process/process.go b/internal/process/process.go deleted file mode 100644 index 681e42f..0000000 --- a/internal/process/process.go +++ /dev/null @@ -1,250 +0,0 @@ -package process - -import ( - "bufio" - "context" - "cursor-api-proxy/internal/env" - "fmt" - "os/exec" - "strings" - "sync" - "syscall" - "time" -) - -type RunResult struct { - Code int - Stdout string - Stderr string -} - -type RunOptions struct { - Cwd string - TimeoutMs int - MaxMode bool - ConfigDir string - Ctx context.Context -} - -type RunStreamingOptions struct { - RunOptions - OnLine func(line string) -} - -// ─── Global child process registry ────────────────────────────────────────── - -var ( - activeMu sync.Mutex - activeChildren []*exec.Cmd -) - -func registerChild(c *exec.Cmd) { - activeMu.Lock() - activeChildren = append(activeChildren, c) - activeMu.Unlock() -} - -func unregisterChild(c *exec.Cmd) { - activeMu.Lock() - for i, ch := range activeChildren { - if ch == c { - activeChildren = append(activeChildren[:i], activeChildren[i+1:]...) - break - } - } - activeMu.Unlock() -} - -func KillAllChildProcesses() { - activeMu.Lock() - all := make([]*exec.Cmd, len(activeChildren)) - copy(all, activeChildren) - activeChildren = nil - activeMu.Unlock() - for _, c := range all { - killProcessGroup(c) - } -} - -// ─── Spawn ──────────────────────────────────────────────────────────────── - -func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd { - envSrc := env.OsEnvToMap() - resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd) - - if opts.MaxMode && maxModeFn != nil { - maxModeFn(resolved.AgentScriptPath, opts.ConfigDir) - } - - envMap := make(map[string]string, len(resolved.Env)) - for k, v := range resolved.Env { - envMap[k] = v - } - if opts.ConfigDir != "" { - envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir - } else if resolved.ConfigDir != "" { - if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists { - envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir - } - } - - envSlice := make([]string, 0, len(envMap)) - for k, v := range envMap { - envSlice = append(envSlice, k+"="+v) - } - - ctx := opts.Ctx - if ctx == nil { - ctx = context.Background() - } - - // 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出 - c := exec.CommandContext(ctx, resolved.Command, resolved.Args...) - c.Dir = opts.Cwd - c.Env = envSlice - // 設定新的 process group,使 kill 能傳遞給所有子孫程序 - c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - // WaitDelay:context cancel 後額外等待這麼久再強制關閉 pipes - c.WaitDelay = 5 * time.Second - // Cancel 函式:殺死整個 process group - c.Cancel = func() error { - return killProcessGroup(c) - } - return c -} - -// MaxModeFn is set by the agent package to avoid import cycle. -var MaxModeFn func(agentScriptPath, configDir string) - -func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) { - ctx := opts.Ctx - var cancel context.CancelFunc - if opts.TimeoutMs > 0 { - if ctx == nil { - ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond) - } else { - ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond) - } - defer cancel() - opts.Ctx = ctx - } else if ctx == nil { - opts.Ctx = context.Background() - } - - c := spawnChild(cmdStr, args, &opts, MaxModeFn) - var stdoutBuf, stderrBuf strings.Builder - c.Stdout = &stdoutBuf - c.Stderr = &stderrBuf - - if err := c.Start(); err != nil { - // context 已取消或命令找不到時 - if opts.Ctx != nil && opts.Ctx.Err() != nil { - return RunResult{Code: -1}, nil - } - if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") { - return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr) - } - return RunResult{}, err - } - registerChild(c) - defer unregisterChild(c) - - err := c.Wait() - code := 0 - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - code = exitErr.ExitCode() - if code == 0 { - code = -1 - } - } else { - // context cancelled or killed — return -1 but no error - return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil - } - } - return RunResult{ - Code: code, - Stdout: stdoutBuf.String(), - Stderr: stderrBuf.String(), - }, nil -} - -type StreamResult struct { - Code int - Stderr string -} - -func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) { - ctx := opts.Ctx - var cancel context.CancelFunc - if opts.TimeoutMs > 0 { - if ctx == nil { - ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond) - } else { - ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond) - } - defer cancel() - opts.RunOptions.Ctx = ctx - } else if opts.RunOptions.Ctx == nil { - opts.RunOptions.Ctx = context.Background() - } - - c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn) - stdoutPipe, err := c.StdoutPipe() - if err != nil { - return StreamResult{}, err - } - stderrPipe, err := c.StderrPipe() - if err != nil { - return StreamResult{}, err - } - - if err := c.Start(); err != nil { - if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") { - return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr) - } - return StreamResult{}, err - } - registerChild(c) - defer unregisterChild(c) - - var stderrBuf strings.Builder - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - scanner := bufio.NewScanner(stdoutPipe) - scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024) - for scanner.Scan() { - line := scanner.Text() - if strings.TrimSpace(line) != "" { - opts.OnLine(line) - } - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - scanner := bufio.NewScanner(stderrPipe) - scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024) - for scanner.Scan() { - stderrBuf.WriteString(scanner.Text()) - stderrBuf.WriteString("\n") - } - }() - - wg.Wait() - err = c.Wait() - code := 0 - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - code = exitErr.ExitCode() - if code == 0 { - code = -1 - } - } - } - return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil -} diff --git a/internal/process/process_test.go b/internal/process/process_test.go deleted file mode 100644 index c48d4b5..0000000 --- a/internal/process/process_test.go +++ /dev/null @@ -1,283 +0,0 @@ -package process_test - -import ( - "context" - "cursor-api-proxy/internal/process" - "os" - "testing" - "time" -) - -// sh 是跨平台 shell 執行小 script 的輔助函式 -func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) { - t.Helper() - return process.Run("sh", []string{"-c", script}, opts) -} - -func TestRun_StdoutAndStderr(t *testing.T) { - result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{}) - if err != nil { - t.Fatal(err) - } - if result.Code != 0 { - t.Errorf("Code = %d, want 0", result.Code) - } - if result.Stdout != "hello\n" { - t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n") - } - if result.Stderr != "world\n" { - t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n") - } -} - -func TestRun_BasicSpawn(t *testing.T) { - result, err := sh(t, "printf ok", process.RunOptions{}) - if err != nil { - t.Fatal(err) - } - if result.Code != 0 { - t.Errorf("Code = %d, want 0", result.Code) - } - if result.Stdout != "ok" { - t.Errorf("Stdout = %q, want ok", result.Stdout) - } -} - -func TestRun_ConfigDir_Propagated(t *testing.T) { - result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`}, - process.RunOptions{ConfigDir: "/test/account/dir"}) - if err != nil { - t.Fatal(err) - } - if result.Stdout != "/test/account/dir" { - t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout) - } -} - -func TestRun_ConfigDir_Absent(t *testing.T) { - // 確保沒有殘留的環境變數 - _ = os.Unsetenv("CURSOR_CONFIG_DIR") - result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`}, - process.RunOptions{}) - if err != nil { - t.Fatal(err) - } - if result.Stdout != "unset" { - t.Errorf("Stdout = %q, want unset", result.Stdout) - } -} - -func TestRun_NonZeroExit(t *testing.T) { - result, err := sh(t, "exit 42", process.RunOptions{}) - if err != nil { - t.Fatal(err) - } - if result.Code != 42 { - t.Errorf("Code = %d, want 42", result.Code) - } -} - -func TestRun_Timeout(t *testing.T) { - start := time.Now() - result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300}) - elapsed := time.Since(start) - if err != nil { - t.Fatal(err) - } - if result.Code == 0 { - t.Error("expected non-zero exit code after timeout") - } - if elapsed > 2*time.Second { - t.Errorf("elapsed %v, want < 2s", elapsed) - } -} - -func TestRunStreaming_OnLine(t *testing.T) { - var lines []string - result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"}, - process.RunStreamingOptions{ - OnLine: func(line string) { lines = append(lines, line) }, - }) - if err != nil { - t.Fatal(err) - } - if result.Code != 0 { - t.Errorf("Code = %d, want 0", result.Code) - } - if len(lines) != 3 { - t.Errorf("got %d lines, want 3: %v", len(lines), lines) - } - if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" { - t.Errorf("lines = %v, want [a b c]", lines) - } -} - -func TestRunStreaming_FlushFinalLine(t *testing.T) { - var lines []string - result, err := process.RunStreaming("sh", []string{"-c", "printf tail"}, - process.RunStreamingOptions{ - OnLine: func(line string) { lines = append(lines, line) }, - }) - if err != nil { - t.Fatal(err) - } - if result.Code != 0 { - t.Errorf("Code = %d, want 0", result.Code) - } - if len(lines) != 1 { - t.Errorf("got %d lines, want 1: %v", len(lines), lines) - } - if lines[0] != "tail" { - t.Errorf("lines[0] = %q, want tail", lines[0]) - } -} - -func TestRunStreaming_ConfigDir(t *testing.T) { - var lines []string - _, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`}, - process.RunStreamingOptions{ - RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"}, - OnLine: func(line string) { lines = append(lines, line) }, - }) - if err != nil { - t.Fatal(err) - } - if len(lines) != 1 || lines[0] != "/my/config/dir" { - t.Errorf("lines = %v, want [/my/config/dir]", lines) - } -} - -func TestRunStreaming_Stderr(t *testing.T) { - result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"}, - process.RunStreamingOptions{OnLine: func(string) {}}) - if err != nil { - t.Fatal(err) - } - if result.Stderr == "" { - t.Error("expected stderr to contain output") - } -} - -func TestRunStreaming_Timeout(t *testing.T) { - start := time.Now() - result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"}, - process.RunStreamingOptions{ - RunOptions: process.RunOptions{TimeoutMs: 300}, - OnLine: func(string) {}, - }) - elapsed := time.Since(start) - if err != nil { - t.Fatal(err) - } - if result.Code == 0 { - t.Error("expected non-zero exit code after timeout") - } - if elapsed > 2*time.Second { - t.Errorf("elapsed %v, want < 2s", elapsed) - } -} - -func TestRunStreaming_Concurrent(t *testing.T) { - var lines1, lines2 []string - done := make(chan struct{}, 2) - - run := func(label string, target *[]string) { - process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"}, - process.RunStreamingOptions{ - OnLine: func(line string) { *target = append(*target, line) }, - }) - done <- struct{}{} - } - - go run("stream1", &lines1) - go run("stream2", &lines2) - - <-done - <-done - - if len(lines1) != 1 || lines1[0] != "stream1" { - t.Errorf("lines1 = %v, want [stream1]", lines1) - } - if len(lines2) != 1 || lines2[0] != "stream2" { - t.Errorf("lines2 = %v, want [stream2]", lines2) - } -} - -func TestRunStreaming_ContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - start := time.Now() - done := make(chan struct{}) - - go func() { - process.RunStreaming("sh", []string{"-c", "sleep 30"}, - process.RunStreamingOptions{ - RunOptions: process.RunOptions{Ctx: ctx}, - OnLine: func(string) {}, - }) - close(done) - }() - - time.AfterFunc(100*time.Millisecond, cancel) - <-done - elapsed := time.Since(start) - - if elapsed > 2*time.Second { - t.Errorf("elapsed %v, want < 2s", elapsed) - } -} - -func TestRun_ContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - start := time.Now() - done := make(chan process.RunResult, 1) - - go func() { - r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx}) - done <- r - }() - - time.AfterFunc(100*time.Millisecond, cancel) - result := <-done - elapsed := time.Since(start) - - if result.Code == 0 { - t.Error("expected non-zero exit code after cancel") - } - if elapsed > 2*time.Second { - t.Errorf("elapsed %v, want < 2s", elapsed) - } -} - -func TestRun_AlreadyCancelledContext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() // 已取消 - - start := time.Now() - result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx}) - elapsed := time.Since(start) - - if result.Code == 0 { - t.Error("expected non-zero exit code") - } - if elapsed > 2*time.Second { - t.Errorf("elapsed %v, want < 2s", elapsed) - } -} - -func TestKillAllChildProcesses(t *testing.T) { - done := make(chan process.RunResult, 1) - go func() { - r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{}) - done <- r - }() - - time.Sleep(80 * time.Millisecond) - process.KillAllChildProcesses() - result := <-done - - if result.Code == 0 { - t.Error("expected non-zero exit code after kill") - } - // 再次呼叫不應 panic - process.KillAllChildProcesses() -} diff --git a/internal/providers/cursor/provider.go b/internal/providers/cursor/provider.go deleted file mode 100644 index e30b3c7..0000000 --- a/internal/providers/cursor/provider.go +++ /dev/null @@ -1,27 +0,0 @@ -package cursor - -import ( - "context" - "cursor-api-proxy/internal/apitypes" - "cursor-api-proxy/internal/config" -) - -type Provider struct { - cfg config.BridgeConfig -} - -func NewProvider(cfg config.BridgeConfig) *Provider { - return &Provider{cfg: cfg} -} - -func (p *Provider) Name() string { - return "cursor" -} - -func (p *Provider) Close() error { - return nil -} - -func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error { - return nil -} diff --git a/internal/providers/factory.go b/internal/providers/factory.go deleted file mode 100644 index 63cf767..0000000 --- a/internal/providers/factory.go +++ /dev/null @@ -1,32 +0,0 @@ -package providers - -import ( - "context" - "cursor-api-proxy/internal/apitypes" - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/providers/cursor" - "cursor-api-proxy/internal/providers/geminiweb" - "fmt" -) - -type Provider interface { - Name() string - Close() error - Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error -} - -func NewProvider(cfg config.BridgeConfig) (Provider, error) { - providerType := cfg.Provider - if providerType == "" { - providerType = "cursor" - } - - switch providerType { - case "cursor": - return cursor.NewProvider(cfg), nil - case "gemini-web": - return geminiweb.NewPlaywrightProvider(cfg) - default: - return nil, fmt.Errorf("unknown provider: %s", providerType) - } -} diff --git a/internal/providers/geminiweb/browser.go b/internal/providers/geminiweb/browser.go deleted file mode 100644 index a94894e..0000000 --- a/internal/providers/geminiweb/browser.go +++ /dev/null @@ -1,125 +0,0 @@ -package geminiweb - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - "time" - - "github.com/go-rod/rod" - "github.com/go-rod/rod/lib/launcher" - "github.com/go-rod/rod/lib/proto" -) - -type Browser struct { - browser *rod.Browser - visible bool -} - -func NewBrowser(visible bool) (*Browser, error) { - l := launcher.New() - if visible { - l = l.Headless(false) - } else { - l = l.Headless(true) - } - - url, err := l.Launch() - if err != nil { - return nil, fmt.Errorf("failed to launch browser: %w", err) - } - - b := rod.New().ControlURL(url) - if err := b.Connect(); err != nil { - return nil, fmt.Errorf("failed to connect browser: %w", err) - } - - return &Browser{browser: b, visible: visible}, nil -} - -func (b *Browser) Close() error { - if b.browser != nil { - return b.browser.Close() - } - return nil -} - -func (b *Browser) NewPage() (*rod.Page, error) { - return b.browser.Page(proto.TargetCreateTarget{URL: "about:blank"}) -} - -type Cookie struct { - Name string `json:"name"` - Value string `json:"value"` - Domain string `json:"domain"` - Path string `json:"path"` - Expires float64 `json:"expires"` - HTTPOnly bool `json:"httpOnly"` - Secure bool `json:"secure"` -} - -func LoadCookiesFromFile(cookieFile string) ([]Cookie, error) { - data, err := os.ReadFile(cookieFile) - if err != nil { - return nil, fmt.Errorf("failed to read cookies: %w", err) - } - - var cookies []Cookie - if err := json.Unmarshal(data, &cookies); err != nil { - return nil, fmt.Errorf("failed to parse cookies: %w", err) - } - - return cookies, nil -} - -func SaveCookiesToFile(cookies []Cookie, cookieFile string) error { - data, err := json.MarshalIndent(cookies, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal cookies: %w", err) - } - - dir := filepath.Dir(cookieFile) - if err := os.MkdirAll(dir, 0755); err != nil { - return fmt.Errorf("failed to create cookie dir: %w", err) - } - - if err := os.WriteFile(cookieFile, data, 0644); err != nil { - return fmt.Errorf("failed to write cookies: %w", err) - } - - return nil -} - -func SetCookiesOnPage(page *rod.Page, cookies []Cookie) error { - var protoCookies []*proto.NetworkCookieParam - for _, c := range cookies { - p := &proto.NetworkCookieParam{ - Name: c.Name, - Value: c.Value, - Domain: c.Domain, - Path: c.Path, - HTTPOnly: c.HTTPOnly, - Secure: c.Secure, - } - if c.Expires > 0 { - exp := proto.TimeSinceEpoch(c.Expires) - p.Expires = exp - } - protoCookies = append(protoCookies, p) - } - return page.SetCookies(protoCookies) -} - -func WaitForElement(page *rod.Page, selector string, timeout time.Duration) (*rod.Element, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return page.Context(ctx).Element(selector) -} - -func WaitForElements(page *rod.Page, selector string, timeout time.Duration) (rod.Elements, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return page.Context(ctx).Elements(selector) -} diff --git a/internal/providers/geminiweb/browser_manager.go b/internal/providers/geminiweb/browser_manager.go deleted file mode 100644 index 3cff6f7..0000000 --- a/internal/providers/geminiweb/browser_manager.go +++ /dev/null @@ -1,173 +0,0 @@ -package geminiweb - -import ( - "fmt" - "os" - "path/filepath" - "sync" - - "github.com/go-rod/rod" - "github.com/go-rod/rod/lib/launcher" - "github.com/go-rod/rod/lib/proto" -) - -// BrowserManager 管理瀏覽器實例的生命週期 -type BrowserManager struct { - mu sync.Mutex - browser *rod.Browser - userDataDir string - page *rod.Page - visible bool - isRunning bool - currentModel string -} - -var ( - globalManager *BrowserManager - globalMu sync.Mutex -) - -// GetBrowserManager 獲取全域瀏覽器管理器(單例) -func GetBrowserManager(userDataDir string, visible bool) (*BrowserManager, error) { - globalMu.Lock() - defer globalMu.Unlock() - - if globalManager != nil { - return globalManager, nil - } - - manager, err := NewBrowserManager(userDataDir, visible) - if err != nil { - return nil, err - } - - globalManager = manager - return globalManager, nil -} - -// NewBrowserManager 建立新的瀏覽器管理器 -func NewBrowserManager(userDataDir string, visible bool) (*BrowserManager, error) { - cleanLockFiles(userDataDir) - - if err := os.MkdirAll(userDataDir, 0755); err != nil { - return nil, fmt.Errorf("failed to create user data dir: %w", err) - } - - return &BrowserManager{ - userDataDir: userDataDir, - visible: visible, - }, nil -} - -// cleanLockFiles 清理 Chrome 的殘留鎖檔案 -func cleanLockFiles(userDataDir string) { - lockFiles := []string{ - "SingletonLock", - "SingletonCookie", - "SingletonSocket", - "Default/SingletonLock", - "Default/SingletonCookie", - "Default/SingletonSocket", - } - - for _, file := range lockFiles { - path := filepath.Join(userDataDir, file) - os.Remove(path) - } -} - -// Launch 啟動瀏覽器(如果尚未啟動) -func (m *BrowserManager) Launch() error { - m.mu.Lock() - defer m.mu.Unlock() - - if m.isRunning && m.browser != nil { - return nil - } - - l := launcher.New() - - if m.visible { - l = l.Headless(false) - } else { - l = l.Headless(true) - } - - l = l.UserDataDir(m.userDataDir) - - url, err := l.Launch() - if err != nil { - return fmt.Errorf("failed to launch browser: %w", err) - } - - b := rod.New().ControlURL(url) - if err := b.Connect(); err != nil { - return fmt.Errorf("failed to connect browser: %w", err) - } - - m.browser = b - - page, err := b.Page(proto.TargetCreateTarget{URL: "about:blank"}) - if err != nil { - _ = b.Close() - return fmt.Errorf("failed to create page: %w", err) - } - - m.page = page - m.isRunning = true - - return nil -} - -// GetPage 獲取頁面 -func (m *BrowserManager) GetPage() (*rod.Page, error) { - m.mu.Lock() - defer m.mu.Unlock() - - if !m.isRunning || m.browser == nil { - return nil, fmt.Errorf("browser not running") - } - - return m.page, nil -} - -// Close 關閉瀏覽器 -func (m *BrowserManager) Close() error { - m.mu.Lock() - defer m.mu.Unlock() - - if !m.isRunning { - return nil - } - - var err error - if m.browser != nil { - err = m.browser.Close() - m.browser = nil - } - - m.page = nil - m.isRunning = false - return err -} - -// IsRunning 檢查瀏覽器是否正在運行 -func (m *BrowserManager) IsRunning() bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.isRunning -} - -// SetCurrentModel 設定當前模型 -func (m *BrowserManager) SetCurrentModel(model string) { - m.mu.Lock() - defer m.mu.Unlock() - m.currentModel = model -} - -// GetCurrentModel 獲取當前模型 -func (m *BrowserManager) GetCurrentModel() string { - m.mu.Lock() - defer m.mu.Unlock() - return m.currentModel -} diff --git a/internal/providers/geminiweb/page.go b/internal/providers/geminiweb/page.go deleted file mode 100644 index 7365bd3..0000000 --- a/internal/providers/geminiweb/page.go +++ /dev/null @@ -1,250 +0,0 @@ -package geminiweb - -import ( - "context" - "fmt" - "strings" - "time" - - "github.com/go-rod/rod" -) - -const geminiURL = "https://gemini.google.com/app" - -// 輸入框選擇器(依優先順序) -var inputSelectors = []string{ - ".ProseMirror", - "rich-textarea", - "div[role='textbox'][contenteditable='true']", - "div[contenteditable='true']", - "textarea", -} - -// NavigateToGemini 導航到 Gemini -func NavigateToGemini(page *rod.Page) error { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := page.Context(ctx).Navigate(geminiURL); err != nil { - return fmt.Errorf("failed to navigate: %w", err) - } - return page.Context(ctx).WaitLoad() -} - -// IsLoggedIn 檢查是否已登入 -func IsLoggedIn(page *rod.Page) bool { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - for _, sel := range inputSelectors { - if _, err := page.Context(ctx).Element(sel); err == nil { - return true - } - } - return false -} - -// SelectModel 選擇模型(可選) -func SelectModel(page *rod.Page, model string) error { - fmt.Printf("[GeminiWeb] Model selection skipped (using current model)\n") - return nil -} - -// TypeInput 在輸入框中輸入文字 -func TypeInput(page *rod.Page, text string) error { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - fmt.Println("[GeminiWeb] Looking for input field...") - - // 1. 嘗試所有選擇器 - var inputEl *rod.Element - var err error - - for _, sel := range inputSelectors { - fmt.Printf(" Trying: %s\n", sel) - inputEl, err = page.Context(ctx).Element(sel) - if err == nil { - fmt.Printf(" ✓ Found with: %s\n", sel) - break - } - } - - if err != nil { - // 2. Fallback: 嘗試等待頁面載入完成後重試 - fmt.Println("[GeminiWeb] Waiting for page to fully load...") - time.Sleep(3 * time.Second) - - for _, sel := range inputSelectors { - fmt.Printf(" Retrying: %s\n", sel) - inputEl, err = page.Context(ctx).Element(sel) - if err == nil { - fmt.Printf(" ✓ Found with: %s\n", sel) - break - } - } - } - - if err != nil { - // 3. Debug: 印出頁面標題和 URL - info, _ := page.Info() - fmt.Printf("[GeminiWeb] DEBUG: URL=%s Title=%s\n", info.URL, info.Title) - - // 4. Fallback: 嘗試更通用的選擇器 - fmt.Println("[GeminiWeb] Trying generic selectors...") - genericSelectors := []string{ - "div[contenteditable]", - "[contenteditable]", - "textarea", - "input[type='text']", - } - - for _, sel := range genericSelectors { - fmt.Printf(" Trying generic: %s\n", sel) - inputEl, err = page.Context(ctx).Element(sel) - if err == nil { - fmt.Printf(" ✓ Found with: %s\n", sel) - break - } - } - } - - if err != nil { - info, _ := page.Info() - return fmt.Errorf("input field not found after trying all selectors (URL=%s)", info.URL) - } - - // 2. Focus 輸入框 - fmt.Printf("[GeminiWeb] Focusing input field...\n") - if err := inputEl.Focus(); err != nil { - return fmt.Errorf("failed to focus input: %w", err) - } - - time.Sleep(500 * time.Millisecond) - - // 3. 使用 Input 方法 - fmt.Printf("[GeminiWeb] Typing %d chars...\n", len(text)) - if err := inputEl.Input(text); err != nil { - return fmt.Errorf("failed to input text: %w", err) - } - - time.Sleep(200 * time.Millisecond) - - fmt.Println("[GeminiWeb] Input complete") - return nil -} - -// ClickSend 發送訊息 -func ClickSend(page *rod.Page) error { - // 方法 1: 按 Enter - if err := page.Keyboard.Press('\r'); err != nil { - return fmt.Errorf("failed to press Enter: %w", err) - } - - time.Sleep(200 * time.Millisecond) - return nil -} - -// WaitForReady 等待頁面空閒 -func WaitForReady(page *rod.Page) error { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - fmt.Println("[GeminiWeb] Checking if page is ready...") - - for { - select { - case <-ctx.Done(): - fmt.Println("[GeminiWeb] Page ready check timeout, proceeding anyway") - return nil - default: - time.Sleep(500 * time.Millisecond) - - // 檢查是否有停止按鈕 - hasStopBtn := false - stopBtns, _ := page.Elements("button[aria-label*='Stop'], button[aria-label*='停止']") - for _, btn := range stopBtns { - visible, _ := btn.Visible() - if visible { - hasStopBtn = true - break - } - } - - if !hasStopBtn { - fmt.Println("[GeminiWeb] Page is ready") - return nil - } - } - } -} - -// ExtractResponse 提取回應文字 -func ExtractResponse(page *rod.Page) (string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() - - var lastText string - lastUpdate := time.Now() - - for { - select { - case <-ctx.Done(): - if lastText != "" { - return lastText, nil - } - return "", fmt.Errorf("response timeout") - default: - time.Sleep(500 * time.Millisecond) - - // 尋找回應文字 - for _, sel := range responseSelectors { - elements, err := page.Elements(sel) - if err != nil || len(elements) == 0 { - continue - } - - // 取得最後一個元素的文字 - lastEl := elements[len(elements)-1] - text, err := lastEl.Text() - if err != nil { - continue - } - - text = strings.TrimSpace(text) - if text != "" && text != lastText && len(text) > len(lastText) { - lastText = text - lastUpdate = time.Now() - fmt.Printf("[GeminiWeb] Response length: %d\n", len(text)) - } - } - - // 檢查是否已完成(2 秒內沒有新內容) - if time.Since(lastUpdate) > 2*time.Second && lastText != "" { - // 最後檢查一次是否還有停止按鈕 - hasStopBtn := false - stopBtns, _ := page.Elements("button[aria-label*='Stop'], button[aria-label*='停止']") - for _, btn := range stopBtns { - visible, _ := btn.Visible() - if visible { - hasStopBtn = true - break - } - } - - if !hasStopBtn { - return lastText, nil - } - } - } - } -} - -// 默認的回應選擇器 -var responseSelectors = []string{ - ".model-response-text", - ".message-content", - ".markdown", - ".prose", - "model-response", -} diff --git a/internal/providers/geminiweb/playwright_provider.go b/internal/providers/geminiweb/playwright_provider.go deleted file mode 100644 index ced9dfa..0000000 --- a/internal/providers/geminiweb/playwright_provider.go +++ /dev/null @@ -1,641 +0,0 @@ -package geminiweb - -import ( - "context" - "cursor-api-proxy/internal/apitypes" - "cursor-api-proxy/internal/config" - "fmt" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/playwright-community/playwright-go" -) - -// PlaywrightProvider 使用 Playwright 的 Gemini Provider -type PlaywrightProvider struct { - cfg config.BridgeConfig - pw *playwright.Playwright - browser playwright.Browser - context playwright.BrowserContext - page playwright.Page - mu sync.Mutex - userDataDir string -} - -var ( - playwrightInstance *playwright.Playwright - playwrightOnce sync.Once - playwrightErr error -) - -// NewPlaywrightProvider 建立新的 Playwright Provider -func NewPlaywrightProvider(cfg config.BridgeConfig) (*PlaywrightProvider, error) { - // 確保 Playwright 已初始化(單例) - playwrightOnce.Do(func() { - playwrightInstance, playwrightErr = playwright.Run() - if playwrightErr != nil { - playwrightErr = fmt.Errorf("failed to run playwright: %w", playwrightErr) - } - }) - - if playwrightErr != nil { - return nil, playwrightErr - } - - // 清理 Chrome 鎖檔案 - userDataDir := filepath.Join(cfg.GeminiAccountDir, "default-session") - cleanLockFiles(userDataDir) - - // 確保目錄存在 - if err := os.MkdirAll(userDataDir, 0755); err != nil { - return nil, fmt.Errorf("failed to create user data dir: %w", err) - } - - return &PlaywrightProvider{ - cfg: cfg, - pw: playwrightInstance, - userDataDir: userDataDir, - }, nil -} - -// getName 返回 Provider 名稱 -func (p *PlaywrightProvider) Name() string { - return "gemini-web" -} - -// launchIfNeeded 如果需要則啟動瀏覽器 -func (p *PlaywrightProvider) launchIfNeeded() error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.context != nil && p.page != nil { - return nil - } - - fmt.Println("[GeminiWeb] Launching Chromium...") - - // 使用 LaunchPersistentContext(自動保存 session) - context, err := p.pw.Chromium.LaunchPersistentContext(p.userDataDir, - playwright.BrowserTypeLaunchPersistentContextOptions{ - Headless: playwright.Bool(!p.cfg.GeminiBrowserVisible), - Args: []string{ - "--no-first-run", - "--no-default-browser-check", - "--disable-background-networking", - "--disable-extensions", - "--disable-plugins", - "--disable-sync", - }, - }) - if err != nil { - return fmt.Errorf("failed to launch persistent context: %w", err) - } - - p.context = context - - // 取得或建立頁面 - pages := context.Pages() - if len(pages) > 0 { - p.page = pages[0] - } else { - page, err := context.NewPage() - if err != nil { - _ = context.Close() - return fmt.Errorf("failed to create page: %w", err) - } - p.page = page - } - - fmt.Println("[GeminiWeb] Browser launched") - return nil -} - -// Generate 生成回應 -func (p *PlaywrightProvider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) (err error) { - // 確保在返回錯誤時保存診斷 - defer func() { - if err != nil { - fmt.Println("[GeminiWeb] Error occurred, saving diagnostics...") - _ = p.saveDiagnostics() - } - }() - - fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model) - - // 1. 確保瀏覽器已啟動 - if err := p.launchIfNeeded(); err != nil { - return fmt.Errorf("failed to launch browser: %w", err) - } - - // 2. 導航到 Gemini(如果需要) - currentURL := p.page.URL() - if !strings.Contains(currentURL, "gemini.google.com") { - fmt.Println("[GeminiWeb] Navigating to Gemini...") - if _, err := p.page.Goto("https://gemini.google.com/app", playwright.PageGotoOptions{ - WaitUntil: playwright.WaitUntilStateDomcontentloaded, - Timeout: playwright.Float(60000), - }); err != nil { - return fmt.Errorf("failed to navigate: %w", err) - } - // 額外等待 JavaScript 載入 - fmt.Println("[GeminiWeb] Waiting for page to initialize...") - time.Sleep(3 * time.Second) - } - - // 3. 調試模式:等待用戶確認 - if p.cfg.GeminiBrowserVisible { - fmt.Println("\n" + strings.Repeat("=", 70)) - fmt.Println("🔍 調試模式:瀏覽器已開啟") - fmt.Println("請檢查瀏覽器畫面,然後按 ENTER 繼續...") - fmt.Println("如果有問題,請查看: /tmp/gemini-debug.*") - fmt.Println(strings.Repeat("=", 70)) - - var input string - fmt.Scanln(&input) - } - - // 4. 等待頁面完全載入(project-golem 策略) - fmt.Println("[GeminiWeb] Waiting for page to be ready...") - if err := p.waitForPageReady(); err != nil { - fmt.Printf("[GeminiWeb] Warning: %v\n", err) - - // 額外調試:輸出頁面 HTML 結構 - if p.cfg.GeminiBrowserVisible { - html, _ := p.page.Content() - debugPath := "/tmp/gemini-debug.html" - if err := os.WriteFile(debugPath, []byte(html), 0644); err == nil { - fmt.Printf("[GeminiWeb] HTML saved to: %s\n", debugPath) - } - } - } - - // 4. 檢查登入狀態 - fmt.Println("[GeminiWeb] Checking login status...") - loggedIn := p.isLoggedIn() - if !loggedIn { - fmt.Println("[GeminiWeb] Not logged in, continuing anyway") - if p.cfg.GeminiBrowserVisible { - fmt.Println("\n========================================") - fmt.Println("Browser is open. You can:") - fmt.Println("1. Log in to Gemini now") - fmt.Println("2. Continue without login") - fmt.Println("========================================\n") - } - } else { - fmt.Println("[GeminiWeb] ✓ Logged in") - } - - // 5. 選擇模型(如果支援) - if err := p.selectModel(model); err != nil { - fmt.Printf("[GeminiWeb] Warning: model selection failed: %v\n", err) - } - - // 6. 建構提示詞 - prompt := buildPromptFromMessagesPlaywright(messages) - fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt)) - - // 7. 輸入文字(使用 Playwright 的 Auto-wait) - if err := p.typeInput(prompt); err != nil { - return fmt.Errorf("failed to type: %w", err) - } - - // 7. 發送訊息 - fmt.Println("[GeminiWeb] Sending message...") - if err := p.sendMessage(); err != nil { - return fmt.Errorf("failed to send: %w", err) - } - - // 8. 提取回應 - fmt.Println("[GeminiWeb] Waiting for response...") - response, err := p.extractResponse() - if err != nil { - return fmt.Errorf("failed to extract response: %w", err) - } - - // 9. 回調 - cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response}) - cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true}) - - fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response)) - return nil -} - -// Close 關閉 Provider -func (p *PlaywrightProvider) Close() error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.context != nil { - if err := p.context.Close(); err != nil { - return err - } - p.context = nil - p.page = nil - } - return nil -} - -// saveDiagnostics 保存診斷信息 -func (p *PlaywrightProvider) saveDiagnostics() error { - if p.page == nil { - return fmt.Errorf("no page available") - } - - // 截圖 - screenshotPath := "/tmp/gemini-debug.png" - if _, err := p.page.Screenshot(playwright.PageScreenshotOptions{ - Path: playwright.String(screenshotPath), - }); err == nil { - fmt.Printf("[GeminiWeb] Screenshot saved: %s\n", screenshotPath) - } - - // HTML - htmlPath := "/tmp/gemini-debug.html" - if html, err := p.page.Content(); err == nil { - if err := os.WriteFile(htmlPath, []byte(html), 0644); err == nil { - fmt.Printf("[GeminiWeb] HTML saved: %s\n", htmlPath) - } - } - - // 輸出頁面信息 - url := p.page.URL() - title, _ := p.page.Title() - fmt.Printf("[GeminiWeb] Diagnostics: URL=%s, Title=%s\n", url, title) - - return nil -} - -// waitForPageReady 等待頁面完全就緒(project-golem 策略) -func (p *PlaywrightProvider) waitForPageReady() error { - fmt.Println("[GeminiWeb] Checking for ready state...") - - // 1. 等待停止按鈕消失(如果存在) - _, _ = p.page.WaitForSelector("button[aria-label*='Stop'], button[aria-label*='停止']", playwright.PageWaitForSelectorOptions{ - State: playwright.WaitForSelectorStateDetached, - Timeout: playwright.Float(5000), - }) - - // 2. 嘗試多種等待策略 - inputSelectors := []string{ - ".ql-editor.ql-blank", - ".ql-editor", - "div[contenteditable='true'][role='textbox']", - "div[contenteditable='true']", - ".ProseMirror", - "rich-textarea", - "textarea", - } - - // 策略 A: 等待任一輸入框出現 - for i, sel := range inputSelectors { - fmt.Printf(" [%d/%d] Waiting for: %s\n", i+1, len(inputSelectors), sel) - locator := p.page.Locator(sel) - if err := locator.WaitFor(playwright.LocatorWaitForOptions{ - Timeout: playwright.Float(5000), - State: playwright.WaitForSelectorStateVisible, - }); err == nil { - fmt.Printf(" ✓ Input field found: %s\n", sel) - return nil - } - } - - // 策略 B: 等待頁面完全載入 - fmt.Println("[GeminiWeb] Waiting for page load...") - time.Sleep(3 * time.Second) - - // 策略 C: 使用 JavaScript 檢查 - fmt.Println("[GeminiWeb] Checking with JavaScript...") - result, err := p.page.Evaluate(` - () => { - // 檢查所有可能的輸入元素 - const selectors = [ - '.ql-editor.ql-blank', - '.ql-editor', - 'div[contenteditable="true"][role="textbox"]', - 'div[contenteditable="true"]', - '.ProseMirror', - 'rich-textarea', - 'textarea' - ]; - - for (const sel of selectors) { - const el = document.querySelector(sel); - if (el) { - return { - found: true, - selector: sel, - tagName: el.tagName, - className: el.className, - visible: el.offsetParent !== null - }; - } - } - - return { found: false }; - } - `) - - if err == nil { - if m, ok := result.(map[string]interface{}); ok { - if found, _ := m["found"].(bool); found { - sel, _ := m["selector"].(string) - fmt.Printf(" ✓ JavaScript found: %s\n", sel) - return nil - } - } - } - - // 策略 D: 調試模式 - 輸出頁面結構 - if p.cfg.GeminiBrowserVisible { - fmt.Println("[GeminiWeb].dump: Page structure analysis") - _, _ = p.page.Evaluate(` - () => { - const allElements = document.querySelectorAll('*'); - const inputLike = []; - for (const el of allElements) { - if (el.contentEditable === 'true' || - el.role === 'textbox' || - el.tagName === 'TEXTAREA' || - el.tagName === 'INPUT') { - inputLike.push({ - tag: el.tagName, - class: el.className, - id: el.id, - role: el.role, - contentEditable: el.contentEditable - }); - } - } - console.log('Input-like elements:', inputLike); - } - `) - } - - return fmt.Errorf("no input field found after all strategies") -} - -// isLoggedIn 檢查是否已登入 -func (p *PlaywrightProvider) isLoggedIn() bool { - // 嘗試找輸入框(登入狀態的主要特徵) - selectors := []string{ - ".ProseMirror", - "rich-textarea", - "div[role='textbox']", - "div[contenteditable='true']", - "textarea", - } - - for _, sel := range selectors { - locator := p.page.Locator(sel) - if count, _ := locator.Count(); count > 0 { - return true - } - } - return false -} - -// typeInput 輸入文字(使用 Playwright 的 Auto-wait) -func (p *PlaywrightProvider) typeInput(text string) error { - fmt.Println("[GeminiWeb] Looking for input field...") - - selectors := []string{ - ".ql-editor.ql-blank", - ".ql-editor", - "div[contenteditable='true'][role='textbox']", - "div[contenteditable='true']", - ".ProseMirror", - "rich-textarea", - "textarea", - } - - var inputLocator playwright.Locator - var found bool - - for _, sel := range selectors { - fmt.Printf(" Trying: %s\n", sel) - locator := p.page.Locator(sel) - if err := locator.WaitFor(playwright.LocatorWaitForOptions{ - Timeout: playwright.Float(3000), - }); err == nil { - inputLocator = locator - found = true - fmt.Printf(" ✓ Found with: %s\n", sel) - break - } - } - - if !found { - // 錯誤會被 Generate 的 defer 捕獲並保存診斷 - url := p.page.URL() - title, _ := p.page.Title() - return fmt.Errorf("input field not found (URL=%s, Title=%s). Diagnostics will be saved to /tmp/", url, title) - } - - // Focus 並填充(Playwright 自動等待) - fmt.Printf("[GeminiWeb] Typing %d chars...\n", len(text)) - if err := inputLocator.Fill(text); err != nil { - return fmt.Errorf("failed to fill: %w", err) - } - - fmt.Println("[GeminiWeb] Input complete") - return nil -} - -// sendMessage 發送訊息 -func (p *PlaywrightProvider) sendMessage() error { - // 方法 1: 按 Enter(最可靠) - if err := p.page.Keyboard().Press("Enter"); err != nil { - return fmt.Errorf("failed to press Enter: %w", err) - } - - time.Sleep(200 * time.Millisecond) - - // 方法 2: 嘗試點擊發送按鈕(補強) - _, _ = p.page.Evaluate(` - () => { - const keywords = ['發送', 'Send', '傳送']; - const buttons = Array.from(document.querySelectorAll('button, [role="button"]')); - - for (const btn of buttons) { - const text = (btn.innerText || btn.textContent || '').trim(); - const label = (btn.getAttribute('aria-label') || '').trim(); - - // 跳過停止按鈕 - if (['停止', 'Stop', '中斷'].includes(text) || label.toLowerCase().includes('stop')) { - continue; - } - - if (keywords.some(kw => text.includes(kw) || label.includes(kw))) { - btn.click(); - return true; - } - } - return false; - } - `) - - return nil -} - -// extractResponse 提取回應 -func (p *PlaywrightProvider) extractResponse() (string, error) { - var lastText string - var stableCount int - lastUpdate := time.Now() - timeout := 120 * time.Second - startTime := time.Now() - - for time.Since(startTime) < timeout { - time.Sleep(500 * time.Millisecond) - - // 使用 JavaScript 提取回應文字(更精確) - result, err := p.page.Evaluate(` - () => { - // 尋找所有可能的回應容器 - const selectors = [ - 'model-response', - '.model-response', - 'message-content', - '.message-content' - ]; - - for (const sel of selectors) { - const el = document.querySelector(sel); - if (el) { - // 嘗試找markdown內容 - const markdown = el.querySelector('.markdown, .prose, [class*="markdown"]'); - if (markdown && markdown.innerText.trim()) { - let text = markdown.innerText.trim(); - // 移除常見的標籤前綴 - text = text.replace(/^Gemini said\s*\n*/i, '').replace(/^Gemini\s*[::]\s*\n*/i, '').trim(); - return { text: text, source: sel + ' .markdown' }; - } - - // 嘗試找純文字內容(排除標籤) - let textContent = el.innerText.trim(); - if (textContent) { - // 移除常見的標籤前綴 - textContent = textContent.replace(/^Gemini said\s*\n*/i, '').replace(/^Gemini\s*[::]\s*\n*/i, '').trim(); - return { text: textContent, source: sel }; - } - } - } - - return { text: '', source: 'none' }; - } - `) - - if err == nil { - if m, ok := result.(map[string]interface{}); ok { - text, _ := m["text"].(string) - text = strings.TrimSpace(text) - - if text != "" && len(text) > len(lastText) { - lastText = text - lastUpdate = time.Now() - stableCount = 0 - fmt.Printf("[GeminiWeb] Response: %d chars\n", len(text)) - } - } - } - - // 檢查是否完成(需要連續 3 次穩定) - if time.Since(lastUpdate) > 500*time.Millisecond && lastText != "" { - stableCount++ - if stableCount >= 3 { - // 最終檢查:停止按鈕是否還存在 - stopBtn := p.page.Locator("button[aria-label*='Stop'], button[aria-label*='停止'], button[data-test-id='stop-button']") - count, _ := stopBtn.Count() - - if count == 0 { - fmt.Println("[GeminiWeb] ✓ Response complete") - return lastText, nil - } - } - } - } - - if lastText != "" { - fmt.Println("[GeminiWeb] ✓ Response complete (timeout)") - return lastText, nil - } - return "", fmt.Errorf("response timeout") -} - -// selectModel 選擇 Gemini 模型 -// Gemini Web 只有三種模型:fast, thinking, pro -func (p *PlaywrightProvider) selectModel(model string) error { - // 映射模型名稱到 Gemini Web 的模型選擇器 - modelMap := map[string]string{ - "fast": "Fast", - "thinking": "Thinking", - "pro": "Pro", - "gemini-fast": "Fast", - "gemini-thinking": "Thinking", - "gemini-pro": "Pro", - "gemini-2.0-fast": "Fast", - "gemini-2.0-flash": "Fast", // 相容舊名稱 - "gemini-2.5-pro": "Pro", - "gemini-2.5-pro-thinking": "Thinking", - } - - // 從完整模型名稱中提取類型 - modelType := "" - modelLower := strings.ToLower(model) - for key, value := range modelMap { - if strings.Contains(modelLower, strings.ToLower(key)) || modelLower == strings.ToLower(key) { - modelType = value - break - } - } - - if modelType == "" { - // 默認使用 Fast - fmt.Printf("[GeminiWeb] Unknown model '%s', defaulting to Fast\n", model) - return nil - } - - fmt.Printf("[GeminiWeb] Selecting model: %s\n", modelType) - - // 點擊模型選擇器 - modelSelector := p.page.Locator("button[aria-label*='Model'], button[aria-label*='模型'], [data-test-id='model-selector']") - if count, _ := modelSelector.Count(); count > 0 { - if err := modelSelector.First().Click(); err != nil { - fmt.Printf("[GeminiWeb] Warning: could not click model selector: %v\n", err) - } else { - time.Sleep(500 * time.Millisecond) - - // 選擇對應的模型選項 - optionSelector := p.page.Locator(fmt.Sprintf("button:has-text('%s'), [role='menuitem']:has-text('%s')", modelType, modelType)) - if count, _ := optionSelector.Count(); count > 0 { - if err := optionSelector.First().Click(); err != nil { - fmt.Printf("[GeminiWeb] Warning: could not select model: %v\n", err) - } else { - fmt.Printf("[GeminiWeb] ✓ Model selected: %s\n", modelType) - time.Sleep(500 * time.Millisecond) - } - } - } - } - - return nil -} - -// buildPromptFromMessages 從訊息列表建構提示詞 -func buildPromptFromMessagesPlaywright(messages []apitypes.Message) string { - var prompt string - for _, m := range messages { - switch m.Role { - case "system": - prompt += "System: " + m.Content + "\n\n" - case "user": - prompt += m.Content + "\n\n" - case "assistant": - prompt += "Assistant: " + m.Content + "\n\n" - } - } - return prompt -} diff --git a/internal/providers/geminiweb/pool.go b/internal/providers/geminiweb/pool.go deleted file mode 100644 index 88d4f89..0000000 --- a/internal/providers/geminiweb/pool.go +++ /dev/null @@ -1,169 +0,0 @@ -package geminiweb - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - "sync" - "time" -) - -type GeminiSession struct { - Name string `json:"name"` - CookieFile string `json:"cookie_file"` - LastUsed int64 `json:"last_used"` - ActiveCount int `json:"active_count"` - RateLimitEnd int64 `json:"rate_limit_end"` -} - -type SessionPool struct { - mu sync.Mutex - sessions []*GeminiSession - dir string - maxCount int -} - -func NewSessionPool(dir string, maxSessions int) (*SessionPool, error) { - if err := os.MkdirAll(dir, 0755); err != nil { - return nil, fmt.Errorf("failed to create session dir: %w", err) - } - - sessions, err := loadSessions(dir) - if err != nil { - return nil, fmt.Errorf("failed to load sessions: %w", err) - } - - return &SessionPool{ - sessions: sessions, - dir: dir, - maxCount: maxSessions, - }, nil -} - -func loadSessions(dir string) ([]*GeminiSession, error) { - entries, err := os.ReadDir(dir) - if err != nil { - return nil, err - } - - var sessions []*GeminiSession - for _, entry := range entries { - if !entry.IsDir() { - continue - } - name := entry.Name() - metaPath := filepath.Join(dir, name, "session.json") - data, err := os.ReadFile(metaPath) - if err != nil { - continue - } - - var s GeminiSession - if err := json.Unmarshal(data, &s); err != nil { - continue - } - sessions = append(sessions, &s) - } - - return sessions, nil -} - -func (p *SessionPool) Count() int { - p.mu.Lock() - defer p.mu.Unlock() - return len(p.sessions) -} - -func (p *SessionPool) GetAvailable() *GeminiSession { - p.mu.Lock() - defer p.mu.Unlock() - - now := time.Now().UnixMilli() - - var available []*GeminiSession - for _, s := range p.sessions { - if s.RateLimitEnd < now { - available = append(available, s) - } - } - - if len(available) == 0 { - return nil - } - - var best *GeminiSession - for _, s := range available { - if best == nil || s.ActiveCount < best.ActiveCount { - best = s - } else if s.ActiveCount == best.ActiveCount && s.LastUsed < best.LastUsed { - best = s - } - } - - return best -} - -func (p *SessionPool) StartSession(s *GeminiSession) { - p.mu.Lock() - defer p.mu.Unlock() - s.ActiveCount++ - s.LastUsed = time.Now().UnixMilli() - p.saveSession(s) -} - -func (p *SessionPool) EndSession(s *GeminiSession) { - p.mu.Lock() - defer p.mu.Unlock() - if s.ActiveCount > 0 { - s.ActiveCount-- - } - p.saveSession(s) -} - -func (p *SessionPool) RateLimitSession(s *GeminiSession, durationMs int64) { - p.mu.Lock() - defer p.mu.Unlock() - s.RateLimitEnd = time.Now().UnixMilli() + durationMs - p.saveSession(s) -} - -func (p *SessionPool) saveSession(s *GeminiSession) { - metaPath := filepath.Join(p.dir, s.Name, "session.json") - data, err := json.MarshalIndent(s, "", " ") - if err != nil { - return - } - _ = os.WriteFile(metaPath, data, 0644) -} - -func (p *SessionPool) CreateSession(name string) (*GeminiSession, error) { - p.mu.Lock() - defer p.mu.Unlock() - - sessionDir := filepath.Join(p.dir, name) - if err := os.MkdirAll(sessionDir, 0755); err != nil { - return nil, fmt.Errorf("failed to create session dir: %w", err) - } - - s := &GeminiSession{ - Name: name, - CookieFile: filepath.Join(sessionDir, "cookies.json"), - LastUsed: time.Now().UnixMilli(), - } - - p.sessions = append(p.sessions, s) - p.saveSession(s) - - return s, nil -} - -func (p *SessionPool) GetSessionNames() []string { - p.mu.Lock() - defer p.mu.Unlock() - names := make([]string, len(p.sessions)) - for i, s := range p.sessions { - names[i] = s.Name - } - return names -} diff --git a/internal/providers/geminiweb/provider.go b/internal/providers/geminiweb/provider.go deleted file mode 100644 index 257f57e..0000000 --- a/internal/providers/geminiweb/provider.go +++ /dev/null @@ -1,196 +0,0 @@ -package geminiweb - -import ( - "context" - "cursor-api-proxy/internal/apitypes" - "cursor-api-proxy/internal/config" - "fmt" - "os" - "path/filepath" - "strings" - "sync" - "time" -) - -// Provider 使用持久化瀏覽器管理器 -type Provider struct { - cfg config.BridgeConfig - managerOnce sync.Once - manager *BrowserManager - managerErr error -} - -// NewProvider 建立新的 Provider -func NewProvider(cfg config.BridgeConfig) *Provider { - return &Provider{cfg: cfg} -} - -// getName 返回 Provider 名稱 -func (p *Provider) Name() string { - return "gemini-web" -} - -// Close 關閉瀏覽器 -func (p *Provider) Close() error { - if p.manager != nil { - return p.manager.Close() - } - return nil -} - -// getManager 獲取或初始化瀏覽器管理器(單例) -func (p *Provider) getManager() (*BrowserManager, error) { - p.managerOnce.Do(func() { - sessionDir := p.getSessionDir() - p.manager, p.managerErr = GetBrowserManager(sessionDir, p.cfg.GeminiBrowserVisible) - }) - return p.manager, p.managerErr -} - -// getSessionDir 獲取 session 目錄 -func (p *Provider) getSessionDir() string { - // 使用單一 session 目錄(簡化設計) - return filepath.Join(p.cfg.GeminiAccountDir, "default-session") -} - -// Generate 生成回應 -func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error { - fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model) - - // 1. 獲取瀏覽器管理器 - manager, err := p.getManager() - if err != nil { - return fmt.Errorf("failed to get browser manager: %w", err) - } - - // 2. 啟動瀏覽器(如果尚未啟動) - if !manager.IsRunning() { - fmt.Printf("[GeminiWeb] Launching browser...\n") - if err := manager.Launch(); err != nil { - return fmt.Errorf("failed to launch browser: %w", err) - } - } - - // 3. 獲取頁面 - page, err := manager.GetPage() - if err != nil { - return fmt.Errorf("failed to get page: %w", err) - } - - // 4. 檢查當前 URL,如果不是 Gemini 則導航 - currentURL, _ := page.Info() - if !strings.Contains(currentURL.URL, "gemini.google.com") { - fmt.Printf("[GeminiWeb] Navigating to Gemini...\n") - if err := NavigateToGemini(page); err != nil { - return fmt.Errorf("failed to navigate: %w", err) - } - time.Sleep(2 * time.Second) - } - - // 5. 檢查登入狀態 - fmt.Printf("[GeminiWeb] Checking login status...\n") - if !IsLoggedIn(page) { - fmt.Printf("[GeminiWeb] Not logged in, continuing anyway\n") - - if p.cfg.GeminiBrowserVisible { - fmt.Println("\n========================================") - fmt.Println("Browser is open. You can:") - fmt.Println("1. Log in to Gemini now") - fmt.Println("2. Continue without login") - fmt.Println("========================================\n") - } - } else { - fmt.Printf("[GeminiWeb] Logged in\n") - } - - // 6. 等待頁面就緒 - if err := WaitForReady(page); err != nil { - fmt.Printf("[GeminiWeb] Warning: %v\n", err) - } - - // 7. 建構提示詞 - prompt := buildPromptFromMessages(messages) - fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt)) - - // 8. 輸入文字 - if err := TypeInput(page, prompt); err != nil { - return fmt.Errorf("failed to type input: %w", err) - } - - // 9. 發送 - fmt.Printf("[GeminiWeb] Sending message...\n") - if err := ClickSend(page); err != nil { - return fmt.Errorf("failed to send: %w", err) - } - - // 10. 提取回應 - fmt.Printf("[GeminiWeb] Waiting for response...\n") - response, err := ExtractResponse(page) - if err != nil { - return fmt.Errorf("failed to extract response: %w", err) - } - - // 11. 串流回調 - cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response}) - cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true}) - - fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response)) - return nil -} - -// buildPromptFromMessages 從訊息列表建構提示詞 -func buildPromptFromMessages(messages []apitypes.Message) string { - var prompt string - for _, m := range messages { - switch m.Role { - case "system": - prompt += "System: " + m.Content + "\n\n" - case "user": - prompt += m.Content + "\n\n" - case "assistant": - prompt += "Assistant: " + m.Content + "\n\n" - } - } - return prompt -} - -// RunLogin 執行登入流程(供 gemini-login 命令使用) -func RunLogin(cfg config.BridgeConfig, sessionName string) error { - if sessionName == "" { - sessionName = "default-session" - } - - sessionDir := filepath.Join(cfg.GeminiAccountDir, sessionName) - if err := os.MkdirAll(sessionDir, 0755); err != nil { - return fmt.Errorf("failed to create session dir: %w", err) - } - - fmt.Printf("Starting browser for login. Session: %s\n", sessionName) - fmt.Printf("Session directory: %s\n", sessionDir) - fmt.Println("Please log in to your Gemini account in the browser window.") - fmt.Println("Press Ctrl+C when you have completed the login...") - - manager, err := NewBrowserManager(sessionDir, true) // visible=true - if err != nil { - return fmt.Errorf("failed to create browser manager: %w", err) - } - - if err := manager.Launch(); err != nil { - return fmt.Errorf("failed to launch browser: %w", err) - } - defer manager.Close() - - page, err := manager.GetPage() - if err != nil { - return fmt.Errorf("failed to get page: %w", err) - } - - if err := NavigateToGemini(page); err != nil { - return fmt.Errorf("failed to navigate: %w", err) - } - - // 等待用戶手動登入... - // 使用 Ctrl+C 退出,瀏覽器資料會自動保存在 userDataDir - - return nil -} diff --git a/internal/router/router.go b/internal/router/router.go deleted file mode 100644 index 745f958..0000000 --- a/internal/router/router.go +++ /dev/null @@ -1,147 +0,0 @@ -package router - -import ( - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/handlers" - "cursor-api-proxy/internal/httputil" - "cursor-api-proxy/internal/logger" - "cursor-api-proxy/internal/pool" - "fmt" - "net/http" - "os" - "time" -) - -type RouterOptions struct { - Version string - Config config.BridgeConfig - ModelCache *handlers.ModelCacheRef - LastModel *string - Pool pool.PoolHandle -} - -func NewRouter(opts RouterOptions) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - cfg := opts.Config - pathname := r.URL.Path - method := r.Method - remoteAddress := r.RemoteAddr - if r.Header.Get("X-Real-IP") != "" { - remoteAddress = r.Header.Get("X-Real-IP") - } - - logger.LogIncoming(method, pathname, remoteAddress) - - defer func() { - logger.AppendSessionLine(cfg.SessionsLogPath, method, pathname, remoteAddress, 200) - }() - - if cfg.RequiredKey != "" { - token := httputil.ExtractBearerToken(r) - if token != cfg.RequiredKey { - httputil.WriteJSON(w, 401, map[string]interface{}{ - "error": map[string]string{"message": "Invalid API key", "code": "unauthorized"}, - }, nil) - return - } - } - - switch { - case method == "GET" && pathname == "/health": - handlers.HandleHealth(w, r, opts.Version, cfg) - - case method == "GET" && pathname == "/v1/models": - opts.ModelCache.HandleModels(w, r, cfg) - - case method == "POST" && pathname == "/v1/chat/completions": - raw, err := httputil.ReadBody(r) - if err != nil { - httputil.WriteJSON(w, 400, map[string]interface{}{ - "error": map[string]string{"message": "failed to read body", "code": "bad_request"}, - }, nil) - return - } - // 根據 Provider 選擇處理方式 - provider := cfg.Provider - if provider == "" { - provider = "cursor" - } - if provider == "gemini-web" { - handlers.HandleGeminiChatCompletions(w, r, cfg, raw, method, pathname, remoteAddress) - } else { - handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress) - } - - case method == "POST" && pathname == "/v1/messages": - raw, err := httputil.ReadBody(r) - if err != nil { - httputil.WriteJSON(w, 400, map[string]interface{}{ - "error": map[string]string{"message": "failed to read body", "code": "bad_request"}, - }, nil) - return - } - handlers.HandleAnthropicMessages(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress) - - case (method == "POST" || method == "GET") && pathname == "/v1/completions": - httputil.WriteJSON(w, 404, map[string]interface{}{ - "error": map[string]string{ - "message": "Legacy completions endpoint is not supported. Use POST /v1/chat/completions instead.", - "code": "not_found", - }, - }, nil) - - case pathname == "/v1/embeddings": - httputil.WriteJSON(w, 404, map[string]interface{}{ - "error": map[string]string{ - "message": "Embeddings are not supported by this proxy.", - "code": "not_found", - }, - }, nil) - - default: - httputil.WriteJSON(w, 404, map[string]interface{}{ - "error": map[string]string{"message": "Not found", "code": "not_found"}, - }, nil) - } - } -} - -func recoveryMiddleware(logPath string, next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - defer func() { - if rec := recover(); rec != nil { - msg := fmt.Sprintf("%v", rec) - fmt.Fprintf(os.Stderr, "[%s] Proxy panic: %s\n", time.Now().UTC().Format(time.RFC3339), msg) - line := fmt.Sprintf("%s ERROR %s %s %s %s\n", - time.Now().UTC().Format(time.RFC3339), r.Method, r.URL.Path, r.RemoteAddr, - msg[:min(200, len(msg))]) - if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { - _, _ = f.WriteString(line) - f.Close() - } - if !isHeaderWritten(w) { - httputil.WriteJSON(w, 500, map[string]interface{}{ - "error": map[string]string{"message": msg, "code": "internal_error"}, - }, nil) - } - } - }() - next(w, r) - } -} - -func isHeaderWritten(w http.ResponseWriter) bool { - // Can't reliably detect without wrapping; always try to write - return false -} - -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func WrapWithRecovery(logPath string, handler http.HandlerFunc) http.HandlerFunc { - return recoveryMiddleware(logPath, handler) -} diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go deleted file mode 100644 index dd514d2..0000000 --- a/internal/sanitize/sanitize.go +++ /dev/null @@ -1,95 +0,0 @@ -package sanitize - -import "regexp" - -type rule struct { - pattern *regexp.Regexp - replacement string -} - -var rules = []rule{ - {regexp.MustCompile(`(?i)x-anthropic-billing-header:[^\n]*\n?`), ""}, - {regexp.MustCompile(`(?i)\bcc_version=[^\s;,\n]+[;,]?\s*`), ""}, - {regexp.MustCompile(`(?i)\bcc_entrypoint=[^\s;,\n]+[;,]?\s*`), ""}, - {regexp.MustCompile(`(?i)\bcch=[a-f0-9]+[;,]?\s*`), ""}, - {regexp.MustCompile(`\bClaude Code\b`), "Cursor"}, - {regexp.MustCompile(`(?i)Anthropic['']s official CLI for Claude`), "Cursor AI assistant"}, - {regexp.MustCompile(`\bAnthropic\b`), "Cursor"}, - {regexp.MustCompile(`(?i)anthropic\.com`), "cursor.com"}, - {regexp.MustCompile(`(?i)claude\.ai`), "cursor.sh"}, - {regexp.MustCompile(`^[;,\s]+`), ""}, -} - -func SanitizeText(text string) string { - for _, r := range rules { - text = r.pattern.ReplaceAllString(text, r.replacement) - } - return text -} - -func SanitizeMessages(messages []interface{}) []interface{} { - result := make([]interface{}, len(messages)) - for i, raw := range messages { - if raw == nil { - result[i] = raw - continue - } - m, ok := raw.(map[string]interface{}) - if !ok { - result[i] = raw - continue - } - newMsg := make(map[string]interface{}, len(m)) - for k, v := range m { - newMsg[k] = v - } - switch c := m["content"].(type) { - case string: - newMsg["content"] = SanitizeText(c) - case []interface{}: - newParts := make([]interface{}, len(c)) - for j, p := range c { - if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" { - if t, ok := pm["text"].(string); ok { - newPart := make(map[string]interface{}, len(pm)) - for k, v := range pm { - newPart[k] = v - } - newPart["text"] = SanitizeText(t) - newParts[j] = newPart - continue - } - } - newParts[j] = p - } - newMsg["content"] = newParts - } - result[i] = newMsg - } - return result -} - -func SanitizeSystem(system interface{}) interface{} { - switch v := system.(type) { - case string: - return SanitizeText(v) - case []interface{}: - result := make([]interface{}, len(v)) - for i, p := range v { - if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" { - if t, ok := pm["text"].(string); ok { - newPart := make(map[string]interface{}, len(pm)) - for k, val := range pm { - newPart[k] = val - } - newPart["text"] = SanitizeText(t) - result[i] = newPart - continue - } - } - result[i] = p - } - return result - } - return system -} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go deleted file mode 100644 index 59886a2..0000000 --- a/internal/sanitize/sanitize_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package sanitize - -import ( - "strings" - "testing" -) - -func TestSanitizeTextAnthropicBilling(t *testing.T) { - input := "x-anthropic-billing-header: abc123\nHello" - got := SanitizeText(input) - if strings.Contains(got, "x-anthropic-billing-header") { - t.Errorf("billing header not removed: %q", got) - } -} - -func TestSanitizeTextClaudeCode(t *testing.T) { - input := "I am Claude Code assistant" - got := SanitizeText(input) - if strings.Contains(got, "Claude Code") { - t.Errorf("'Claude Code' not replaced: %q", got) - } - if !strings.Contains(got, "Cursor") { - t.Errorf("'Cursor' not present in output: %q", got) - } -} - -func TestSanitizeTextAnthropic(t *testing.T) { - input := "Powered by Anthropic's technology and anthropic.com" - got := SanitizeText(input) - if strings.Contains(got, "Anthropic") { - t.Errorf("'Anthropic' not replaced: %q", got) - } - if strings.Contains(got, "anthropic.com") { - t.Errorf("'anthropic.com' not replaced: %q", got) - } -} - -func TestSanitizeTextNoChange(t *testing.T) { - input := "Hello, this is a normal message about cursor." - got := SanitizeText(input) - if got != input { - t.Errorf("unexpected change: %q -> %q", input, got) - } -} - -func TestSanitizeMessages(t *testing.T) { - messages := []interface{}{ - map[string]interface{}{"role": "user", "content": "Ask Claude Code something"}, - map[string]interface{}{"role": "system", "content": "Use Anthropic's tools"}, - } - result := SanitizeMessages(messages) - - for _, raw := range result { - m := raw.(map[string]interface{}) - c := m["content"].(string) - if strings.Contains(c, "Claude Code") || strings.Contains(c, "Anthropic") { - t.Errorf("found unsanitized content: %q", c) - } - } -} diff --git a/internal/server/server.go b/internal/server/server.go deleted file mode 100644 index adaa5d6..0000000 --- a/internal/server/server.go +++ /dev/null @@ -1,159 +0,0 @@ -package server - -import ( - "context" - "crypto/tls" - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/handlers" - "cursor-api-proxy/internal/pool" - "cursor-api-proxy/internal/process" - "cursor-api-proxy/internal/logger" - "cursor-api-proxy/internal/router" - "fmt" - "net/http" - "os" - "os/signal" - "syscall" - "time" -) - -type ServerOptions struct { - Version string - Config config.BridgeConfig - Pool pool.PoolHandle -} - -func StartBridgeServer(opts ServerOptions) []*http.Server { - cfg := opts.Config - var servers []*http.Server - - if len(cfg.ConfigDirs) > 0 { - if cfg.MultiPort { - for i, dir := range cfg.ConfigDirs { - port := cfg.Port + i - subCfg := cfg - subCfg.Port = port - subCfg.ConfigDirs = []string{dir} - subCfg.MultiPort = false - subPool := pool.NewAccountPool([]string{dir}) - srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg, Pool: subPool}) - servers = append(servers, srv) - } - return servers - } - pool.InitAccountPool(cfg.ConfigDirs) - } - - servers = append(servers, startSingleServer(opts)) - return servers -} - -func startSingleServer(opts ServerOptions) *http.Server { - cfg := opts.Config - - modelCache := &handlers.ModelCacheRef{} - lastModel := cfg.DefaultModel - - ph := opts.Pool - if ph == nil { - ph = pool.GlobalPoolHandle{} - } - handler := router.NewRouter(router.RouterOptions{ - Version: opts.Version, - Config: cfg, - ModelCache: modelCache, - LastModel: &lastModel, - Pool: ph, - }) - handler = router.WrapWithRecovery(cfg.SessionsLogPath, handler) - - useTLS := cfg.TLSCertPath != "" && cfg.TLSKeyPath != "" - - srv := &http.Server{ - Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), - Handler: handler, - } - - if useTLS { - cert, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath) - if err != nil { - fmt.Fprintf(os.Stderr, "TLS error: %v\n", err) - os.Exit(1) - } - srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} - } - - scheme := "http" - if useTLS { - scheme = "https" - } - - go func() { - var err error - if useTLS { - err = srv.ListenAndServeTLS("", "") - } else { - err = srv.ListenAndServe() - } - if err != nil && err != http.ErrServerClosed { - if isAddrInUse(err) { - fmt.Fprintf(os.Stderr, "❌ Port %d is already in use. Set CURSOR_BRIDGE_PORT to use a different port.\n", cfg.Port) - } else { - fmt.Fprintf(os.Stderr, "❌ Server error: %v\n", err) - } - os.Exit(1) - } - }() - - logger.LogServerStart(opts.Version, scheme, cfg.Host, cfg.Port, cfg) - - return srv -} - -func SetupGracefulShutdown(servers []*http.Server, timeoutMs int) { - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) - - go func() { - sig := <-sigCh - logger.LogShutdown(sig.String()) - - process.KillAllChildProcesses() - - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) - defer cancel() - - done := make(chan struct{}) - go func() { - for _, srv := range servers { - _ = srv.Shutdown(ctx) - } - close(done) - }() - - select { - case <-done: - os.Exit(0) - case <-ctx.Done(): - fmt.Fprintln(os.Stderr, "[shutdown] Timed out waiting for connections to drain — forcing exit.") - os.Exit(1) - } - }() -} - -func isAddrInUse(err error) bool { - return err != nil && (contains(err.Error(), "address already in use") || contains(err.Error(), "bind: address already in use")) -} - -func contains(s, sub string) bool { - return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsHelper(s, sub)) -} - -func containsHelper(s, sub string) bool { - for i := 0; i <= len(s)-len(sub); i++ { - if s[i:i+len(sub)] == sub { - return true - } - } - return false -} diff --git a/internal/server/server_test.go b/internal/server/server_test.go deleted file mode 100644 index eba9262..0000000 --- a/internal/server/server_test.go +++ /dev/null @@ -1,331 +0,0 @@ -package server_test - -import ( - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/server" - "encoding/json" - "fmt" - "io" - "net" - "context" - "net/http" - "os" - "strings" - "testing" - "time" -) - -// freePort 取得一個暫時可用的隨機 port -func freePort(t *testing.T) int { - t.Helper() - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - port := l.Addr().(*net.TCPAddr).Port - l.Close() - return port -} - -// makeFakeAgentBin 建立一個 shell script,模擬 agent 固定輸出 -// sync 模式:直接輸出一行文字 -// stream 模式:輸出 JSON stream 行 -func makeFakeAgentBin(t *testing.T, syncOutput string) string { - t.Helper() - dir := t.TempDir() - script := dir + "/agent" - content := fmt.Sprintf(`#!/bin/sh -# 若有 --stream-json 則輸出 stream 格式 -for arg; do - if [ "$arg" = "--stream-json" ]; then - printf '%%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"%s"}]}}' - printf '%%s\n' '{"type":"result","subtype":"success"}' - exit 0 - fi -done -# 否則輸出 sync 格式 -printf '%%s' '%s' -`, syncOutput, syncOutput) - if err := os.WriteFile(script, []byte(content), 0755); err != nil { - t.Fatal(err) - } - return script -} - -// makeFakeAgentBinWithModels 額外支援 --list-models 輸出 -func makeFakeAgentBinWithModels(t *testing.T) string { - t.Helper() - dir := t.TempDir() - script := dir + "/agent" - content := `#!/bin/sh -for arg; do - if [ "$arg" = "--list-models" ]; then - printf 'claude-3-opus - Claude 3 Opus\n' - printf 'claude-3-sonnet - Claude 3 Sonnet\n' - exit 0 - fi - if [ "$arg" = "--stream-json" ]; then - printf '%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}' - printf '%s\n' '{"type":"result","subtype":"success"}' - exit 0 - fi -done -printf 'Hello from agent' -` - if err := os.WriteFile(script, []byte(content), 0755); err != nil { - t.Fatal(err) - } - return script -} - -func makeTestConfig(agentBin string, port int, overrides ...func(*config.BridgeConfig)) config.BridgeConfig { - cfg := config.BridgeConfig{ - AgentBin: agentBin, - Host: "127.0.0.1", - Port: port, - DefaultModel: "auto", - Mode: "ask", - Force: false, - ApproveMcps: false, - StrictModel: true, - Workspace: os.TempDir(), - TimeoutMs: 30000, - SessionsLogPath: os.TempDir() + "/test-sessions.log", - ChatOnlyWorkspace: true, - Verbose: false, - MaxMode: false, - ConfigDirs: []string{}, - MultiPort: false, - WinCmdlineMax: 30000, - } - for _, fn := range overrides { - fn(&cfg) - } - return cfg -} - -func waitListening(t *testing.T, host string, port int, timeout time.Duration) { - t.Helper() - deadline := time.Now().Add(timeout) - for time.Now().Before(deadline) { - conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 50*time.Millisecond) - if err == nil { - conn.Close() - return - } - time.Sleep(20 * time.Millisecond) - } - t.Fatalf("server on port %d did not start within %v", port, timeout) -} - -func doRequest(t *testing.T, method, url, body string, headers map[string]string) (int, string) { - t.Helper() - var reqBody io.Reader - if body != "" { - reqBody = strings.NewReader(body) - } - req, err := http.NewRequest(method, url, reqBody) - if err != nil { - t.Fatal(err) - } - if body != "" { - req.Header.Set("Content-Type", "application/json") - } - for k, v := range headers { - req.Header.Set(k, v) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - data, _ := io.ReadAll(resp.Body) - return resp.StatusCode, string(data) -} - -func TestBridgeServer_Health(t *testing.T) { - port := freePort(t) - agentBin := makeFakeAgentBinWithModels(t) - cfg := makeTestConfig(agentBin, port) - - srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) - waitListening(t, "127.0.0.1", port, 3*time.Second) - defer func() { - for _, s := range srvs { - s.Shutdown(context.Background()) - } - }() - - status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil) - if status != 200 { - t.Fatalf("status = %d, want 200; body: %s", status, body) - } - var result map[string]interface{} - json.Unmarshal([]byte(body), &result) - if result["ok"] != true { - t.Errorf("ok = %v, want true", result["ok"]) - } - if result["version"] != "1.0.0" { - t.Errorf("version = %v, want 1.0.0", result["version"]) - } -} - -func TestBridgeServer_Models(t *testing.T) { - port := freePort(t) - agentBin := makeFakeAgentBinWithModels(t) - cfg := makeTestConfig(agentBin, port) - - srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) - waitListening(t, "127.0.0.1", port, 3*time.Second) - defer func() { - for _, s := range srvs { - s.Shutdown(context.Background()) - } - }() - - status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/v1/models", port), "", nil) - if status != 200 { - t.Fatalf("status = %d, want 200; body: %s", status, body) - } - var result map[string]interface{} - json.Unmarshal([]byte(body), &result) - if result["object"] != "list" { - t.Errorf("object = %v, want list", result["object"]) - } - data := result["data"].([]interface{}) - if len(data) < 2 { - t.Errorf("data len = %d, want >= 2", len(data)) - } -} - -func TestBridgeServer_Unauthorized(t *testing.T) { - port := freePort(t) - agentBin := makeFakeAgentBinWithModels(t) - cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) { - c.RequiredKey = "secret123" - }) - - srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) - waitListening(t, "127.0.0.1", port, 3*time.Second) - defer func() { - for _, s := range srvs { - s.Shutdown(context.Background()) - } - }() - - status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil) - if status != 401 { - t.Fatalf("status = %d, want 401; body: %s", status, body) - } - var result map[string]interface{} - json.Unmarshal([]byte(body), &result) - errObj := result["error"].(map[string]interface{}) - if errObj["message"] != "Invalid API key" { - t.Errorf("message = %v, want 'Invalid API key'", errObj["message"]) - } -} - -func TestBridgeServer_AuthorizedKey(t *testing.T) { - port := freePort(t) - agentBin := makeFakeAgentBinWithModels(t) - cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) { - c.RequiredKey = "secret123" - }) - - srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) - waitListening(t, "127.0.0.1", port, 3*time.Second) - defer func() { - for _, s := range srvs { - s.Shutdown(context.Background()) - } - }() - - status, _ := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", map[string]string{ - "Authorization": "Bearer secret123", - }) - if status != 200 { - t.Errorf("status = %d, want 200", status) - } -} - -func TestBridgeServer_NotFound(t *testing.T) { - port := freePort(t) - agentBin := makeFakeAgentBinWithModels(t) - cfg := makeTestConfig(agentBin, port) - - srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) - waitListening(t, "127.0.0.1", port, 3*time.Second) - defer func() { - for _, s := range srvs { - s.Shutdown(context.Background()) - } - }() - - status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/unknown", port), "", nil) - if status != 404 { - t.Fatalf("status = %d, want 404; body: %s", status, body) - } - var result map[string]interface{} - json.Unmarshal([]byte(body), &result) - errObj := result["error"].(map[string]interface{}) - if errObj["code"] != "not_found" { - t.Errorf("code = %v, want not_found", errObj["code"]) - } -} - -func TestBridgeServer_ChatCompletions_Sync(t *testing.T) { - port := freePort(t) - agentBin := makeFakeAgentBin(t, "Hello from agent") - cfg := makeTestConfig(agentBin, port) - - srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) - waitListening(t, "127.0.0.1", port, 3*time.Second) - defer func() { - for _, s := range srvs { - s.Shutdown(context.Background()) - } - }() - - reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hi"}]}` - status, body := doRequest(t, "POST", fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", port), reqBody, nil) - if status != 200 { - t.Fatalf("status = %d, want 200; body: %s", status, body) - } - var result map[string]interface{} - json.Unmarshal([]byte(body), &result) - if result["object"] != "chat.completion" { - t.Errorf("object = %v, want chat.completion", result["object"]) - } - choices := result["choices"].([]interface{}) - msg := choices[0].(map[string]interface{})["message"].(map[string]interface{}) - if msg["content"] != "Hello from agent" { - t.Errorf("content = %v, want 'Hello from agent'", msg["content"]) - } -} - -func TestBridgeServer_MultiPort(t *testing.T) { - basePort := freePort(t) - agentBin := makeFakeAgentBinWithModels(t) - - dir1 := t.TempDir() - dir2 := t.TempDir() - - cfg := makeTestConfig(agentBin, basePort, func(c *config.BridgeConfig) { - c.ConfigDirs = []string{dir1, dir2} - c.MultiPort = true - }) - - srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) - if len(srvs) != 2 { - t.Fatalf("got %d servers, want 2", len(srvs)) - } - - // 等待兩個 server 啟動(port 可能會衝突,這裡不嚴格測試 port 分配) - time.Sleep(200 * time.Millisecond) - - defer func() { - for _, s := range srvs { - s.Shutdown(context.Background()) - } - }() -} diff --git a/internal/toolcall/toolcall.go b/internal/toolcall/toolcall.go deleted file mode 100644 index 4d47176..0000000 --- a/internal/toolcall/toolcall.go +++ /dev/null @@ -1,154 +0,0 @@ -package toolcall - -import ( - "encoding/json" - "regexp" - "strings" -) - -type ToolCall struct { - Name string - Arguments string // JSON string -} - -type ParsedResponse struct { - TextContent string - ToolCalls []ToolCall -} - -func (p *ParsedResponse) HasToolCalls() bool { - return len(p.ToolCalls) > 0 -} - -// Modified regex to handle nested JSON -var toolCallTagRe = regexp.MustCompile(`(?s)行政法规\s*(\{(?:[^{}]|\{[^{}]*\})*\})\s*ugalakh`) -var antmlFunctionCallsRe = regexp.MustCompile(`(?s)\s*(.*?)\s*`) -var antmlInvokeRe = regexp.MustCompile(`(?s)\s*(.*?)\s*`) -var antmlParamRe = regexp.MustCompile(`(?s)(.*?)`) - -func ExtractToolCalls(text string, toolNames map[string]bool) *ParsedResponse { - result := &ParsedResponse{} - - if locs := toolCallTagRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 { - var calls []ToolCall - var textParts []string - last := 0 - for _, loc := range locs { - if loc[0] > last { - textParts = append(textParts, text[last:loc[0]]) - } - jsonStr := text[loc[2]:loc[3]] - if tc := parseToolCallJSON(jsonStr, toolNames); tc != nil { - calls = append(calls, *tc) - } else { - textParts = append(textParts, text[loc[0]:loc[1]]) - } - last = loc[1] - } - if last < len(text) { - textParts = append(textParts, text[last:]) - } - if len(calls) > 0 { - result.TextContent = strings.TrimSpace(strings.Join(textParts, "")) - result.ToolCalls = calls - return result - } - } - - if locs := antmlFunctionCallsRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 { - var calls []ToolCall - var textParts []string - last := 0 - for _, loc := range locs { - if loc[0] > last { - textParts = append(textParts, text[last:loc[0]]) - } - block := text[loc[2]:loc[3]] - invokes := antmlInvokeRe.FindAllStringSubmatch(block, -1) - for _, inv := range invokes { - name := inv[1] - if toolNames != nil && len(toolNames) > 0 && !toolNames[name] { - continue - } - body := inv[2] - args := map[string]interface{}{} - params := antmlParamRe.FindAllStringSubmatch(body, -1) - for _, p := range params { - paramName := p[1] - paramValue := strings.TrimSpace(p[2]) - var jsonVal interface{} - if err := json.Unmarshal([]byte(paramValue), &jsonVal); err == nil { - args[paramName] = jsonVal - } else { - args[paramName] = paramValue - } - } - argsJSON, _ := json.Marshal(args) - calls = append(calls, ToolCall{Name: name, Arguments: string(argsJSON)}) - } - last = loc[1] - } - if last < len(text) { - textParts = append(textParts, text[last:]) - } - if len(calls) > 0 { - result.TextContent = strings.TrimSpace(strings.Join(textParts, "")) - result.ToolCalls = calls - return result - } - } - - result.TextContent = text - return result -} - -func parseToolCallJSON(jsonStr string, toolNames map[string]bool) *ToolCall { - var raw map[string]interface{} - if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil { - return nil - } - name, _ := raw["name"].(string) - if name == "" { - return nil - } - if toolNames != nil && len(toolNames) > 0 && !toolNames[name] { - return nil - } - var argsStr string - switch a := raw["arguments"].(type) { - case string: - argsStr = a - case map[string]interface{}, []interface{}: - b, _ := json.Marshal(a) - argsStr = string(b) - default: - if p, ok := raw["parameters"]; ok { - b, _ := json.Marshal(p) - argsStr = string(b) - } else { - argsStr = "{}" - } - } - return &ToolCall{Name: name, Arguments: argsStr} -} - -func CollectToolNames(tools []interface{}) map[string]bool { - names := map[string]bool{} - for _, t := range tools { - m, ok := t.(map[string]interface{}) - if !ok { - continue - } - if m["type"] == "function" { - if fn, ok := m["function"].(map[string]interface{}); ok { - if name, ok := fn["name"].(string); ok { - names[name] = true - } - } - } - if name, ok := m["name"].(string); ok { - names[name] = true - } - } - return names -} diff --git a/internal/winlimit/winlimit.go b/internal/winlimit/winlimit.go deleted file mode 100644 index 06044b1..0000000 --- a/internal/winlimit/winlimit.go +++ /dev/null @@ -1,181 +0,0 @@ -package winlimit - -import ( - "cursor-api-proxy/internal/env" - "runtime" -) - -const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n" -const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n" - -// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux. -// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars. -func safeLinuxArgMax() int { - return 1536 * 1024 -} - -type FitPromptResult struct { - OK bool - Args []string - Truncated bool - OriginalLength int - FinalPromptLength int - Error string -} - -func estimateCmdlineLength(resolved env.AgentCommand) int { - argv := append([]string{resolved.Command}, resolved.Args...) - if resolved.WindowsVerbatimArguments { - n := 0 - for _, a := range argv { - n += len(a) - } - if len(argv) > 1 { - n += len(argv) - 1 - } - return n + 512 - } - dstLen := 0 - for _, a := range argv { - dstLen += len(a) - } - dstLen = dstLen*2 + len(argv)*2 - if len(argv) > 1 { - dstLen += len(argv) - 1 - } - return dstLen + 512 -} - -func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult { - if runtime.GOOS != "windows" { - return fitPromptLinux(fixedArgs, prompt) - } - - e := env.OsEnvToMap() - measured := func(p string) int { - args := make([]string, len(fixedArgs)+1) - copy(args, fixedArgs) - args[len(fixedArgs)] = p - resolved := env.ResolveAgentCommand(agentBin, args, e, cwd) - return estimateCmdlineLength(resolved) - } - - if measured("") > maxCmdline { - return FitPromptResult{ - OK: false, - Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.", - } - } - - if measured(prompt) <= maxCmdline { - args := make([]string, len(fixedArgs)+1) - copy(args, fixedArgs) - args[len(fixedArgs)] = prompt - return FitPromptResult{ - OK: true, - Args: args, - Truncated: false, - OriginalLength: len(prompt), - FinalPromptLength: len(prompt), - } - } - - prefix := WinPromptOmissionPrefix - if measured(prefix) > maxCmdline { - return FitPromptResult{ - OK: false, - Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.", - } - } - - lo, hi, best := 0, len(prompt), 0 - for lo <= hi { - mid := (lo + hi) / 2 - var tail string - if mid > 0 { - tail = prompt[len(prompt)-mid:] - } - candidate := prefix + tail - if measured(candidate) <= maxCmdline { - best = mid - lo = mid + 1 - } else { - hi = mid - 1 - } - } - - var finalPrompt string - if best == 0 { - finalPrompt = prefix - } else { - finalPrompt = prefix + prompt[len(prompt)-best:] - } - - args := make([]string, len(fixedArgs)+1) - copy(args, fixedArgs) - args[len(fixedArgs)] = finalPrompt - return FitPromptResult{ - OK: true, - Args: args, - Truncated: true, - OriginalLength: len(prompt), - FinalPromptLength: len(finalPrompt), - } -} - -// fitPromptLinux handles Linux ARG_MAX truncation. -func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult { - argMax := safeLinuxArgMax() - - // Estimate total cmdline size: sum of all fixed args + prompt + null terminators - fixedLen := 0 - for _, a := range fixedArgs { - fixedLen += len(a) + 1 - } - totalLen := fixedLen + len(prompt) + 1 - - if totalLen <= argMax { - args := make([]string, len(fixedArgs)+1) - copy(args, fixedArgs) - args[len(fixedArgs)] = prompt - return FitPromptResult{ - OK: true, - Args: args, - Truncated: false, - OriginalLength: len(prompt), - FinalPromptLength: len(prompt), - } - } - - // Need to truncate: keep the tail of the prompt (most recent messages) - prefix := LinuxPromptOmissionPrefix - available := argMax - fixedLen - len(prefix) - 1 - if available < 0 { - available = 0 - } - - var finalPrompt string - if available <= 0 { - finalPrompt = prefix - } else if available >= len(prompt) { - finalPrompt = prefix + prompt - } else { - finalPrompt = prefix + prompt[len(prompt)-available:] - } - - args := make([]string, len(fixedArgs)+1) - copy(args, fixedArgs) - args[len(fixedArgs)] = finalPrompt - return FitPromptResult{ - OK: true, - Args: args, - Truncated: true, - OriginalLength: len(prompt), - FinalPromptLength: len(finalPrompt), - } -} - -func WarnPromptTruncated(originalLength, finalLength int) { - _ = originalLength - _ = finalLength -} diff --git a/internal/winlimit/winlimit_test.go b/internal/winlimit/winlimit_test.go deleted file mode 100644 index 03d2488..0000000 --- a/internal/winlimit/winlimit_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package winlimit - -import ( - "runtime" - "strings" - "testing" -) - -func TestNonWindowsPassThrough(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping non-Windows test on Windows") - } - - fixedArgs := []string{"--print", "--model", "gpt-4"} - prompt := "Hello world" - result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp") - - if !result.OK { - t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error) - } - if result.Truncated { - t.Error("expected no truncation on non-Windows") - } - if result.OriginalLength != len(prompt) { - t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength) - } - // Last arg should be the prompt - if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt { - t.Errorf("expected last arg to be prompt, got %v", result.Args) - } -} - -func TestOmissionPrefix(t *testing.T) { - if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") { - t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix) - } -} diff --git a/internal/workspace/workspace.go b/internal/workspace/workspace.go deleted file mode 100644 index 8fe76bd..0000000 --- a/internal/workspace/workspace.go +++ /dev/null @@ -1,30 +0,0 @@ -package workspace - -import ( - "cursor-api-proxy/internal/config" - "os" - "path/filepath" - "strings" -) - -type WorkspaceResult struct { - WorkspaceDir string - TempDir string -} - -func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult { - if cfg.ChatOnlyWorkspace { - tempDir, err := os.MkdirTemp("", "cursor-proxy-") - if err != nil { - tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback") - _ = os.MkdirAll(tempDir, 0700) - } - return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir} - } - - headerWs := strings.TrimSpace(workspaceHeader) - if headerWs != "" { - return WorkspaceResult{WorkspaceDir: headerWs} - } - return WorkspaceResult{WorkspaceDir: cfg.Workspace} -} diff --git a/main.go b/main.go index 47c2409..75ed6cb 100644 --- a/main.go +++ b/main.go @@ -1,26 +1,56 @@ package main import ( - "cursor-api-proxy/cmd" - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/env" - "cursor-api-proxy/internal/server" + "flag" "fmt" "os" + + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/handler" + "cursor-api-proxy/internal/svc" + + cmd "cursor-api-proxy/cmd/cli" + + "github.com/zeromicro/go-zero/core/conf" + "github.com/zeromicro/go-zero/rest" ) const version = "1.0.0" +var configFile = flag.String("f", "etc/chat-api.yaml", "the config file") + func main() { - args, err := cmd.ParseArgs(os.Args[1:]) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) + // Check for CLI commands first (before flag.Parse) + args := os.Args[1:] + if len(args) > 0 { + parsed, err := cmd.ParseArgs(args) + if err != nil { + // Not a CLI command, proceed to HTTP server + } else if handleCLICommand(parsed) { + return + } } + // HTTP server mode (go-zero) + flag.Parse() + + var c config.Config + conf.MustLoad(*configFile, &c) + + server := rest.MustNewServer(c.RestConf) + defer server.Stop() + + ctx := svc.NewServiceContext(c) + handler.RegisterHandlers(server, ctx) + + fmt.Printf("Starting server at %s:%d...\n", c.Host, c.Port) + server.Start() +} + +func handleCLICommand(args cmd.ParsedArgs) bool { if args.Help { cmd.PrintHelp(version) - return + return true } if args.Login { @@ -28,7 +58,7 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - return + return true } if args.Logout { @@ -36,7 +66,7 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - return + return true } if args.AccountsList { @@ -44,7 +74,7 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - return + return true } if args.ResetHwid { @@ -52,22 +82,9 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - return + return true } - e := env.OsEnvToMap() - if args.Tailscale { - e["CURSOR_BRIDGE_HOST"] = "0.0.0.0" - } - - cwd, _ := os.Getwd() - cfg := config.LoadBridgeConfig(e, cwd) - - servers := server.StartBridgeServer(server.ServerOptions{ - Version: version, - Config: cfg, - }) - server.SetupGracefulShutdown(servers, 10000) - - select {} + // Not a CLI command + return false } diff --git a/pkg/infrastructure/logger/logger.go b/pkg/infrastructure/logger/logger.go index db18a0a..cd17aae 100644 --- a/pkg/infrastructure/logger/logger.go +++ b/pkg/infrastructure/logger/logger.go @@ -1,13 +1,14 @@ package logger import ( - "cursor-api-proxy/internal/config" - "cursor-api-proxy/internal/pool" "fmt" "os" "path/filepath" "strings" "time" + + "cursor-api-proxy/internal/config" + "cursor-api-proxy/pkg/domain/entity" ) const ( @@ -192,7 +193,7 @@ func LogAccountAssigned(configDir string) { fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset) } -func LogAccountStats(verbose bool, stats []pool.AccountStat) { +func LogAccountStats(verbose bool, stats []entity.AccountStat) { if !verbose || len(stats) == 0 { return } diff --git a/pkg/infrastructure/process/process_test.go b/pkg/infrastructure/process/process_test.go index c48d4b5..6c0ea4a 100644 --- a/pkg/infrastructure/process/process_test.go +++ b/pkg/infrastructure/process/process_test.go @@ -2,7 +2,7 @@ package process_test import ( "context" - "cursor-api-proxy/internal/process" + "cursor-api-proxy/pkg/infrastructure/process" "os" "testing" "time" diff --git a/pkg/infrastructure/process/runner.go b/pkg/infrastructure/process/runner.go index 681e42f..4043adb 100644 --- a/pkg/infrastructure/process/runner.go +++ b/pkg/infrastructure/process/runner.go @@ -3,7 +3,7 @@ package process import ( "bufio" "context" - "cursor-api-proxy/internal/env" + "cursor-api-proxy/pkg/infrastructure/env" "fmt" "os/exec" "strings" diff --git a/pkg/infrastructure/winlimit/winlimit.go b/pkg/infrastructure/winlimit/winlimit.go index 06044b1..c62b721 100644 --- a/pkg/infrastructure/winlimit/winlimit.go +++ b/pkg/infrastructure/winlimit/winlimit.go @@ -1,7 +1,7 @@ package winlimit import ( - "cursor-api-proxy/internal/env" + "cursor-api-proxy/pkg/infrastructure/env" "runtime" ) diff --git a/pkg/provider/geminiweb/playwright_provider.go b/pkg/provider/geminiweb/playwright_provider.go index dbfa5b2..6b6ba68 100644 --- a/pkg/provider/geminiweb/playwright_provider.go +++ b/pkg/provider/geminiweb/playwright_provider.go @@ -182,7 +182,7 @@ func (p *PlaywrightProvider) Generate(ctx context.Context, model string, message fmt.Println("Browser is open. You can:") fmt.Println("1. Log in to Gemini now") fmt.Println("2. Continue without login") - fmt.Println("========================================\n") + fmt.Println("========================================") } } else { fmt.Println("[GeminiWeb] ✓ Logged in") diff --git a/pkg/provider/geminiweb/provider.go b/pkg/provider/geminiweb/provider.go index ca0a0cb..3b5032f 100644 --- a/pkg/provider/geminiweb/provider.go +++ b/pkg/provider/geminiweb/provider.go @@ -97,7 +97,7 @@ func (p *Provider) Generate(ctx context.Context, model string, messages []entity fmt.Println("Browser is open. You can:") fmt.Println("1. Log in to Gemini now") fmt.Println("2. Continue without login") - fmt.Println("========================================\n") + fmt.Println("========================================") } } else { fmt.Printf("[GeminiWeb] Logged in\n") diff --git a/scripts/detect-gemini-dom.go b/scripts/detect-gemini-dom.go index 1c482a9..574b6a4 100644 --- a/scripts/detect-gemini-dom.go +++ b/scripts/detect-gemini-dom.go @@ -1,7 +1,7 @@ package main import ( - "cursor-api-proxy/internal/providers/geminiweb" + "cursor-api-proxy/pkg/provider/geminiweb" "fmt" "os" @@ -92,7 +92,7 @@ func analyzeDOM(page *rod.Page) { ariaLabel, _ := el.Attribute("aria-label") placeholder, _ := el.Attribute("placeholder") fmt.Printf(" [%d] tag=%s class=%s aria-label=%s placeholder=%s\n", - i, tag, class, ariaLabel, placeholder) + i, tag, ptrToStr(class), ptrToStr(ariaLabel), ptrToStr(placeholder)) } } } @@ -120,7 +120,7 @@ func analyzeDOM(page *rod.Page) { text, _ := el.Text() text = truncate(text, 30) fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n", - i, tag, class, ariaLabel, text) + i, tag, ptrToStr(class), ptrToStr(ariaLabel), text) } } } @@ -145,7 +145,7 @@ func analyzeDOM(page *rod.Page) { ariaLabel, _ := el.Attribute("aria-label") text, _ := el.Text() fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n", - i, tag, class, ariaLabel, truncate(text, 30)) + i, tag, ptrToStr(class), ptrToStr(ariaLabel), truncate(text, 30)) } } } @@ -157,3 +157,10 @@ func truncate(s string, max int) string { } return s[:max] + "..." } + +func ptrToStr(s *string) string { + if s == nil { + return "" + } + return *s +}