commit a2f1d053913c062c307deeff9adc91e8b35d980f Author: daniel Date: Mon Mar 30 14:09:15 2026 +0000 first commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..426eb26 --- /dev/null +++ b/README.md @@ -0,0 +1,102 @@ +# Cursor API Proxy + +[English](./README.md) | 繁體中文 + +一個讓你可以透過標準 OpenAI/Anthropic API 格式存取 Cursor AI 編輯器的代理伺服器。 + +## 功能特色 + +- **API 相容**:支援 OpenAI 格式和 Anthropic 格式的 API 呼叫 +- **多帳號管理**:支援新增、移除、切換多個 Cursor 帳號 +- **Tailscale 支援**:可綁定到 `0.0.0.0` 供區域網路存取 +- **HWID 重置**:內建反偵測功能,可重置機器識別碼 +- **連線池**:最佳化的連線管理 + +## 安裝 + +```bash +git clone https://github.com/your-repo/cursor-api-proxy-go.git +cd cursor-api-proxy-go +go build -o cursor-api-proxy . +``` + +## 使用方式 + +### 啟動伺服器 + +```bash +./cursor-api-proxy +``` + +預設監聽 `127.0.0.1:8080`。 + +### 登入帳號 + +```bash +# 登入帳號 +./cursor-api-proxy login myaccount + +# 使用代理登入 +./cursor-api-proxy login myaccount --proxy=http://127.0.0.1:7890 +``` + +### 列出帳號 + +```bash +./cursor-api-proxy accounts +``` + +### 登出帳號 + +```bash +./cursor-api-proxy logout myaccount +``` + +### 重置 HWID(反BAN) + +```bash +# 基本重置 +./cursor-api-proxy reset-hwid + +# 深度清理(清除 session 和 cookies) +./cursor-api-proxy reset-hwid --deep-clean +``` + +### 其他選項 + +| 選項 | 說明 | +|------|------| +| `--tailscale` | 綁定到 `0.0.0.0` 供區域網路存取 | +| `-h, --help` | 顯示說明 | + +## API 端點 + +| 端點 | 方法 | 說明 | +|------|------|------| +| `http://127.0.0.1:8080/v1/chat/completions` | POST | OpenAI 格式聊天完成 | +| `http://127.0.0.1:8080/v1/models` | GET | 列出可用模型 | +| `http://127.0.0.1:8080/v1/chat/messages` | POST | Anthropic 格式聊天 | +| `http://127.0.0.1:8080/health` | GET | 健康檢查 | + +## 環境變數 + +| 變數 | 預設值 | 說明 | +|------|--------|------| +| `CURSOR_BRIDGE_HOST` | `127.0.0.1` | 監聽位址 | +| `CURSOR_BRIDGE_PORT` | `8080` | 監聽連接埠 | +| `HTTPS_PROXY` | - | HTTP 代理伺服器 | + +## 常見問題 + +**Q: 為什麼需要登入帳號?** +A: Cursor API 需要驗證才能使用,請先登入你的 Cursor 帳號。 + +**Q: 如何處理被BAN的問題?** +A: 使用 `reset-hwid` 命令重置機器識別碼,加上 `--deep-clean` 進行更徹底的清理。 + +**Q: 可以在其他設備上使用嗎?** +A: 可以,使用 `--tailscale` 選項啟動伺服器,然後透過區域網路 IP 存取。 + +## 授權 + +MIT License diff --git a/cmd/accounts.go b/cmd/accounts.go new file mode 100644 index 0000000..df1d999 --- /dev/null +++ b/cmd/accounts.go @@ -0,0 +1,196 @@ +package cmd + +import ( + "cursor-api-proxy/internal/agent" + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +type AccountInfo struct { + Name string + ConfigDir string + Authenticated bool + Email string + DisplayName string + AuthID string + Plan string + SubscriptionStatus string + ExpiresAt string +} + +func ReadAccountInfo(name, configDir string) AccountInfo { + info := AccountInfo{Name: name, ConfigDir: configDir} + + configFile := filepath.Join(configDir, "cli-config.json") + data, err := os.ReadFile(configFile) + if err != nil { + return info + } + + var raw struct { + AuthInfo *struct { + Email string `json:"email"` + DisplayName string `json:"displayName"` + AuthID string `json:"authId"` + } `json:"authInfo"` + } + if err := json.Unmarshal(data, &raw); err == nil && raw.AuthInfo != nil { + info.Authenticated = true + info.Email = raw.AuthInfo.Email + info.DisplayName = raw.AuthInfo.DisplayName + info.AuthID = raw.AuthInfo.AuthID + } + + statsigFile := filepath.Join(configDir, "statsig-cache.json") + statsigData, err := os.ReadFile(statsigFile) + if err != nil { + return info + } + + var statsigWrapper struct { + Data string `json:"data"` + } + if err := json.Unmarshal(statsigData, &statsigWrapper); err != nil || statsigWrapper.Data == "" { + return info + } + + var statsig struct { + User *struct { + Custom *struct { + IsEnterpriseUser bool `json:"isEnterpriseUser"` + StripeSubscriptionStatus string `json:"stripeSubscriptionStatus"` + StripeMembershipExpiration string `json:"stripeMembershipExpiration"` + } `json:"custom"` + } `json:"user"` + } + if err := json.Unmarshal([]byte(statsigWrapper.Data), &statsig); err != nil { + return info + } + + if statsig.User != nil && statsig.User.Custom != nil { + c := statsig.User.Custom + if c.IsEnterpriseUser { + info.Plan = "Enterprise" + } else if c.StripeSubscriptionStatus == "active" { + info.Plan = "Pro" + } else { + info.Plan = "Free" + } + info.SubscriptionStatus = c.StripeSubscriptionStatus + info.ExpiresAt = c.StripeMembershipExpiration + } + + return info +} + +func HandleAccountsList() error { + accountsDir := agent.AccountsDir() + + entries, err := os.ReadDir(accountsDir) + if err != nil { + fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.") + return nil + } + + var names []string + for _, e := range entries { + if e.IsDir() { + names = append(names, e.Name()) + } + } + + if len(names) == 0 { + fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.") + return nil + } + + fmt.Print("Cursor Accounts:\n\n") + + keychainToken := agent.ReadKeychainToken() + + for i, name := range names { + configDir := filepath.Join(accountsDir, name) + info := ReadAccountInfo(name, configDir) + + fmt.Printf(" %d. %s\n", i+1, name) + + if info.Authenticated { + cachedToken := agent.ReadCachedToken(configDir) + keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID + token := cachedToken + if token == "" && keychainMatchesAccount { + token = keychainToken + } + + var liveProfile *StripeProfile + var liveUsage *UsageData + if token != "" { + liveUsage, _ = FetchAccountUsage(token) + liveProfile, _ = FetchStripeProfile(token) + } + + if info.Email != "" { + display := "" + if info.DisplayName != "" { + display = " (" + info.DisplayName + ")" + } + fmt.Printf(" %s%s\n", info.Email, display) + } + + if info.Plan != "" && liveProfile == nil { + canceled := "" + if info.SubscriptionStatus == "canceled" { + canceled = " · canceled" + } + expiry := "" + if info.ExpiresAt != "" { + expiry = " · expires " + info.ExpiresAt + } + fmt.Printf(" %s%s%s\n", info.Plan, canceled, expiry) + } + fmt.Println(" Authenticated") + + if liveProfile != nil { + fmt.Printf(" %s\n", DescribePlan(liveProfile)) + } + if liveUsage != nil { + for _, line := range FormatUsageSummary(liveUsage) { + fmt.Println(line) + } + } + } else { + fmt.Println(" Not authenticated") + } + + fmt.Println("") + } + + fmt.Println("Tip: run 'cursor-api-proxy logout ' to remove an account.") + return nil +} + +func HandleLogout(accountName string) error { + if accountName == "" { + fmt.Fprintln(os.Stderr, "Error: Please specify the account name to remove.") + fmt.Fprintln(os.Stderr, "Usage: cursor-api-proxy logout ") + os.Exit(1) + } + + accountsDir := agent.AccountsDir() + configDir := filepath.Join(accountsDir, accountName) + + if _, err := os.Stat(configDir); os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "Account '%s' not found.\n", accountName) + os.Exit(1) + } + + if err := os.RemoveAll(configDir); err != nil { + fmt.Fprintf(os.Stderr, "Error removing account: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Account '%s' removed.\n", accountName) + return nil +} diff --git a/cmd/args.go b/cmd/args.go new file mode 100644 index 0000000..05e1c05 --- /dev/null +++ b/cmd/args.go @@ -0,0 +1,118 @@ +package cmd + +import "fmt" + +type ParsedArgs struct { + Tailscale bool + Help bool + Login bool + AccountsList bool + Logout bool + AccountName string + Proxies []string + ResetHwid bool + DeepClean bool + DryRun bool +} + +func ParseArgs(argv []string) (ParsedArgs, error) { + var args ParsedArgs + + for i := 0; i < len(argv); i++ { + arg := argv[i] + + switch arg { + case "login", "add-account": + args.Login = true + if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' { + i++ + args.AccountName = argv[i] + } + + case "logout", "remove-account": + args.Logout = true + if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' { + i++ + args.AccountName = argv[i] + } + + case "accounts", "list-accounts": + args.AccountsList = true + + case "reset-hwid", "reset": + args.ResetHwid = true + + case "--deep-clean": + args.DeepClean = true + + case "--dry-run": + args.DryRun = true + + case "--tailscale": + args.Tailscale = true + + case "--help", "-h": + args.Help = true + + default: + if len(arg) > len("--proxy=") && arg[:len("--proxy=")] == "--proxy=" { + raw := arg[len("--proxy="):] + parts := splitComma(raw) + for _, p := range parts { + if p != "" { + args.Proxies = append(args.Proxies, p) + } + } + } else { + return args, fmt.Errorf("Unknown argument: %s", arg) + } + } + } + + return args, nil +} + +func splitComma(s string) []string { + var result []string + start := 0 + for i := 0; i <= len(s); i++ { + if i == len(s) || s[i] == ',' { + part := trim(s[start:i]) + if part != "" { + result = append(result, part) + } + start = i + 1 + } + } + return result +} + +func trim(s string) string { + start := 0 + end := len(s) + for start < end && (s[start] == ' ' || s[start] == '\t') { + start++ + } + for end > start && (s[end-1] == ' ' || s[end-1] == '\t') { + end-- + } + return s[start:end] +} + +func PrintHelp(version string) { + fmt.Printf("cursor-api-proxy v%s\n\n", version) + fmt.Println("Usage:") + fmt.Println(" cursor-api-proxy [options]") + fmt.Println("") + fmt.Println("Commands:") + fmt.Println(" login [name] Log into a Cursor account (saved to ~/.cursor-api-proxy/accounts/)") + fmt.Println(" login [name] --proxy=... Same, but with a proxy from a comma-separated list") + fmt.Println(" logout Remove a saved Cursor account") + fmt.Println(" accounts List saved accounts with plan info") + fmt.Println(" reset-hwid Reset Cursor machine/telemetry IDs (anti-ban)") + fmt.Println(" reset-hwid --deep-clean Also wipe session storage and cookies") + fmt.Println("") + fmt.Println("Options:") + fmt.Println(" --tailscale Bind to 0.0.0.0 for tailnet/LAN access") + fmt.Println(" -h, --help Show this help message") +} diff --git a/cmd/login.go b/cmd/login.go new file mode 100644 index 0000000..f1eb22d --- /dev/null +++ b/cmd/login.go @@ -0,0 +1,125 @@ +package cmd + +import ( + "bufio" + "cursor-api-proxy/internal/agent" + "cursor-api-proxy/internal/env" + "fmt" + "os" + "os/exec" + "os/signal" + "path/filepath" + "regexp" + "syscall" + "time" +) + +var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`) + +func HandleLogin(accountName string, proxies []string) error { + e := env.OsEnvToMap() + loaded := env.LoadEnvConfig(e, "") + agentBin := loaded.AgentBin + + if accountName == "" { + accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000) + } + + accountsDir := agent.AccountsDir() + configDir := filepath.Join(accountsDir, accountName) + dirWasNew := !fileExists(configDir) + + if err := os.MkdirAll(accountsDir, 0755); err != nil { + return fmt.Errorf("failed to create accounts dir: %w", err) + } + if err := os.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("failed to create config dir: %w", err) + } + + fmt.Printf("Logging into Cursor account: %s\n", accountName) + fmt.Printf("Config: %s\n\n", configDir) + fmt.Println("Run the login command — complete the login in your browser.") + fmt.Println("") + + cleanupDir := func() { + if dirWasNew { + _ = os.RemoveAll(configDir) + } + } + + cmdEnv := make([]string, 0, len(e)+2) + for k, v := range e { + cmdEnv = append(cmdEnv, k+"="+v) + } + cmdEnv = append(cmdEnv, "CURSOR_CONFIG_DIR="+configDir) + cmdEnv = append(cmdEnv, "NO_OPEN_BROWSER=1") + + child := exec.Command(agentBin, "login") + child.Env = cmdEnv + child.Stdin = os.Stdin + child.Stderr = os.Stderr + + stdoutPipe, err := child.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + if err := child.Start(); err != nil { + cleanupDir() + if os.IsNotExist(err) { + return fmt.Errorf("could not find '%s'. Make sure the Cursor CLI is installed", agentBin) + } + return fmt.Errorf("error launching agent login: %w", err) + } + + // Handle cancellation signals + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + go func() { + sig := <-sigCh + _ = child.Process.Kill() + cleanupDir() + if sig == syscall.SIGINT { + fmt.Println("\n\nLogin cancelled.") + } + os.Exit(0) + }() + defer signal.Stop(sigCh) + + var stdoutBuf string + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + line := scanner.Text() + fmt.Println(line) + stdoutBuf += line + "\n" + + if loginURLRe.MatchString(stdoutBuf) { + match := loginURLRe.FindString(stdoutBuf) + if match != "" { + fmt.Printf("\nOpen this URL in your browser (incognito recommended):\n%s\n\n", match) + } + } + } + + if err := child.Wait(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + cleanupDir() + return fmt.Errorf("login failed with code %d", exitErr.ExitCode()) + } + return err + } + + // Cache keychain token for this account + token := agent.ReadKeychainToken() + if token != "" { + agent.WriteCachedToken(configDir, token) + } + + fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName) + return nil +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} diff --git a/cmd/resethwid.go b/cmd/resethwid.go new file mode 100644 index 0000000..ba53c2b --- /dev/null +++ b/cmd/resethwid.go @@ -0,0 +1,261 @@ +package cmd + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "time" + + "github.com/google/uuid" +) + +func sha256hex() string { + b := make([]byte, 32) + _, _ = rand.Read(b) + h := sha256.Sum256(b) + return hex.EncodeToString(h[:]) +} + +func sha512hex() string { + b := make([]byte, 64) + _, _ = rand.Read(b) + h := sha512.Sum512(b) + return hex.EncodeToString(h[:]) +} + +func newUUID() string { + return uuid.New().String() +} + +func log(icon, msg string) { + fmt.Printf(" %s %s\n", icon, msg) +} + +func getCursorGlobalStorage() string { + switch runtime.GOOS { + case "darwin": + home, _ := os.UserHomeDir() + return filepath.Join(home, "Library", "Application Support", "Cursor", "User", "globalStorage") + case "windows": + appdata := os.Getenv("APPDATA") + return filepath.Join(appdata, "Cursor", "User", "globalStorage") + default: + xdg := os.Getenv("XDG_CONFIG_HOME") + if xdg == "" { + home, _ := os.UserHomeDir() + xdg = filepath.Join(home, ".config") + } + return filepath.Join(xdg, "Cursor", "User", "globalStorage") + } +} + +func getCursorRoot() string { + gs := getCursorGlobalStorage() + return filepath.Dir(filepath.Dir(gs)) +} + +func generateNewIDs() map[string]string { + return map[string]string{ + "telemetry.machineId": sha256hex(), + "telemetry.macMachineId": sha512hex(), + "telemetry.devDeviceId": newUUID(), + "telemetry.sqmId": "{" + fmt.Sprintf("%s", newUUID()+"") + "}", + "storage.serviceMachineId": newUUID(), + } +} + +func killCursor() { + log("", "Stopping Cursor processes...") + switch runtime.GOOS { + case "windows": + exec.Command("taskkill", "/F", "/IM", "Cursor.exe").Run() + default: + exec.Command("pkill", "-x", "Cursor").Run() + exec.Command("pkill", "-f", "Cursor.app").Run() + } + log("", "Cursor stopped (or was not running)") +} + +func updateStorageJSON(storagePath string, ids map[string]string) { + if _, err := os.Stat(storagePath); os.IsNotExist(err) { + log("", fmt.Sprintf("storage.json not found: %s", storagePath)) + return + } + + if runtime.GOOS == "darwin" { + exec.Command("chflags", "nouchg", storagePath).Run() + exec.Command("chmod", "644", storagePath).Run() + } + + data, err := os.ReadFile(storagePath) + if err != nil { + log("", fmt.Sprintf("storage.json read error: %v", err)) + return + } + + var obj map[string]interface{} + if err := json.Unmarshal(data, &obj); err != nil { + log("", fmt.Sprintf("storage.json parse error: %v", err)) + return + } + + for k, v := range ids { + obj[k] = v + } + + out, err := json.MarshalIndent(obj, "", " ") + if err != nil { + log("", fmt.Sprintf("storage.json marshal error: %v", err)) + return + } + + if err := os.WriteFile(storagePath, out, 0644); err != nil { + log("", fmt.Sprintf("storage.json write error: %v", err)) + return + } + log("", "storage.json updated") +} + +func updateStateVscdb(dbPath string, ids map[string]string) { + if _, err := os.Stat(dbPath); os.IsNotExist(err) { + log("", fmt.Sprintf("state.vscdb not found: %s", dbPath)) + return + } + + if runtime.GOOS == "darwin" { + exec.Command("chflags", "nouchg", dbPath).Run() + exec.Command("chmod", "644", dbPath).Run() + } + + if err := updateVscdbPureGo(dbPath, ids); err != nil { + log("", fmt.Sprintf("state.vscdb error: %v", err)) + } else { + log("", "state.vscdb updated") + } +} + +func updateMachineIDFile(machineID, cursorRoot string) { + var candidates []string + if runtime.GOOS == "linux" { + candidates = []string{ + filepath.Join(cursorRoot, "machineid"), + filepath.Join(cursorRoot, "machineId"), + } + } else { + candidates = []string{filepath.Join(cursorRoot, "machineId")} + } + + filePath := candidates[0] + for _, c := range candidates { + if _, err := os.Stat(c); err == nil { + filePath = c + break + } + } + + if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { + log("", fmt.Sprintf("machineId dir error: %v", err)) + return + } + + if runtime.GOOS == "darwin" { + if _, err := os.Stat(filePath); err == nil { + exec.Command("chflags", "nouchg", filePath).Run() + exec.Command("chmod", "644", filePath).Run() + } + } + + if err := os.WriteFile(filePath, []byte(machineID+"\n"), 0644); err != nil { + log("", fmt.Sprintf("machineId write error: %v", err)) + return + } + log("", fmt.Sprintf("machineId file updated (%s)", filepath.Base(filePath))) +} + +var dirsToWipe = []string{ + "Session Storage", "Local Storage", "IndexedDB", "Cache", "Code Cache", + "GPUCache", "Service Worker", "Network", "Cookies", "Cookies-journal", +} + +func deepClean(cursorRoot string) { + log("", "Deep-cleaning session data...") + wiped := 0 + for _, name := range dirsToWipe { + target := filepath.Join(cursorRoot, name) + if _, err := os.Stat(target); os.IsNotExist(err) { + continue + } + info, err := os.Stat(target) + if err != nil { + continue + } + if info.IsDir() { + if err := os.RemoveAll(target); err == nil { + wiped++ + } + } else { + if err := os.Remove(target); err == nil { + wiped++ + } + } + } + log("", fmt.Sprintf("Wiped %d cache/session items", wiped)) +} + +func HandleResetHwid(doDeepClean, dryRun bool) error { + fmt.Print("\nCursor HWID Reset\n\n") + fmt.Println(" Resets all machine / telemetry IDs so Cursor sees a fresh install.") + fmt.Print(" Cursor must be closed — it will be killed automatically.\n\n") + + globalStorage := getCursorGlobalStorage() + cursorRoot := getCursorRoot() + + if _, err := os.Stat(globalStorage); os.IsNotExist(err) { + fmt.Printf("Cursor config not found at:\n %s\n", globalStorage) + fmt.Println(" Make sure Cursor is installed and has been run at least once.") + os.Exit(1) + } + + if dryRun { + fmt.Println(" [DRY RUN] Would reset IDs in:") + fmt.Printf(" %s\n", filepath.Join(globalStorage, "storage.json")) + fmt.Printf(" %s\n", filepath.Join(globalStorage, "state.vscdb")) + fmt.Printf(" %s\n", filepath.Join(cursorRoot, "machineId")) + return nil + } + + killCursor() + + time.Sleep(800 * time.Millisecond) + + newIDs := generateNewIDs() + log("", "Generated new IDs:") + for k, v := range newIDs { + fmt.Printf(" %s: %s\n", k, v) + } + fmt.Println() + + log("", "Updating storage.json...") + updateStorageJSON(filepath.Join(globalStorage, "storage.json"), newIDs) + + log("", "Updating state.vscdb...") + updateStateVscdb(filepath.Join(globalStorage, "state.vscdb"), newIDs) + + log("", "Updating machineId file...") + updateMachineIDFile(newIDs["telemetry.machineId"], cursorRoot) + + if doDeepClean { + fmt.Println() + deepClean(cursorRoot) + } + + fmt.Print("\nHWID reset complete. You can now restart Cursor.\n\n") + return nil +} diff --git a/cmd/sqlite.go b/cmd/sqlite.go new file mode 100644 index 0000000..a8173b2 --- /dev/null +++ b/cmd/sqlite.go @@ -0,0 +1,29 @@ +package cmd + +import ( + "database/sql" + "fmt" + + _ "modernc.org/sqlite" +) + +func updateVscdbPureGo(dbPath string, ids map[string]string) error { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS ItemTable (key TEXT PRIMARY KEY, value TEXT NOT NULL)`) + if err != nil { + return fmt.Errorf("create table: %w", err) + } + + for k, v := range ids { + _, err = db.Exec(`INSERT OR REPLACE INTO ItemTable (key, value) VALUES (?, ?)`, k, v) + if err != nil { + return fmt.Errorf("insert %s: %w", k, err) + } + } + return nil +} diff --git a/cmd/usage.go b/cmd/usage.go new file mode 100644 index 0000000..4999020 --- /dev/null +++ b/cmd/usage.go @@ -0,0 +1,255 @@ +package cmd + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type ModelUsage struct { + NumRequests int `json:"numRequests"` + NumRequestsTotal int `json:"numRequestsTotal"` + NumTokens int `json:"numTokens"` + MaxTokenUsage *int `json:"maxTokenUsage"` + MaxRequestUsage *int `json:"maxRequestUsage"` +} + +type UsageData struct { + StartOfMonth string `json:"startOfMonth"` + Models map[string]ModelUsage `json:"-"` +} + +type StripeProfile struct { + MembershipType string `json:"membershipType"` + SubscriptionStatus string `json:"subscriptionStatus"` + DaysRemainingOnTrial *int `json:"daysRemainingOnTrial"` + IsTeamMember bool `json:"isTeamMember"` + IsYearlyPlan bool `json:"isYearlyPlan"` +} + +func DecodeJWTPayload(token string) map[string]interface{} { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return nil + } + padded := strings.ReplaceAll(parts[1], "-", "+") + padded = strings.ReplaceAll(padded, "_", "/") + data, err := base64.StdEncoding.DecodeString(padded + strings.Repeat("=", (4-len(padded)%4)%4)) + if err != nil { + return nil + } + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + return nil + } + return result +} + +func TokenSub(token string) string { + payload := DecodeJWTPayload(token) + if payload == nil { + return "" + } + if sub, ok := payload["sub"].(string); ok { + return sub + } + return "" +} + +func apiGet(path, token string) (map[string]interface{}, error) { + client := &http.Client{Timeout: 8 * time.Second} + req, err := http.NewRequest("GET", "https://api2.cursor.sh"+path, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + return nil, nil + } + return result, nil +} + +func FetchAccountUsage(token string) (*UsageData, error) { + raw, err := apiGet("/auth/usage", token) + if err != nil || raw == nil { + return nil, err + } + + startOfMonth, _ := raw["startOfMonth"].(string) + usage := &UsageData{ + StartOfMonth: startOfMonth, + Models: make(map[string]ModelUsage), + } + + for k, v := range raw { + if k == "startOfMonth" { + continue + } + data, err := json.Marshal(v) + if err != nil { + continue + } + var mu ModelUsage + if err := json.Unmarshal(data, &mu); err == nil { + usage.Models[k] = mu + } + } + return usage, nil +} + +func FetchStripeProfile(token string) (*StripeProfile, error) { + raw, err := apiGet("/auth/full_stripe_profile", token) + if err != nil || raw == nil { + return nil, err + } + + profile := &StripeProfile{ + MembershipType: fmt.Sprintf("%v", raw["membershipType"]), + SubscriptionStatus: fmt.Sprintf("%v", raw["subscriptionStatus"]), + IsTeamMember: raw["isTeamMember"] == true, + IsYearlyPlan: raw["isYearlyPlan"] == true, + } + if d, ok := raw["daysRemainingOnTrial"].(float64); ok { + di := int(d) + profile.DaysRemainingOnTrial = &di + } + return profile, nil +} + +func DescribePlan(profile *StripeProfile) string { + if profile == nil { + return "" + } + switch profile.MembershipType { + case "free_trial": + days := 0 + if profile.DaysRemainingOnTrial != nil { + days = *profile.DaysRemainingOnTrial + } + return fmt.Sprintf("Pro Trial (%dd left) — unlimited fast requests", days) + case "pro": + return "Pro — extended limits" + case "pro_plus": + return "Pro+ — extended limits" + case "ultra": + return "Ultra — extended limits" + case "free", "hobby": + return "Hobby (free) — limited agent requests" + default: + return fmt.Sprintf("%s · %s", profile.MembershipType, profile.SubscriptionStatus) + } +} + +var modelLabels = map[string]string{ + "gpt-4": "Fast Premium Requests", + "claude-sonnet-4-6": "Claude Sonnet 4.6", + "claude-sonnet-4-5-20250929-v1": "Claude Sonnet 4.5", + "claude-sonnet-4-20250514-v1": "Claude Sonnet 4", + "claude-opus-4-6-v1": "Claude Opus 4.6", + "claude-opus-4-5-20251101-v1": "Claude Opus 4.5", + "claude-opus-4-1-20250805-v1": "Claude Opus 4.1", + "claude-opus-4-20250514-v1": "Claude Opus 4", + "claude-haiku-4-5-20251001-v1": "Claude Haiku 4.5", + "claude-3-5-haiku-20241022-v1": "Claude 3.5 Haiku", + "gpt-5": "GPT-5", + "gpt-4o": "GPT-4o", + "o1": "o1", + "o3-mini": "o3-mini", + "cursor-small": "Cursor Small (free)", +} + +func modelLabel(key string) string { + if label, ok := modelLabels[key]; ok { + return label + } + return key +} + +func FormatUsageSummary(usage *UsageData) []string { + if usage == nil { + return nil + } + var lines []string + + start := "?" + if usage.StartOfMonth != "" { + if t, err := time.Parse(time.RFC3339, usage.StartOfMonth); err == nil { + start = t.Format("2006-01-02") + } else { + start = usage.StartOfMonth + } + } + lines = append(lines, fmt.Sprintf(" Billing period from %s", start)) + + if len(usage.Models) == 0 { + lines = append(lines, " No requests this billing period") + return lines + } + + type entry struct { + key string + usage ModelUsage + } + var entries []entry + for k, v := range usage.Models { + entries = append(entries, entry{k, v}) + } + + // Sort: entries with limits first, then by usage descending + for i := 1; i < len(entries); i++ { + for j := i; j > 0; j-- { + a, b := entries[j-1], entries[j] + aHasLimit := a.usage.MaxRequestUsage != nil + bHasLimit := b.usage.MaxRequestUsage != nil + if !aHasLimit && bHasLimit { + entries[j-1], entries[j] = entries[j], entries[j-1] + } else if aHasLimit == bHasLimit && a.usage.NumRequests < b.usage.NumRequests { + entries[j-1], entries[j] = entries[j], entries[j-1] + } else { + break + } + } + } + + for _, e := range entries { + used := e.usage.NumRequests + max := e.usage.MaxRequestUsage + label := modelLabel(e.key) + if max != nil && *max > 0 { + pct := int(float64(used) / float64(*max) * 100) + bar := makeBar(used, *max, 12) + lines = append(lines, fmt.Sprintf(" %s: %d/%d (%d%%) [%s]", label, used, *max, pct, bar)) + } else if used > 0 { + lines = append(lines, fmt.Sprintf(" %s: %d requests", label, used)) + } else { + lines = append(lines, fmt.Sprintf(" %s: 0 requests (unlimited)", label)) + } + } + + return lines +} + +func makeBar(used, max, width int) string { + fill := int(float64(used) / float64(max) * float64(width)) + if fill > width { + fill = width + } + return strings.Repeat("█", fill) + strings.Repeat("░", width-fill) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..27de6b7 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module cursor-api-proxy + +go 1.25.0 + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/sys v0.42.0 // indirect + modernc.org/libc v1.70.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.48.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d96a814 --- /dev/null +++ b/go.sum @@ -0,0 +1,21 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= +modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.48.0 h1:ElZyLop3Q2mHYk5IFPPXADejZrlHu7APbpB0sF78bq4= +modernc.org/sqlite v1.48.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= diff --git a/internal/agent/cmdargs.go b/internal/agent/cmdargs.go new file mode 100644 index 0000000..aeab0b7 --- /dev/null +++ b/internal/agent/cmdargs.go @@ -0,0 +1,29 @@ +package agent + +import "cursor-api-proxy/internal/config" + +func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, stream bool) []string { + args := []string{"--print"} + if cfg.ApproveMcps { + args = append(args, "--approve-mcps") + } + if cfg.Force { + args = append(args, "--force") + } + if cfg.ChatOnlyWorkspace { + args = append(args, "--trust") + } + args = append(args, "--mode", "ask") + args = append(args, "--workspace", workspaceDir) + args = append(args, "--model", model) + if stream { + args = append(args, "--stream-partial-output", "--output-format", "stream-json") + } else { + args = append(args, "--output-format", "text") + } + return args +} + +func BuildAgentCmdArgs(cfg config.BridgeConfig, workspaceDir, model, prompt string, stream bool) []string { + return append(BuildAgentFixedArgs(cfg, workspaceDir, model, stream), prompt) +} diff --git a/internal/agent/maxmode.go b/internal/agent/maxmode.go new file mode 100644 index 0000000..0a094c7 --- /dev/null +++ b/internal/agent/maxmode.go @@ -0,0 +1,85 @@ +package agent + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" +) + +func getCandidates(agentScriptPath, configDirOverride string) []string { + if configDirOverride != "" { + return []string{filepath.Join(configDirOverride, "cli-config.json")} + } + + var result []string + + if dir := os.Getenv("CURSOR_CONFIG_DIR"); dir != "" { + result = append(result, filepath.Join(dir, "cli-config.json")) + } + + if agentScriptPath != "" { + agentDir := filepath.Dir(agentScriptPath) + result = append(result, filepath.Join(agentDir, "..", "data", "config", "cli-config.json")) + } + + home := os.Getenv("HOME") + if home == "" { + home = os.Getenv("USERPROFILE") + } + + switch runtime.GOOS { + case "windows": + local := os.Getenv("LOCALAPPDATA") + if local == "" { + local = filepath.Join(home, "AppData", "Local") + } + result = append(result, filepath.Join(local, "cursor-agent", "cli-config.json")) + case "darwin": + result = append(result, filepath.Join(home, "Library", "Application Support", "cursor-agent", "cli-config.json")) + default: + xdg := os.Getenv("XDG_CONFIG_HOME") + if xdg == "" { + xdg = filepath.Join(home, ".config") + } + result = append(result, filepath.Join(xdg, "cursor-agent", "cli-config.json")) + } + + return result +} + +func RunMaxModePreflight(agentScriptPath, configDirOverride string) { + for _, candidate := range getCandidates(agentScriptPath, configDirOverride) { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + + // Strip BOM if present + if len(data) >= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF { + data = data[3:] + } + + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + continue + } + if raw == nil || len(raw) <= 1 { + continue + } + + raw["maxMode"] = true + if model, ok := raw["model"].(map[string]interface{}); ok { + model["maxMode"] = true + } + + out, err := json.MarshalIndent(raw, "", " ") + if err != nil { + continue + } + if err := os.WriteFile(candidate, out, 0644); err != nil { + continue + } + return + } +} diff --git a/internal/agent/runner.go b/internal/agent/runner.go new file mode 100644 index 0000000..2428b4c --- /dev/null +++ b/internal/agent/runner.go @@ -0,0 +1,72 @@ +package agent + +import ( + "context" + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/process" + "os" + "path/filepath" +) + +func init() { + process.MaxModeFn = RunMaxModePreflight +} + +func cacheTokenForAccount(configDir string) { + if configDir == "" { + return + } + token := ReadKeychainToken() + if token != "" { + WriteCachedToken(configDir, token) + } +} + +func AccountsDir() string { + home := os.Getenv("HOME") + if home == "" { + home = os.Getenv("USERPROFILE") + } + return filepath.Join(home, ".cursor-api-proxy", "accounts") +} + +func RunAgentSync(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, tempDir, configDir string, ctx context.Context) (process.RunResult, error) { + opts := process.RunOptions{ + Cwd: workspaceDir, + TimeoutMs: cfg.TimeoutMs, + MaxMode: cfg.MaxMode, + ConfigDir: configDir, + Ctx: ctx, + } + + result, err := process.Run(cfg.AgentBin, cmdArgs, opts) + + cacheTokenForAccount(configDir) + if tempDir != "" { + os.RemoveAll(tempDir) + } + + return result, err +} + +func RunAgentStreamWithContext(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, onLine func(string), tempDir, configDir string, ctx context.Context) (process.StreamResult, error) { + opts := process.RunStreamingOptions{ + RunOptions: process.RunOptions{ + Cwd: workspaceDir, + TimeoutMs: cfg.TimeoutMs, + MaxMode: cfg.MaxMode, + ConfigDir: configDir, + Ctx: ctx, + }, + OnLine: onLine, + } + + result, err := process.RunStreaming(cfg.AgentBin, cmdArgs, opts) + + cacheTokenForAccount(configDir) + if tempDir != "" { + os.RemoveAll(tempDir) + } + + return result, err +} diff --git a/internal/agent/token.go b/internal/agent/token.go new file mode 100644 index 0000000..cb9d025 --- /dev/null +++ b/internal/agent/token.go @@ -0,0 +1,36 @@ +package agent + +import ( + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" +) + +const tokenFile = ".cursor-token" + +func ReadCachedToken(configDir string) string { + p := filepath.Join(configDir, tokenFile) + data, err := os.ReadFile(p) + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) +} + +func WriteCachedToken(configDir, token string) { + p := filepath.Join(configDir, tokenFile) + _ = os.WriteFile(p, []byte(token), 0600) +} + +func ReadKeychainToken() string { + if runtime.GOOS != "darwin" { + return "" + } + out, err := exec.Command("security", "find-generic-password", "-s", "cursor-access-token", "-w").Output() + if err != nil { + return "" + } + return strings.TrimSpace(string(out)) +} diff --git a/internal/anthropic/anthropic.go b/internal/anthropic/anthropic.go new file mode 100644 index 0000000..fb1ac95 --- /dev/null +++ b/internal/anthropic/anthropic.go @@ -0,0 +1,134 @@ +package anthropic + +import ( + "cursor-api-proxy/internal/openai" + "strings" +) + +type MessageParam struct { + Role string `json:"role"` + Content interface{} `json:"content"` +} + +type MessagesRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []MessageParam `json:"messages"` + System interface{} `json:"system"` + Stream bool `json:"stream"` + Tools []interface{} `json:"tools"` +} + +func systemToText(system interface{}) string { + if system == nil { + return "" + } + switch v := system.(type) { + case string: + return strings.TrimSpace(v) + case []interface{}: + var parts []string + for _, p := range v { + if m, ok := p.(map[string]interface{}); ok { + if m["type"] == "text" { + if t, ok := m["text"].(string); ok { + parts = append(parts, t) + } + } + } + } + return strings.Join(parts, "\n") + } + return "" +} + +func anthropicBlockToText(p interface{}) string { + if p == nil { + return "" + } + switch v := p.(type) { + case string: + return v + case map[string]interface{}: + typ, _ := v["type"].(string) + switch typ { + case "text": + if t, ok := v["text"].(string); ok { + return t + } + case "image": + if src, ok := v["source"].(map[string]interface{}); ok { + srcType, _ := src["type"].(string) + switch srcType { + case "base64": + mt, _ := src["media_type"].(string) + if mt == "" { + mt = "image" + } + return "[Image: base64 " + mt + "]" + case "url": + url, _ := src["url"].(string) + return "[Image: " + url + "]" + } + } + return "[Image]" + case "document": + title, _ := v["title"].(string) + if title == "" { + if src, ok := v["source"].(map[string]interface{}); ok { + title, _ = src["url"].(string) + } + } + if title != "" { + return "[Document: " + title + "]" + } + return "[Document]" + } + } + return "" +} + +func anthropicContentToText(content interface{}) string { + switch v := content.(type) { + case string: + return v + case []interface{}: + var parts []string + for _, p := range v { + if t := anthropicBlockToText(p); t != "" { + parts = append(parts, t) + } + } + return strings.Join(parts, " ") + } + return "" +} + +func BuildPromptFromAnthropicMessages(messages []MessageParam, system interface{}) string { + var oaiMessages []interface{} + + systemText := systemToText(system) + if systemText != "" { + oaiMessages = append(oaiMessages, map[string]interface{}{ + "role": "system", + "content": systemText, + }) + } + + for _, m := range messages { + text := anthropicContentToText(m.Content) + if text == "" { + continue + } + role := m.Role + if role != "user" && role != "assistant" { + role = "user" + } + oaiMessages = append(oaiMessages, map[string]interface{}{ + "role": role, + "content": text, + }) + } + + return openai.BuildPromptFromMessages(oaiMessages) +} diff --git a/internal/anthropic/anthropic_test.go b/internal/anthropic/anthropic_test.go new file mode 100644 index 0000000..8237a5a --- /dev/null +++ b/internal/anthropic/anthropic_test.go @@ -0,0 +1,109 @@ +package anthropic_test + +import ( + "cursor-api-proxy/internal/anthropic" + "strings" + "testing" +) + +func TestBuildPromptFromAnthropicMessages_Simple(t *testing.T) { + messages := []anthropic.MessageParam{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there"}, + } + prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) + if !strings.Contains(prompt, "Hello") { + t.Errorf("prompt missing user message: %q", prompt) + } + if !strings.Contains(prompt, "Hi there") { + t.Errorf("prompt missing assistant message: %q", prompt) + } +} + +func TestBuildPromptFromAnthropicMessages_WithSystem(t *testing.T) { + messages := []anthropic.MessageParam{ + {Role: "user", Content: "ping"}, + } + prompt := anthropic.BuildPromptFromAnthropicMessages(messages, "You are a helpful bot.") + if !strings.Contains(prompt, "You are a helpful bot.") { + t.Errorf("prompt missing system: %q", prompt) + } + if !strings.Contains(prompt, "ping") { + t.Errorf("prompt missing user: %q", prompt) + } +} + +func TestBuildPromptFromAnthropicMessages_SystemArray(t *testing.T) { + system := []interface{}{ + map[string]interface{}{"type": "text", "text": "Part A"}, + map[string]interface{}{"type": "text", "text": "Part B"}, + } + messages := []anthropic.MessageParam{ + {Role: "user", Content: "test"}, + } + prompt := anthropic.BuildPromptFromAnthropicMessages(messages, system) + if !strings.Contains(prompt, "Part A") { + t.Errorf("prompt missing Part A: %q", prompt) + } + if !strings.Contains(prompt, "Part B") { + t.Errorf("prompt missing Part B: %q", prompt) + } +} + +func TestBuildPromptFromAnthropicMessages_ContentBlocks(t *testing.T) { + content := []interface{}{ + map[string]interface{}{"type": "text", "text": "block one"}, + map[string]interface{}{"type": "text", "text": "block two"}, + } + messages := []anthropic.MessageParam{ + {Role: "user", Content: content}, + } + prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) + if !strings.Contains(prompt, "block one") { + t.Errorf("prompt missing 'block one': %q", prompt) + } + if !strings.Contains(prompt, "block two") { + t.Errorf("prompt missing 'block two': %q", prompt) + } +} + +func TestBuildPromptFromAnthropicMessages_ImageBlock(t *testing.T) { + content := []interface{}{ + map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "base64", + "media_type": "image/png", + "data": "abc123", + }, + }, + } + messages := []anthropic.MessageParam{ + {Role: "user", Content: content}, + } + prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) + if !strings.Contains(prompt, "[Image") { + t.Errorf("prompt missing [Image]: %q", prompt) + } +} + +func TestBuildPromptFromAnthropicMessages_EmptyContentSkipped(t *testing.T) { + messages := []anthropic.MessageParam{ + {Role: "user", Content: ""}, + {Role: "assistant", Content: "response"}, + } + prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) + if !strings.Contains(prompt, "response") { + t.Errorf("prompt missing 'response': %q", prompt) + } +} + +func TestBuildPromptFromAnthropicMessages_UnknownRoleBecomesUser(t *testing.T) { + messages := []anthropic.MessageParam{ + {Role: "system", Content: "system-as-user"}, + } + prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil) + if !strings.Contains(prompt, "system-as-user") { + t.Errorf("prompt missing 'system-as-user': %q", prompt) + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..b217919 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,54 @@ +package config + +import ( + "cursor-api-proxy/internal/env" +) + +type BridgeConfig struct { + AgentBin string + Host string + Port int + RequiredKey string + DefaultModel string + Mode string + Force bool + ApproveMcps bool + StrictModel bool + Workspace string + TimeoutMs int + TLSCertPath string + TLSKeyPath string + SessionsLogPath string + ChatOnlyWorkspace bool + Verbose bool + MaxMode bool + ConfigDirs []string + MultiPort bool + WinCmdlineMax int +} + +func LoadBridgeConfig(e env.EnvSource, cwd string) BridgeConfig { + loaded := env.LoadEnvConfig(e, cwd) + return BridgeConfig{ + AgentBin: loaded.AgentBin, + Host: loaded.Host, + Port: loaded.Port, + RequiredKey: loaded.RequiredKey, + DefaultModel: loaded.DefaultModel, + Mode: "ask", + Force: loaded.Force, + ApproveMcps: loaded.ApproveMcps, + StrictModel: loaded.StrictModel, + Workspace: loaded.Workspace, + TimeoutMs: loaded.TimeoutMs, + TLSCertPath: loaded.TLSCertPath, + TLSKeyPath: loaded.TLSKeyPath, + SessionsLogPath: loaded.SessionsLogPath, + ChatOnlyWorkspace: loaded.ChatOnlyWorkspace, + Verbose: loaded.Verbose, + MaxMode: loaded.MaxMode, + ConfigDirs: loaded.ConfigDirs, + MultiPort: loaded.MultiPort, + WinCmdlineMax: loaded.WinCmdlineMax, + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..15f2abc --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,123 @@ +package config_test + +import ( + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/env" + "path/filepath" + "strings" + "testing" +) + +func TestLoadBridgeConfig_Defaults(t *testing.T) { + cfg := config.LoadBridgeConfig(env.EnvSource{}, "/workspace") + + if cfg.AgentBin != "agent" { + t.Errorf("AgentBin = %q, want %q", cfg.AgentBin, "agent") + } + if cfg.Host != "127.0.0.1" { + t.Errorf("Host = %q, want %q", cfg.Host, "127.0.0.1") + } + if cfg.Port != 8765 { + t.Errorf("Port = %d, want 8765", cfg.Port) + } + if cfg.RequiredKey != "" { + t.Errorf("RequiredKey = %q, want empty", cfg.RequiredKey) + } + if cfg.DefaultModel != "auto" { + t.Errorf("DefaultModel = %q, want %q", cfg.DefaultModel, "auto") + } + if cfg.Force { + t.Error("Force should be false") + } + if cfg.ApproveMcps { + t.Error("ApproveMcps should be false") + } + if !cfg.StrictModel { + t.Error("StrictModel should be true") + } + if cfg.Mode != "ask" { + t.Errorf("Mode = %q, want %q", cfg.Mode, "ask") + } + if cfg.Workspace != "/workspace" { + t.Errorf("Workspace = %q, want /workspace", cfg.Workspace) + } + if !cfg.ChatOnlyWorkspace { + t.Error("ChatOnlyWorkspace should be true") + } + if cfg.WinCmdlineMax != 30000 { + t.Errorf("WinCmdlineMax = %d, want 30000", cfg.WinCmdlineMax) + } +} + +func TestLoadBridgeConfig_FromEnv(t *testing.T) { + e := env.EnvSource{ + "CURSOR_AGENT_BIN": "/usr/bin/agent", + "CURSOR_BRIDGE_HOST": "0.0.0.0", + "CURSOR_BRIDGE_PORT": "9999", + "CURSOR_BRIDGE_API_KEY": "sk-secret", + "CURSOR_BRIDGE_DEFAULT_MODEL": "org/claude-3-opus", + "CURSOR_BRIDGE_FORCE": "true", + "CURSOR_BRIDGE_APPROVE_MCPS": "yes", + "CURSOR_BRIDGE_STRICT_MODEL": "false", + "CURSOR_BRIDGE_WORKSPACE": "./my-workspace", + "CURSOR_BRIDGE_TIMEOUT_MS": "60000", + "CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE": "false", + "CURSOR_BRIDGE_VERBOSE": "1", + "CURSOR_BRIDGE_TLS_CERT": "./certs/test.crt", + "CURSOR_BRIDGE_TLS_KEY": "./certs/test.key", + } + cfg := config.LoadBridgeConfig(e, "/tmp/project") + + if cfg.AgentBin != "/usr/bin/agent" { + t.Errorf("AgentBin = %q, want /usr/bin/agent", cfg.AgentBin) + } + if cfg.Host != "0.0.0.0" { + t.Errorf("Host = %q, want 0.0.0.0", cfg.Host) + } + if cfg.Port != 9999 { + t.Errorf("Port = %d, want 9999", cfg.Port) + } + if cfg.RequiredKey != "sk-secret" { + t.Errorf("RequiredKey = %q, want sk-secret", cfg.RequiredKey) + } + if cfg.DefaultModel != "claude-3-opus" { + t.Errorf("DefaultModel = %q, want claude-3-opus", cfg.DefaultModel) + } + if !cfg.Force { + t.Error("Force should be true") + } + if !cfg.ApproveMcps { + t.Error("ApproveMcps should be true") + } + if cfg.StrictModel { + t.Error("StrictModel should be false") + } + if !filepath.IsAbs(cfg.Workspace) { + t.Errorf("Workspace should be absolute, got %q", cfg.Workspace) + } + if !strings.Contains(cfg.Workspace, "my-workspace") { + t.Errorf("Workspace %q should contain 'my-workspace'", cfg.Workspace) + } + if cfg.TimeoutMs != 60000 { + t.Errorf("TimeoutMs = %d, want 60000", cfg.TimeoutMs) + } + if cfg.ChatOnlyWorkspace { + t.Error("ChatOnlyWorkspace should be false") + } + if !cfg.Verbose { + t.Error("Verbose should be true") + } + if cfg.TLSCertPath != "/tmp/project/certs/test.crt" { + t.Errorf("TLSCertPath = %q, want /tmp/project/certs/test.crt", cfg.TLSCertPath) + } + if cfg.TLSKeyPath != "/tmp/project/certs/test.key" { + t.Errorf("TLSKeyPath = %q, want /tmp/project/certs/test.key", cfg.TLSKeyPath) + } +} + +func TestLoadBridgeConfig_WideHost(t *testing.T) { + cfg := config.LoadBridgeConfig(env.EnvSource{"CURSOR_BRIDGE_HOST": "0.0.0.0"}, "/workspace") + if cfg.Host != "0.0.0.0" { + t.Errorf("Host = %q, want 0.0.0.0", cfg.Host) + } +} diff --git a/internal/env/env.go b/internal/env/env.go new file mode 100644 index 0000000..d70d848 --- /dev/null +++ b/internal/env/env.go @@ -0,0 +1,331 @@ +package env + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" +) + +type EnvSource map[string]string + +type LoadedEnv struct { + AgentBin string + AgentNode string + AgentScript string + CommandShell string + Host string + Port int + RequiredKey string + DefaultModel string + Force bool + ApproveMcps bool + StrictModel bool + Workspace string + TimeoutMs int + TLSCertPath string + TLSKeyPath string + SessionsLogPath string + ChatOnlyWorkspace bool + Verbose bool + MaxMode bool + ConfigDirs []string + MultiPort bool + WinCmdlineMax int +} + +type AgentCommand struct { + Command string + Args []string + Env map[string]string + WindowsVerbatimArguments bool + AgentScriptPath string + ConfigDir string +} + +func getEnvVal(e EnvSource, names []string) string { + for _, name := range names { + if v, ok := e[name]; ok && strings.TrimSpace(v) != "" { + return strings.TrimSpace(v) + } + } + return "" +} + +func envBool(e EnvSource, names []string, def bool) bool { + raw := getEnvVal(e, names) + if raw == "" { + return def + } + switch strings.ToLower(raw) { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + } + return def +} + +func envInt(e EnvSource, names []string, def int) int { + raw := getEnvVal(e, names) + if raw == "" { + return def + } + v, err := strconv.Atoi(raw) + if err != nil { + return def + } + return v +} + +func normalizeModelId(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "auto" + } + parts := strings.Split(raw, "/") + last := parts[len(parts)-1] + if last == "" { + return "auto" + } + return last +} + +func resolveAbs(raw, cwd string) string { + if raw == "" { + return "" + } + if filepath.IsAbs(raw) { + return raw + } + return filepath.Join(cwd, raw) +} + +func isAuthenticatedAccountDir(dir string) bool { + data, err := os.ReadFile(filepath.Join(dir, "cli-config.json")) + if err != nil { + return false + } + var cfg struct { + AuthInfo *struct { + Email string `json:"email"` + } `json:"authInfo"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return false + } + return cfg.AuthInfo != nil && cfg.AuthInfo.Email != "" +} + +func discoverAccountDirs(homeDir string) []string { + if homeDir == "" { + return nil + } + accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts") + entries, err := os.ReadDir(accountsDir) + if err != nil { + return nil + } + var dirs []string + for _, e := range entries { + if !e.IsDir() { + continue + } + dir := filepath.Join(accountsDir, e.Name()) + if isAuthenticatedAccountDir(dir) { + dirs = append(dirs, dir) + } + } + return dirs +} + +func OsEnvToMap() EnvSource { + m := make(EnvSource) + for _, kv := range os.Environ() { + parts := strings.SplitN(kv, "=", 2) + if len(parts) == 2 { + m[parts[0]] = parts[1] + } + } + return m +} + +func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv { + if e == nil { + e = OsEnvToMap() + } + if cwd == "" { + var err error + cwd, err = os.Getwd() + if err != nil { + cwd = "." + } + } + + host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"}) + if host == "" { + host = "127.0.0.1" + } + port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765) + if port <= 0 { + port = 8765 + } + + home := getEnvVal(e, []string{"HOME", "USERPROFILE"}) + + sessionsLogPath := func() string { + if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" { + return p + } + if home != "" { + return filepath.Join(home, ".cursor-api-proxy", "sessions.log") + } + return filepath.Join(cwd, "sessions.log") + }() + + var configDirs []string + if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" { + for _, d := range strings.Split(raw, ",") { + d = strings.TrimSpace(d) + if d != "" { + if p := resolveAbs(d, cwd); p != "" { + configDirs = append(configDirs, p) + } + } + } + } + if len(configDirs) == 0 { + configDirs = discoverAccountDirs(home) + } + + winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000) + if winMax < 4096 { + winMax = 4096 + } + if winMax > 32700 { + winMax = 32700 + } + + agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"}) + if agentBin == "" { + agentBin = "agent" + } + commandShell := getEnvVal(e, []string{"COMSPEC"}) + if commandShell == "" { + commandShell = "cmd.exe" + } + workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd) + if workspace == "" { + workspace = cwd + } + + return LoadedEnv{ + AgentBin: agentBin, + AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}), + AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}), + CommandShell: commandShell, + Host: host, + Port: port, + RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}), + DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})), + Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false), + ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false), + StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true), + Workspace: workspace, + TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000), + TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd), + TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd), + SessionsLogPath: sessionsLogPath, + ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true), + Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false), + MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false), + ConfigDirs: configDirs, + MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false), + WinCmdlineMax: winMax, + } +} + +func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand { + if e == nil { + e = OsEnvToMap() + } + loaded := LoadEnvConfig(e, cwd) + + cloneEnv := func() map[string]string { + m := make(map[string]string, len(e)) + for k, v := range e { + m[k] = v + } + return m + } + + if runtime.GOOS == "windows" { + if loaded.AgentNode != "" && loaded.AgentScript != "" { + agentScriptPath := loaded.AgentScript + if !filepath.IsAbs(agentScriptPath) { + agentScriptPath = filepath.Join(cwd, agentScriptPath) + } + agentDir := filepath.Dir(agentScriptPath) + configDir := filepath.Join(agentDir, "..", "data", "config") + env2 := cloneEnv() + env2["CURSOR_INVOKED_AS"] = "agent.cmd" + ac := AgentCommand{ + Command: loaded.AgentNode, + Args: append([]string{loaded.AgentScript}, args...), + Env: env2, + AgentScriptPath: agentScriptPath, + } + if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil { + ac.ConfigDir = configDir + } + return ac + } + + if strings.HasSuffix(strings.ToLower(cmd), ".cmd") { + cmdResolved := cmd + if !filepath.IsAbs(cmd) { + cmdResolved = filepath.Join(cwd, cmd) + } + dir := filepath.Dir(cmdResolved) + nodeBin := filepath.Join(dir, "node.exe") + script := filepath.Join(dir, "index.js") + if _, err1 := os.Stat(nodeBin); err1 == nil { + if _, err2 := os.Stat(script); err2 == nil { + configDir := filepath.Join(dir, "..", "data", "config") + env2 := cloneEnv() + env2["CURSOR_INVOKED_AS"] = "agent.cmd" + ac := AgentCommand{ + Command: nodeBin, + Args: append([]string{script}, args...), + Env: env2, + AgentScriptPath: script, + } + if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil { + ac.ConfigDir = configDir + } + return ac + } + } + + quotedArgs := make([]string, len(args)) + for i, a := range args { + if strings.Contains(a, " ") { + quotedArgs[i] = `"` + a + `"` + } else { + quotedArgs[i] = a + } + } + cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"` + return AgentCommand{ + Command: loaded.CommandShell, + Args: []string{"/d", "/s", "/c", cmdLine}, + Env: cloneEnv(), + WindowsVerbatimArguments: true, + } + } + } + + return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()} +} diff --git a/internal/env/env_test.go b/internal/env/env_test.go new file mode 100644 index 0000000..e589d59 --- /dev/null +++ b/internal/env/env_test.go @@ -0,0 +1,65 @@ +package env + +import "testing" + +func TestLoadEnvConfigDefaults(t *testing.T) { + e := EnvSource{} + loaded := LoadEnvConfig(e, "/tmp") + + if loaded.Host != "127.0.0.1" { + t.Errorf("expected 127.0.0.1, got %s", loaded.Host) + } + if loaded.Port != 8765 { + t.Errorf("expected 8765, got %d", loaded.Port) + } + if loaded.DefaultModel != "auto" { + t.Errorf("expected auto, got %s", loaded.DefaultModel) + } + if loaded.AgentBin != "agent" { + t.Errorf("expected agent, got %s", loaded.AgentBin) + } + if !loaded.StrictModel { + t.Error("expected strictModel=true by default") + } +} + +func TestLoadEnvConfigOverride(t *testing.T) { + e := EnvSource{ + "CURSOR_BRIDGE_HOST": "0.0.0.0", + "CURSOR_BRIDGE_PORT": "9000", + "CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4", + "CURSOR_AGENT_BIN": "/usr/local/bin/agent", + } + loaded := LoadEnvConfig(e, "/tmp") + + if loaded.Host != "0.0.0.0" { + t.Errorf("expected 0.0.0.0, got %s", loaded.Host) + } + if loaded.Port != 9000 { + t.Errorf("expected 9000, got %d", loaded.Port) + } + if loaded.DefaultModel != "gpt-4" { + t.Errorf("expected gpt-4, got %s", loaded.DefaultModel) + } + if loaded.AgentBin != "/usr/local/bin/agent" { + t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin) + } +} + +func TestNormalizeModelID(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"gpt-4", "gpt-4"}, + {"openai/gpt-4", "gpt-4"}, + {"", "auto"}, + {" ", "auto"}, + } + for _, tc := range tests { + got := normalizeModelId(tc.input) + if got != tc.want { + t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/internal/handlers/anthropic_handler.go b/internal/handlers/anthropic_handler.go new file mode 100644 index 0000000..9691a36 --- /dev/null +++ b/internal/handlers/anthropic_handler.go @@ -0,0 +1,285 @@ +package handlers + +import ( + "cursor-api-proxy/internal/agent" + "cursor-api-proxy/internal/anthropic" + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/httputil" + "cursor-api-proxy/internal/logger" + "cursor-api-proxy/internal/models" + "cursor-api-proxy/internal/openai" + "cursor-api-proxy/internal/parser" + "cursor-api-proxy/internal/pool" + "cursor-api-proxy/internal/sanitize" + "cursor-api-proxy/internal/winlimit" + "cursor-api-proxy/internal/workspace" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" +) + +func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { + var req anthropic.MessagesRequest + if err := json.Unmarshal([]byte(rawBody), &req); err != nil { + httputil.WriteJSON(w, 400, map[string]interface{}{ + "error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"}, + }, nil) + return + } + + requested := openai.NormalizeModelID(req.Model) + model := ResolveModel(requested, lastModelRef, cfg) + + // Parse system from raw body to handle both string and array + var rawMap map[string]interface{} + _ = json.Unmarshal([]byte(rawBody), &rawMap) + + cleanSystem := sanitize.SanitizeSystem(req.System) + + // SanitizeMessages expects []interface{} + rawMessages := make([]interface{}, len(req.Messages)) + for i, m := range req.Messages { + rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content} + } + cleanRawMessages := sanitize.SanitizeMessages(rawMessages) + + var cleanMessages []anthropic.MessageParam + for _, raw := range cleanRawMessages { + if m, ok := raw.(map[string]interface{}); ok { + role, _ := m["role"].(string) + cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]}) + } + } + + toolsText := openai.ToolsToSystemText(req.Tools, nil) + var systemWithTools interface{} + if toolsText != "" { + sysStr := "" + switch v := cleanSystem.(type) { + case string: + sysStr = v + } + if sysStr != "" { + systemWithTools = sysStr + "\n\n" + toolsText + } else { + systemWithTools = toolsText + } + } else { + systemWithTools = cleanSystem + } + + prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools) + + if req.MaxTokens == 0 { + httputil.WriteJSON(w, 400, map[string]interface{}{ + "error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"}, + }, nil) + return + } + + cursorModel := models.ResolveToCursorModel(model) + if cursorModel == "" { + cursorModel = model + } + + var trafficMsgs []logger.TrafficMessage + if s := systemToString(cleanSystem); s != "" { + trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s}) + } + for _, m := range cleanMessages { + text := contentToString(m.Content) + if text != "" { + trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text}) + } + } + logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream) + + headerWs := r.Header.Get("x-cursor-workspace") + ws := workspace.ResolveWorkspace(cfg, headerWs) + + fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream) + fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir) + + if !fit.OK { + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"type": "api_error", "message": fit.Error}, + }, nil) + return + } + if fit.Truncated { + fmt.Printf("[%s] Windows: prompt truncated (%d -> %d chars).\n", + time.Now().UTC().Format(time.RFC3339), fit.OriginalLength, fit.FinalPromptLength) + } + + cmdArgs := fit.Args + msgID := "msg_" + uuid.New().String() + + var truncatedHeaders map[string]string + if fit.Truncated { + truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} + } + + if req.Stream { + httputil.WriteSSEHeaders(w, truncatedHeaders) + flusher, _ := w.(http.Flusher) + + writeEvent := func(evt interface{}) { + data, _ := json.Marshal(evt) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + } + + writeEvent(map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": msgID, + "type": "message", + "role": "assistant", + "model": model, + "content": []interface{}{}, + }, + }) + writeEvent(map[string]interface{}{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]string{"type": "text", "text": ""}, + }) + + var accumulated string + parseLine := parser.CreateStreamParser( + func(text string) { + accumulated += text + writeEvent(map[string]interface{}{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]string{"type": "text_delta", "text": text}, + }) + }, + func() { + logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) + writeEvent(map[string]interface{}{"type": "content_block_stop", "index": 0}) + writeEvent(map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, + "usage": map[string]int{"output_tokens": 0}, + }) + writeEvent(map[string]interface{}{"type": "message_stop"}) + }, + ) + + configDir := pool.GetNextAccountConfigDir() + logger.LogAccountAssigned(configDir) + pool.ReportRequestStart(configDir) + streamStart := time.Now().UnixMilli() + + ctx := r.Context() + result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx) + + latencyMs := time.Now().UnixMilli() - streamStart + pool.ReportRequestEnd(configDir) + + if err == nil && isRateLimited(result.Stderr) { + pool.ReportRateLimit(configDir, 60000) + } + if err != nil || result.Code != 0 { + pool.ReportRequestError(configDir, latencyMs) + if err != nil { + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) + } else { + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr) + } + } else { + pool.ReportRequestSuccess(configDir, latencyMs) + } + logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + return + } + + configDir := pool.GetNextAccountConfigDir() + logger.LogAccountAssigned(configDir) + pool.ReportRequestStart(configDir) + syncStart := time.Now().UnixMilli() + + out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context()) + syncLatency := time.Now().UnixMilli() - syncStart + pool.ReportRequestEnd(configDir) + + if err != nil { + pool.ReportRequestError(configDir, syncLatency) + logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"type": "api_error", "message": err.Error()}, + }, nil) + return + } + + if isRateLimited(out.Stderr) { + pool.ReportRateLimit(configDir, 60000) + } + + if out.Code != 0 { + pool.ReportRequestError(configDir, syncLatency) + logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr) + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"type": "api_error", "message": errMsg}, + }, nil) + return + } + + pool.ReportRequestSuccess(configDir, syncLatency) + content := trimSpace(out.Stdout) + logger.LogTrafficResponse(cfg.Verbose, model, content, false) + logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + + httputil.WriteJSON(w, 200, map[string]interface{}{ + "id": msgID, + "type": "message", + "role": "assistant", + "content": []map[string]string{{"type": "text", "text": content}}, + "model": model, + "stop_reason": "end_turn", + "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, + }, truncatedHeaders) +} + +func systemToString(system interface{}) string { + switch v := system.(type) { + case string: + return v + case []interface{}: + result := "" + for _, p := range v { + if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" { + if t, ok := m["text"].(string); ok { + result += t + } + } + } + return result + } + return "" +} + +func contentToString(content interface{}) string { + switch v := content.(type) { + case string: + return v + case []interface{}: + result := "" + for _, p := range v { + if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" { + if t, ok := m["text"].(string); ok { + result += t + } + } + } + return result + } + return "" +} diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go new file mode 100644 index 0000000..8a92021 --- /dev/null +++ b/internal/handlers/chat.go @@ -0,0 +1,249 @@ +package handlers + +import ( + "cursor-api-proxy/internal/agent" + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/httputil" + "cursor-api-proxy/internal/logger" + "cursor-api-proxy/internal/models" + "cursor-api-proxy/internal/openai" + "cursor-api-proxy/internal/parser" + "cursor-api-proxy/internal/pool" + "cursor-api-proxy/internal/sanitize" + "cursor-api-proxy/internal/winlimit" + "cursor-api-proxy/internal/workspace" + "encoding/json" + "fmt" + "net/http" + "regexp" + "time" + + "github.com/google/uuid" +) + +var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`) + +func isRateLimited(stderr string) bool { + return rateLimitRe.MatchString(stderr) +} + +func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { + var bodyMap map[string]interface{} + if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil { + httputil.WriteJSON(w, 400, map[string]interface{}{ + "error": map[string]string{"message": "invalid JSON body", "code": "bad_request"}, + }, nil) + return + } + + rawModel, _ := bodyMap["model"].(string) + requested := openai.NormalizeModelID(rawModel) + model := ResolveModel(requested, lastModelRef, cfg) + cursorModel := models.ResolveToCursorModel(model) + if cursorModel == "" { + cursorModel = model + } + + var messages []interface{} + if m, ok := bodyMap["messages"].([]interface{}); ok { + messages = m + } + + var tools []interface{} + if t, ok := bodyMap["tools"].([]interface{}); ok { + tools = t + } + var funcs []interface{} + if f, ok := bodyMap["functions"].([]interface{}); ok { + funcs = f + } + + cleanMessages := sanitize.SanitizeMessages(messages) + + toolsText := openai.ToolsToSystemText(tools, funcs) + messagesWithTools := cleanMessages + if toolsText != "" { + messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...) + } + prompt := openai.BuildPromptFromMessages(messagesWithTools) + + var trafficMsgs []logger.TrafficMessage + for _, raw := range cleanMessages { + if m, ok := raw.(map[string]interface{}); ok { + role, _ := m["role"].(string) + content := openai.MessageContentToText(m["content"]) + trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content}) + } + } + + isStream := false + if s, ok := bodyMap["stream"].(bool); ok { + isStream = s + } + + logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream) + + headerWs := r.Header.Get("x-cursor-workspace") + ws := workspace.ResolveWorkspace(cfg, headerWs) + + fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream) + fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir) + + if !fit.OK { + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"}, + }, nil) + return + } + if fit.Truncated { + fmt.Printf("[%s] Windows: prompt truncated for CreateProcess limit (%d -> %d chars, tail preserved).\n", + time.Now().UTC().Format(time.RFC3339), fit.OriginalLength, fit.FinalPromptLength) + } + + cmdArgs := fit.Args + id := "chatcmpl_" + uuid.New().String() + created := time.Now().Unix() + + var truncatedHeaders map[string]string + if fit.Truncated { + truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} + } + + if isStream { + httputil.WriteSSEHeaders(w, truncatedHeaders) + flusher, _ := w.(http.Flusher) + + var accumulated string + parseLine := parser.CreateStreamParser( + func(text string) { + accumulated += text + chunk := map[string]interface{}{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + }, + func() { + logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) + stopChunk := map[string]interface{}{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, + }, + } + data, _ := json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + }, + ) + + configDir := pool.GetNextAccountConfigDir() + logger.LogAccountAssigned(configDir) + pool.ReportRequestStart(configDir) + streamStart := time.Now().UnixMilli() + + ctx := r.Context() + result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx) + + latencyMs := time.Now().UnixMilli() - streamStart + pool.ReportRequestEnd(configDir) + + if err == nil && isRateLimited(result.Stderr) { + pool.ReportRateLimit(configDir, 60000) + } + + if err != nil || (result.Code != 0 && ctx.Err() == nil) { + pool.ReportRequestError(configDir, latencyMs) + if err != nil { + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) + } else { + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr) + } + } else { + pool.ReportRequestSuccess(configDir, latencyMs) + } + logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + return + } + + configDir := pool.GetNextAccountConfigDir() + logger.LogAccountAssigned(configDir) + pool.ReportRequestStart(configDir) + syncStart := time.Now().UnixMilli() + + out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context()) + syncLatency := time.Now().UnixMilli() - syncStart + pool.ReportRequestEnd(configDir) + + if err != nil { + pool.ReportRequestError(configDir, syncLatency) + logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"}, + }, nil) + return + } + + if isRateLimited(out.Stderr) { + pool.ReportRateLimit(configDir, 60000) + } + + if out.Code != 0 { + pool.ReportRequestError(configDir, syncLatency) + logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr) + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"message": errMsg, "code": "cursor_cli_error"}, + }, nil) + return + } + + pool.ReportRequestSuccess(configDir, syncLatency) + content := trimSpace(out.Stdout) + logger.LogTrafficResponse(cfg.Verbose, model, content, false) + logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + + httputil.WriteJSON(w, 200, map[string]interface{}{ + "id": id, + "object": "chat.completion", + "created": created, + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "message": map[string]string{"role": "assistant", "content": content}, + "finish_reason": "stop", + }, + }, + "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }, truncatedHeaders) +} + +func trimSpace(s string) string { + result := "" + start := 0 + end := len(s) + for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') { + start++ + } + for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') { + end-- + } + result = s[start:end] + return result +} diff --git a/internal/handlers/health.go b/internal/handlers/health.go new file mode 100644 index 0000000..6b26713 --- /dev/null +++ b/internal/handlers/health.go @@ -0,0 +1,20 @@ +package handlers + +import ( + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/httputil" + "net/http" +) + +func HandleHealth(w http.ResponseWriter, r *http.Request, version string, cfg config.BridgeConfig) { + httputil.WriteJSON(w, 200, map[string]interface{}{ + "ok": true, + "version": version, + "workspace": cfg.Workspace, + "mode": cfg.Mode, + "defaultModel": cfg.DefaultModel, + "force": cfg.Force, + "approveMcps": cfg.ApproveMcps, + "strictModel": cfg.StrictModel, + }, nil) +} diff --git a/internal/handlers/models.go b/internal/handlers/models.go new file mode 100644 index 0000000..16c7873 --- /dev/null +++ b/internal/handlers/models.go @@ -0,0 +1,107 @@ +package handlers + +import ( + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/httputil" + "cursor-api-proxy/internal/models" + "net/http" + "sync" + "time" +) + +const modelCacheTTLMs = 5 * 60 * 1000 + +type ModelCache struct { + At int64 + Models []models.CursorCliModel +} + +type ModelCacheRef struct { + mu sync.Mutex + cache *ModelCache + inflight bool + waiters []chan struct{} +} + +func (ref *ModelCacheRef) HandleModels(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig) { + now := time.Now().UnixMilli() + + ref.mu.Lock() + if ref.cache != nil && now-ref.cache.At <= modelCacheTTLMs { + cache := ref.cache + ref.mu.Unlock() + writeModels(w, cache.Models) + return + } + + if ref.inflight { + // Wait for the in-flight fetch + ch := make(chan struct{}, 1) + ref.waiters = append(ref.waiters, ch) + ref.mu.Unlock() + <-ch + ref.mu.Lock() + cache := ref.cache + ref.mu.Unlock() + writeModels(w, cache.Models) + return + } + + ref.inflight = true + ref.mu.Unlock() + + fetched, err := models.ListCursorCliModels(cfg.AgentBin, 60000) + + ref.mu.Lock() + ref.inflight = false + if err == nil { + ref.cache = &ModelCache{At: time.Now().UnixMilli(), Models: fetched} + } + waiters := ref.waiters + ref.waiters = nil + ref.mu.Unlock() + + for _, ch := range waiters { + ch <- struct{}{} + } + + if err != nil { + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"message": err.Error(), "code": "models_fetch_error"}, + }, nil) + return + } + + writeModels(w, fetched) +} + +func writeModels(w http.ResponseWriter, mods []models.CursorCliModel) { + cursorModels := make([]map[string]interface{}, len(mods)) + for i, m := range mods { + cursorModels[i] = map[string]interface{}{ + "id": m.ID, + "object": "model", + "owned_by": "cursor", + "name": m.Name, + } + } + + ids := make([]string, len(mods)) + for i, m := range mods { + ids[i] = m.ID + } + aliases := models.GetAnthropicModelAliases(ids) + for _, a := range aliases { + cursorModels = append(cursorModels, map[string]interface{}{ + "id": a.ID, + "object": "model", + "owned_by": "cursor", + "name": a.Name, + }) + } + + httputil.WriteJSON(w, 200, map[string]interface{}{ + "object": "list", + "data": cursorModels, + }, nil) +} diff --git a/internal/handlers/resolve_model.go b/internal/handlers/resolve_model.go new file mode 100644 index 0000000..e20b353 --- /dev/null +++ b/internal/handlers/resolve_model.go @@ -0,0 +1,27 @@ +package handlers + +import "cursor-api-proxy/internal/config" + +func ResolveModel(requested string, lastModelRef *string, cfg config.BridgeConfig) string { + isAuto := requested == "auto" + var explicitModel string + if requested != "" && !isAuto { + explicitModel = requested + } + if explicitModel != "" { + *lastModelRef = explicitModel + } + if isAuto { + return "auto" + } + if explicitModel != "" { + return explicitModel + } + if cfg.StrictModel && *lastModelRef != "" { + return *lastModelRef + } + if *lastModelRef != "" { + return *lastModelRef + } + return cfg.DefaultModel +} diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go new file mode 100644 index 0000000..bb39663 --- /dev/null +++ b/internal/httputil/httputil.go @@ -0,0 +1,50 @@ +package httputil + +import ( + "encoding/json" + "io" + "net/http" + "regexp" +) + +var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`) + +func ExtractBearerToken(r *http.Request) string { + h := r.Header.Get("Authorization") + if h == "" { + return "" + } + m := bearerRe.FindStringSubmatch(h) + if m == nil { + return "" + } + return m[1] +} + +func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) { + for k, v := range extraHeaders { + w.Header().Set(k, v) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) { + for k, v := range extraHeaders { + w.Header().Set(k, v) + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + w.WriteHeader(200) +} + +func ReadBody(r *http.Request) (string, error) { + data, err := io.ReadAll(r.Body) + if err != nil { + return "", err + } + return string(data), nil +} diff --git a/internal/httputil/httputil_test.go b/internal/httputil/httputil_test.go new file mode 100644 index 0000000..530e536 --- /dev/null +++ b/internal/httputil/httputil_test.go @@ -0,0 +1,50 @@ +package httputil + +import ( + "net/http/httptest" + "testing" +) + +func TestExtractBearerToken(t *testing.T) { + tests := []struct { + header string + want string + }{ + {"Bearer mytoken123", "mytoken123"}, + {"bearer MYTOKEN", "MYTOKEN"}, + {"", ""}, + {"Basic abc", ""}, + {"Bearer ", ""}, + } + for _, tc := range tests { + req := httptest.NewRequest("GET", "/", nil) + if tc.header != "" { + req.Header.Set("Authorization", tc.header) + } + got := ExtractBearerToken(req) + if got != tc.want { + t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want) + } + } +} + +func TestWriteJSON(t *testing.T) { + w := httptest.NewRecorder() + WriteJSON(w, 200, map[string]string{"ok": "true"}, nil) + + if w.Code != 200 { + t.Errorf("expected 200, got %d", w.Code) + } + if w.Header().Get("Content-Type") != "application/json" { + t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type")) + } +} + +func TestWriteJSONWithExtraHeaders(t *testing.T) { + w := httptest.NewRecorder() + WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"}) + + if w.Header().Get("X-Custom") != "value" { + t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom")) + } +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 0000000..d0db1e7 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,184 @@ +package logger + +import ( + "cursor-api-proxy/internal/pool" + "fmt" + "os" + "path/filepath" + "strings" + "time" +) + +const ( + cReset = "\x1b[0m" + cBold = "\x1b[1m" + cDim = "\x1b[2m" + cCyan = "\x1b[36m" + cBCyan = "\x1b[1;96m" + cGreen = "\x1b[32m" + cBGreen = "\x1b[1;92m" + cYellow = "\x1b[33m" + cMagenta = "\x1b[35m" + cBMagenta = "\x1b[1;95m" + cRed = "\x1b[31m" + cGray = "\x1b[90m" + cWhite = "\x1b[97m" +) + +var roleStyle = map[string]string{ + "system": cYellow, + "user": cCyan, + "assistant": cGreen, +} + +var roleEmoji = map[string]string{ + "system": "🔧", + "user": "👤", + "assistant": "🤖", +} + +func ts() string { + return cGray + time.Now().UTC().Format(time.RFC3339Nano) + cReset +} + +func truncate(s string, max int) string { + if len(s) <= max { + return s + } + head := int(float64(max) * 0.6) + tail := max - head + omitted := len(s) - head - tail + return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset +} + +func hr(ch string, length int) string { + return cGray + strings.Repeat(ch, length) + cReset +} + +type TrafficMessage struct { + Role string + Content string +} + +func LogIncoming(method, pathname, remoteAddress string) { + fmt.Printf("[%s] Incoming: %s %s (from %s)\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress) +} + +func LogAccountAssigned(configDir string) { + if configDir == "" { + return + } + name := filepath.Base(configDir) + fmt.Printf("[%s] %s→ account%s %s%s%s\n", time.Now().UTC().Format(time.RFC3339), cBCyan, cReset, cBold, name, cReset) +} + +func LogAccountStats(verbose bool, stats []pool.AccountStat) { + if !verbose || len(stats) == 0 { + return + } + now := time.Now().UnixMilli() + fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset) + for _, s := range stats { + name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir)) + active := fmt.Sprintf("%sdim%sactive:0%s", cDim, "", cReset) + if s.ActiveRequests > 0 { + active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset) + } + total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset) + ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset) + errStr := fmt.Sprintf("%sdim%serr:0%s", cDim, "", cReset) + if s.TotalErrors > 0 { + errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset) + } + rl := fmt.Sprintf("%sdim%srl:0%s", cDim, "", cReset) + if s.TotalRateLimits > 0 { + rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset) + } + avg := "avg:-" + if s.TotalRequests > 0 { + avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests)) + } + status := fmt.Sprintf("%s✓%s", cGreen, cReset) + if s.IsRateLimited { + recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339) + _ = now + status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset) + } + fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n", + cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status) + } + fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset) +} + +func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) { + if !verbose { + return + } + modeTag := fmt.Sprintf("%sdim%ssync%s", cDim, "", cReset) + if isStream { + modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset) + } + modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset) + fmt.Println(hr("─", 60)) + fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag) + for _, m := range messages { + roleColor := cWhite + if c, ok := roleStyle[m.Role]; ok { + roleColor = c + } + emoji := "💬" + if e, ok := roleEmoji[m.Role]; ok { + emoji = e + } + label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset) + charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset) + preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280) + fmt.Printf(" %s %s %s\n", emoji, label, charCount) + fmt.Printf(" %s%s%s\n", cDim, preview, cReset) + } +} + +func LogTrafficResponse(verbose bool, model, text string, isStream bool) { + if !verbose { + return + } + modeTag := fmt.Sprintf("%sdim%ssync%s", cDim, "", cReset) + if isStream { + modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset) + } + modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset) + charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset) + preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480) + fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount) + fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset) + fmt.Println(hr("─", 60)) +} + +func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) { + line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode) + dir := filepath.Dir(logPath) + if err := os.MkdirAll(dir, 0755); err == nil { + f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err == nil { + _, _ = f.WriteString(line) + f.Close() + } + } +} + +func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string { + errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr)) + fmt.Fprintf(os.Stderr, "[%s] Agent error: %s\n", time.Now().UTC().Format(time.RFC3339), errMsg) + truncated := strings.TrimSpace(stderr) + if len(truncated) > 200 { + truncated = truncated[:200] + } + truncated = strings.ReplaceAll(truncated, "\n", " ") + line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n", + time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated) + if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { + _, _ = f.WriteString(line) + f.Close() + } + return errMsg +} diff --git a/internal/models/cursorcli.go b/internal/models/cursorcli.go new file mode 100644 index 0000000..69d1f6c --- /dev/null +++ b/internal/models/cursorcli.go @@ -0,0 +1,62 @@ +package models + +import ( + "cursor-api-proxy/internal/process" + "fmt" + "os" + "regexp" + "strings" +) + +type CursorCliModel struct { + ID string + Name string +} + +var modelLineRe = regexp.MustCompile(`^([A-Za-z0-9][A-Za-z0-9._:/-]*)\s+-\s+(.*)$`) +var trailingParenRe = regexp.MustCompile(`\s*\([^)]*\)\s*$`) + +func ParseCursorCliModels(output string) []CursorCliModel { + lines := strings.Split(output, "\n") + seen := make(map[string]CursorCliModel) + var order []string + + for _, line := range lines { + line = strings.TrimSpace(line) + m := modelLineRe.FindStringSubmatch(line) + if m == nil { + continue + } + id := m[1] + rawName := m[2] + name := strings.TrimSpace(trailingParenRe.ReplaceAllString(rawName, "")) + if name == "" { + name = id + } + if _, exists := seen[id]; !exists { + seen[id] = CursorCliModel{ID: id, Name: name} + order = append(order, id) + } + } + + result := make([]CursorCliModel, 0, len(order)) + for _, id := range order { + result = append(result, seen[id]) + } + return result +} + +func ListCursorCliModels(agentBin string, timeoutMs int) ([]CursorCliModel, error) { + tmpDir := os.TempDir() + result, err := process.Run(agentBin, []string{"--list-models"}, process.RunOptions{ + Cwd: tmpDir, + TimeoutMs: timeoutMs, + }) + if err != nil { + return nil, err + } + if result.Code != 0 { + return nil, fmt.Errorf("agent --list-models failed: %s", strings.TrimSpace(result.Stderr)) + } + return ParseCursorCliModels(result.Stdout), nil +} diff --git a/internal/models/cursorcli_test.go b/internal/models/cursorcli_test.go new file mode 100644 index 0000000..9a911ac --- /dev/null +++ b/internal/models/cursorcli_test.go @@ -0,0 +1,33 @@ +package models + +import "testing" + +func TestParseCursorCliModels(t *testing.T) { + output := ` +gpt-4o - GPT-4o (some info) +claude-3-5-sonnet - Claude 3.5 Sonnet +gpt-4o - GPT-4o duplicate +invalid line without dash +` + result := ParseCursorCliModels(output) + + if len(result) != 2 { + t.Fatalf("expected 2 unique models, got %d: %v", len(result), result) + } + if result[0].ID != "gpt-4o" { + t.Errorf("expected gpt-4o, got %s", result[0].ID) + } + if result[0].Name != "GPT-4o" { + t.Errorf("expected 'GPT-4o', got %s", result[0].Name) + } + if result[1].ID != "claude-3-5-sonnet" { + t.Errorf("expected claude-3-5-sonnet, got %s", result[1].ID) + } +} + +func TestParseCursorCliModelsEmpty(t *testing.T) { + result := ParseCursorCliModels("") + if len(result) != 0 { + t.Fatalf("expected empty, got %v", result) + } +} diff --git a/internal/models/cursormap.go b/internal/models/cursormap.go new file mode 100644 index 0000000..0f2e00e --- /dev/null +++ b/internal/models/cursormap.go @@ -0,0 +1,71 @@ +package models + +import "strings" + +var anthropicToCursor = map[string]string{ + "claude-opus-4-6": "opus-4.6", + "claude-opus-4.6": "opus-4.6", + "claude-sonnet-4-6": "sonnet-4.6", + "claude-sonnet-4.6": "sonnet-4.6", + "claude-opus-4-5": "opus-4.5", + "claude-opus-4.5": "opus-4.5", + "claude-sonnet-4-5": "sonnet-4.5", + "claude-sonnet-4.5": "sonnet-4.5", + "claude-opus-4": "opus-4.6", + "claude-sonnet-4": "sonnet-4.6", + "claude-haiku-4-5-20251001": "sonnet-4.5", + "claude-haiku-4-5": "sonnet-4.5", + "claude-haiku-4-6": "sonnet-4.6", + "claude-haiku-4": "sonnet-4.5", + "claude-opus-4-6-thinking": "opus-4.6-thinking", + "claude-sonnet-4-6-thinking": "sonnet-4.6-thinking", + "claude-opus-4-5-thinking": "opus-4.5-thinking", + "claude-sonnet-4-5-thinking": "sonnet-4.5-thinking", +} + +type ModelAlias struct { + CursorID string + AnthropicID string + Name string +} + +var cursorToAnthropicAlias = []ModelAlias{ + {"opus-4.6", "claude-opus-4-6", "Claude 4.6 Opus"}, + {"opus-4.6-thinking", "claude-opus-4-6-thinking", "Claude 4.6 Opus (Thinking)"}, + {"sonnet-4.6", "claude-sonnet-4-6", "Claude 4.6 Sonnet"}, + {"sonnet-4.6-thinking", "claude-sonnet-4-6-thinking", "Claude 4.6 Sonnet (Thinking)"}, + {"opus-4.5", "claude-opus-4-5", "Claude 4.5 Opus"}, + {"opus-4.5-thinking", "claude-opus-4-5-thinking", "Claude 4.5 Opus (Thinking)"}, + {"sonnet-4.5", "claude-sonnet-4-5", "Claude 4.5 Sonnet"}, + {"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"}, +} + +func ResolveToCursorModel(requested string) string { + if strings.TrimSpace(requested) == "" { + return "" + } + key := strings.ToLower(strings.TrimSpace(requested)) + if v, ok := anthropicToCursor[key]; ok { + return v + } + return strings.TrimSpace(requested) +} + +type AnthropicAlias struct { + ID string + Name string +} + +func GetAnthropicModelAliases(availableCursorIDs []string) []AnthropicAlias { + set := make(map[string]bool, len(availableCursorIDs)) + for _, id := range availableCursorIDs { + set[id] = true + } + var result []AnthropicAlias + for _, a := range cursorToAnthropicAlias { + if set[a.CursorID] { + result = append(result, AnthropicAlias{ID: a.AnthropicID, Name: a.Name}) + } + } + return result +} diff --git a/internal/openai/openai.go b/internal/openai/openai.go new file mode 100644 index 0000000..120187b --- /dev/null +++ b/internal/openai/openai.go @@ -0,0 +1,184 @@ +package openai + +import ( + "encoding/json" + "strings" +) + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []interface{} `json:"messages"` + Stream bool `json:"stream"` + Tools []interface{} `json:"tools"` + ToolChoice interface{} `json:"tool_choice"` + Functions []interface{} `json:"functions"` + FunctionCall interface{} `json:"function_call"` +} + +func NormalizeModelID(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parts := strings.Split(trimmed, "/") + last := parts[len(parts)-1] + if last == "" { + return "" + } + return last +} + +func imageURLToText(imageURL interface{}) string { + if imageURL == nil { + return "[Image]" + } + var url string + switch v := imageURL.(type) { + case string: + url = v + case map[string]interface{}: + if u, ok := v["url"].(string); ok { + url = u + } + } + if url == "" { + return "[Image]" + } + if strings.HasPrefix(url, "data:") { + end := strings.Index(url, ";") + mime := "image" + if end > 5 { + mime = url[5:end] + } + return "[Image: base64 " + mime + "]" + } + return "[Image: " + url + "]" +} + +func MessageContentToText(content interface{}) string { + if content == nil { + return "" + } + switch v := content.(type) { + case string: + return v + case []interface{}: + var parts []string + for _, p := range v { + if p == nil { + continue + } + switch part := p.(type) { + case string: + parts = append(parts, part) + case map[string]interface{}: + typ, _ := part["type"].(string) + switch typ { + case "text": + if t, ok := part["text"].(string); ok { + parts = append(parts, t) + } + case "image_url": + parts = append(parts, imageURLToText(part["image_url"])) + case "image": + src := part["source"] + if src == nil { + src = part["url"] + } + parts = append(parts, imageURLToText(src)) + } + } + } + return strings.Join(parts, " ") + } + return "" +} + +func ToolsToSystemText(tools []interface{}, functions []interface{}) string { + var defs []interface{} + + for _, t := range tools { + if m, ok := t.(map[string]interface{}); ok { + if m["type"] == "function" { + if fn := m["function"]; fn != nil { + defs = append(defs, fn) + } + } else { + defs = append(defs, t) + } + } + } + defs = append(defs, functions...) + + if len(defs) == 0 { + return "" + } + + var lines []string + lines = append(lines, "Available tools (respond with a JSON object to call one):", "") + + for _, raw := range defs { + fn, ok := raw.(map[string]interface{}) + if !ok { + continue + } + name, _ := fn["name"].(string) + desc, _ := fn["description"].(string) + params := "{}" + if p := fn["parameters"]; p != nil { + if b, err := json.MarshalIndent(p, "", " "); err == nil { + params = string(b) + } + } + lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params) + } + + return strings.Join(lines, "\n") +} + +type SimpleMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func BuildPromptFromMessages(messages []interface{}) string { + var systemParts []string + var convo []string + + for _, raw := range messages { + m, ok := raw.(map[string]interface{}) + if !ok { + continue + } + role, _ := m["role"].(string) + text := MessageContentToText(m["content"]) + if text == "" { + continue + } + switch role { + case "system", "developer": + systemParts = append(systemParts, text) + case "user": + convo = append(convo, "User: "+text) + case "assistant": + convo = append(convo, "Assistant: "+text) + case "tool", "function": + convo = append(convo, "Tool: "+text) + } + } + + system := "" + if len(systemParts) > 0 { + system = "System:\n" + strings.Join(systemParts, "\n\n") + "\n\n" + } + transcript := strings.Join(convo, "\n\n") + return system + transcript + "\n\nAssistant:" +} + +func BuildPromptFromSimpleMessages(messages []SimpleMessage) string { + ifaces := make([]interface{}, len(messages)) + for i, m := range messages { + ifaces[i] = map[string]interface{}{"role": m.Role, "content": m.Content} + } + return BuildPromptFromMessages(ifaces) +} diff --git a/internal/openai/openai_test.go b/internal/openai/openai_test.go new file mode 100644 index 0000000..04ede1b --- /dev/null +++ b/internal/openai/openai_test.go @@ -0,0 +1,80 @@ +package openai + +import "testing" + +func TestNormalizeModelID(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"gpt-4", "gpt-4"}, + {"openai/gpt-4", "gpt-4"}, + {"anthropic/claude-3", "claude-3"}, + {"", ""}, + {" ", ""}, + {"a/b/c", "c"}, + } + for _, tc := range tests { + got := NormalizeModelID(tc.input) + if got != tc.want { + t.Errorf("NormalizeModelID(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestBuildPromptFromMessages(t *testing.T) { + messages := []interface{}{ + map[string]interface{}{"role": "system", "content": "You are helpful."}, + map[string]interface{}{"role": "user", "content": "Hello"}, + map[string]interface{}{"role": "assistant", "content": "Hi there"}, + } + got := BuildPromptFromMessages(messages) + if got == "" { + t.Fatal("expected non-empty prompt") + } + containsSystem := false + containsUser := false + containsAssistant := false + for i := 0; i < len(got)-10; i++ { + if got[i:i+6] == "System" { + containsSystem = true + } + if got[i:i+4] == "User" { + containsUser = true + } + if got[i:i+9] == "Assistant" { + containsAssistant = true + } + } + if !containsSystem || !containsUser || !containsAssistant { + t.Errorf("prompt missing sections: system=%v user=%v assistant=%v\n%s", + containsSystem, containsUser, containsAssistant, got) + } +} + +func TestToolsToSystemText(t *testing.T) { + tools := []interface{}{ + map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": "get_weather", + "description": "Get weather", + "parameters": map[string]interface{}{"type": "object"}, + }, + }, + } + got := ToolsToSystemText(tools, nil) + if got == "" { + t.Fatal("expected non-empty tools text") + } + if len(got) < 10 { + t.Errorf("tools text too short: %q", got) + } +} + +func TestToolsToSystemTextEmpty(t *testing.T) { + got := ToolsToSystemText(nil, nil) + if got != "" { + t.Errorf("expected empty string for no tools, got %q", got) + } +} diff --git a/internal/parser/stream.go b/internal/parser/stream.go new file mode 100644 index 0000000..aa9dc98 --- /dev/null +++ b/internal/parser/stream.go @@ -0,0 +1,61 @@ +package parser + +import "encoding/json" + +type StreamParser func(line string) + +func CreateStreamParser(onText func(string), onDone func()) StreamParser { + accumulated := "" + done := false + + return func(line string) { + if done { + return + } + + var obj struct { + Type string `json:"type"` + Subtype string `json:"subtype"` + Message *struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } `json:"message"` + } + + if err := json.Unmarshal([]byte(line), &obj); err != nil { + return + } + + if obj.Type == "assistant" && obj.Message != nil { + text := "" + for _, p := range obj.Message.Content { + if p.Type == "text" && p.Text != "" { + text += p.Text + } + } + if text == "" { + return + } + if text == accumulated { + return + } + if len(accumulated) > 0 && len(text) > len(accumulated) && text[:len(accumulated)] == accumulated { + delta := text[len(accumulated):] + if delta != "" { + onText(delta) + } + accumulated = text + } else { + onText(text) + accumulated += text + } + } + + if obj.Type == "result" && obj.Subtype == "success" { + done = true + onDone() + } + } +} diff --git a/internal/parser/stream_test.go b/internal/parser/stream_test.go new file mode 100644 index 0000000..b54f7c8 --- /dev/null +++ b/internal/parser/stream_test.go @@ -0,0 +1,164 @@ +package parser + +import ( + "encoding/json" + "testing" +) + +func makeAssistantLine(text string) string { + obj := map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": text}, + }, + }, + } + b, _ := json.Marshal(obj) + return string(b) +} + +func makeResultLine() string { + b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"}) + return string(b) +} + +func TestStreamParserIncrementalDeltas(t *testing.T) { + var texts []string + doneCount := 0 + parse := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + parse(makeAssistantLine("Hello")) + if len(texts) != 1 || texts[0] != "Hello" { + t.Fatalf("expected [Hello], got %v", texts) + } + + parse(makeAssistantLine("Hello world")) + if len(texts) != 2 || texts[1] != " world" { + t.Fatalf("expected second call with ' world', got %v", texts) + } +} + +func TestStreamParserDeduplicatesFinalMessage(t *testing.T) { + var texts []string + parse := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() {}, + ) + + parse(makeAssistantLine("Hi")) + parse(makeAssistantLine("Hi there")) + if len(texts) != 2 { + t.Fatalf("expected 2 calls, got %d: %v", len(texts), texts) + } + if texts[0] != "Hi" || texts[1] != " there" { + t.Fatalf("unexpected texts: %v", texts) + } + + // Final duplicate: full accumulated text again + parse(makeAssistantLine("Hi there")) + if len(texts) != 2 { + t.Fatalf("expected no new call after duplicate, got %d: %v", len(texts), texts) + } +} + +func TestStreamParserCallsOnDone(t *testing.T) { + var texts []string + doneCount := 0 + parse := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + parse(makeResultLine()) + if doneCount != 1 { + t.Fatalf("expected onDone called once, got %d", doneCount) + } + if len(texts) != 0 { + t.Fatalf("expected no text, got %v", texts) + } +} + +func TestStreamParserIgnoresLinesAfterDone(t *testing.T) { + var texts []string + doneCount := 0 + parse := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + parse(makeResultLine()) + parse(makeAssistantLine("late")) + if len(texts) != 0 { + t.Fatalf("expected no text after done, got %v", texts) + } + if doneCount != 1 { + t.Fatalf("expected onDone called once, got %d", doneCount) + } +} + +func TestStreamParserIgnoresNonAssistantLines(t *testing.T) { + var texts []string + parse := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() {}, + ) + + b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}}) + parse(string(b1)) + b2, _ := json.Marshal(map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{"content": []interface{}{}}, + }) + parse(string(b2)) + parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`) + + if len(texts) != 0 { + t.Fatalf("expected no texts, got %v", texts) + } +} + +func TestStreamParserIgnoresParseErrors(t *testing.T) { + var texts []string + doneCount := 0 + parse := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + parse("not json") + parse("{") + parse("") + + if len(texts) != 0 || doneCount != 0 { + t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount) + } +} + +func TestStreamParserJoinsMultipleTextParts(t *testing.T) { + var texts []string + parse := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() {}, + ) + + obj := map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": " "}, + {"type": "text", "text": "world"}, + }, + }, + } + b, _ := json.Marshal(obj) + parse(string(b)) + + if len(texts) != 1 || texts[0] != "Hello world" { + t.Fatalf("expected ['Hello world'], got %v", texts) + } +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go new file mode 100644 index 0000000..0bf8013 --- /dev/null +++ b/internal/pool/pool.go @@ -0,0 +1,259 @@ +package pool + +import ( + "sync" + "time" +) + +type accountStatus struct { + configDir string + activeRequests int + lastUsed int64 + rateLimitUntil int64 + totalRequests int + totalSuccess int + totalErrors int + totalRateLimits int + totalLatencyMs int64 +} + +type AccountStat struct { + ConfigDir string + ActiveRequests int + TotalRequests int + TotalSuccess int + TotalErrors int + TotalRateLimits int + TotalLatencyMs int64 + IsRateLimited bool + RateLimitUntil int64 +} + +type AccountPool struct { + mu sync.Mutex + accounts []*accountStatus +} + +func NewAccountPool(configDirs []string) *AccountPool { + accounts := make([]*accountStatus, 0, len(configDirs)) + for _, dir := range configDirs { + accounts = append(accounts, &accountStatus{configDir: dir}) + } + return &AccountPool{accounts: accounts} +} + +func (p *AccountPool) GetNextConfigDir() string { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.accounts) == 0 { + return "" + } + + now := time.Now().UnixMilli() + + available := make([]*accountStatus, 0, len(p.accounts)) + for _, a := range p.accounts { + if a.rateLimitUntil < now { + available = append(available, a) + } + } + + target := available + if len(target) == 0 { + target = make([]*accountStatus, len(p.accounts)) + copy(target, p.accounts) + // sort by earliest recovery + for i := 1; i < len(target); i++ { + for j := i; j > 0 && target[j].rateLimitUntil < target[j-1].rateLimitUntil; j-- { + target[j], target[j-1] = target[j-1], target[j] + } + } + } + + // pick least busy then least recently used + best := target[0] + for _, a := range target[1:] { + if a.activeRequests < best.activeRequests { + best = a + } else if a.activeRequests == best.activeRequests && a.lastUsed < best.lastUsed { + best = a + } + } + best.lastUsed = now + return best.configDir +} + +func (p *AccountPool) find(configDir string) *accountStatus { + for _, a := range p.accounts { + if a.configDir == configDir { + return a + } + } + return nil +} + +func (p *AccountPool) ReportRequestStart(configDir string) { + if configDir == "" { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if a := p.find(configDir); a != nil { + a.activeRequests++ + a.totalRequests++ + } +} + +func (p *AccountPool) ReportRequestEnd(configDir string) { + if configDir == "" { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if a := p.find(configDir); a != nil && a.activeRequests > 0 { + a.activeRequests-- + } +} + +func (p *AccountPool) ReportRequestSuccess(configDir string, latencyMs int64) { + if configDir == "" { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if a := p.find(configDir); a != nil { + a.totalSuccess++ + a.totalLatencyMs += latencyMs + } +} + +func (p *AccountPool) ReportRequestError(configDir string, latencyMs int64) { + if configDir == "" { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if a := p.find(configDir); a != nil { + a.totalErrors++ + a.totalLatencyMs += latencyMs + } +} + +func (p *AccountPool) ReportRateLimit(configDir string, penaltyMs int64) { + if configDir == "" { + return + } + if penaltyMs <= 0 { + penaltyMs = 60000 + } + p.mu.Lock() + defer p.mu.Unlock() + if a := p.find(configDir); a != nil { + a.rateLimitUntil = time.Now().UnixMilli() + penaltyMs + a.totalRateLimits++ + } +} + +func (p *AccountPool) GetStats() []AccountStat { + p.mu.Lock() + defer p.mu.Unlock() + now := time.Now().UnixMilli() + stats := make([]AccountStat, len(p.accounts)) + for i, a := range p.accounts { + stats[i] = AccountStat{ + ConfigDir: a.configDir, + ActiveRequests: a.activeRequests, + TotalRequests: a.totalRequests, + TotalSuccess: a.totalSuccess, + TotalErrors: a.totalErrors, + TotalRateLimits: a.totalRateLimits, + TotalLatencyMs: a.totalLatencyMs, + IsRateLimited: a.rateLimitUntil > now, + RateLimitUntil: a.rateLimitUntil, + } + } + return stats +} + +func (p *AccountPool) Count() int { + return len(p.accounts) +} + +// ─── Global pool ─────────────────────────────────────────────────────────── + +var ( + globalPool *AccountPool + globalMu sync.Mutex +) + +func InitAccountPool(configDirs []string) { + globalMu.Lock() + defer globalMu.Unlock() + globalPool = NewAccountPool(configDirs) +} + +func GetNextAccountConfigDir() string { + globalMu.Lock() + p := globalPool + globalMu.Unlock() + if p == nil { + return "" + } + return p.GetNextConfigDir() +} + +func ReportRequestStart(configDir string) { + globalMu.Lock() + p := globalPool + globalMu.Unlock() + if p != nil { + p.ReportRequestStart(configDir) + } +} + +func ReportRequestEnd(configDir string) { + globalMu.Lock() + p := globalPool + globalMu.Unlock() + if p != nil { + p.ReportRequestEnd(configDir) + } +} + +func ReportRequestSuccess(configDir string, latencyMs int64) { + globalMu.Lock() + p := globalPool + globalMu.Unlock() + if p != nil { + p.ReportRequestSuccess(configDir, latencyMs) + } +} + +func ReportRequestError(configDir string, latencyMs int64) { + globalMu.Lock() + p := globalPool + globalMu.Unlock() + if p != nil { + p.ReportRequestError(configDir, latencyMs) + } +} + +func ReportRateLimit(configDir string, penaltyMs int64) { + globalMu.Lock() + p := globalPool + globalMu.Unlock() + if p != nil { + p.ReportRateLimit(configDir, penaltyMs) + } +} + +func GetAccountStats() []AccountStat { + globalMu.Lock() + p := globalPool + globalMu.Unlock() + if p == nil { + return nil + } + return p.GetStats() +} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go new file mode 100644 index 0000000..27c1621 --- /dev/null +++ b/internal/pool/pool_test.go @@ -0,0 +1,152 @@ +package pool + +import ( + "testing" + "time" +) + +func TestEmptyPool(t *testing.T) { + p := NewAccountPool(nil) + if got := p.GetNextConfigDir(); got != "" { + t.Fatalf("expected empty string for empty pool, got %q", got) + } + if p.Count() != 0 { + t.Fatalf("expected count 0, got %d", p.Count()) + } +} + +func TestSingleDir(t *testing.T) { + p := NewAccountPool([]string{"/dir1"}) + if got := p.GetNextConfigDir(); got != "/dir1" { + t.Fatalf("expected /dir1, got %q", got) + } + if got := p.GetNextConfigDir(); got != "/dir1" { + t.Fatalf("expected /dir1 again, got %q", got) + } +} + +func TestRoundRobin(t *testing.T) { + p := NewAccountPool([]string{"/a", "/b", "/c"}) + got := []string{ + p.GetNextConfigDir(), + p.GetNextConfigDir(), + p.GetNextConfigDir(), + p.GetNextConfigDir(), + } + want := []string{"/a", "/b", "/c", "/a"} + for i, w := range want { + if got[i] != w { + t.Fatalf("call %d: expected %q, got %q", i, w, got[i]) + } + } +} + +func TestLeastBusy(t *testing.T) { + p := NewAccountPool([]string{"/dir1", "/dir2", "/dir3"}) + p.ReportRequestStart("/dir1") + p.ReportRequestStart("/dir2") + + if got := p.GetNextConfigDir(); got != "/dir3" { + t.Fatalf("expected /dir3 (least busy), got %q", got) + } + + p.ReportRequestStart("/dir3") + p.ReportRequestEnd("/dir1") + + if got := p.GetNextConfigDir(); got != "/dir1" { + t.Fatalf("expected /dir1 after end, got %q", got) + } +} + +func TestSkipsRateLimited(t *testing.T) { + p := NewAccountPool([]string{"/dir1", "/dir2"}) + p.ReportRateLimit("/dir1", 60000) + + if got := p.GetNextConfigDir(); got != "/dir2" { + t.Fatalf("expected /dir2, got %q", got) + } + if got := p.GetNextConfigDir(); got != "/dir2" { + t.Fatalf("expected /dir2 again, got %q", got) + } +} + +func TestFallbackToSoonestRecovery(t *testing.T) { + p := NewAccountPool([]string{"/dir1", "/dir2"}) + p.ReportRateLimit("/dir1", 60000) + p.ReportRateLimit("/dir2", 30000) + + // dir2 recovers sooner — should be selected + if got := p.GetNextConfigDir(); got != "/dir2" { + t.Fatalf("expected /dir2 (sooner recovery), got %q", got) + } +} + +func TestActiveRequestsDoesNotGoNegative(t *testing.T) { + p := NewAccountPool([]string{"/dir1"}) + p.ReportRequestEnd("/dir1") + p.ReportRequestEnd("/dir1") + if got := p.GetNextConfigDir(); got != "/dir1" { + t.Fatalf("pool should still work, got %q", got) + } +} + +func TestIgnoreUnknownConfigDir(t *testing.T) { + p := NewAccountPool([]string{"/dir1"}) + p.ReportRequestStart("/nonexistent") + p.ReportRequestEnd("/nonexistent") + p.ReportRateLimit("/nonexistent", 60000) + if got := p.GetNextConfigDir(); got != "/dir1" { + t.Fatalf("expected /dir1, got %q", got) + } +} + +func TestRateLimitExpires(t *testing.T) { + p := NewAccountPool([]string{"/dir1", "/dir2"}) + p.ReportRateLimit("/dir1", 50) + + if got := p.GetNextConfigDir(); got != "/dir2" { + t.Fatalf("immediately expected /dir2, got %q", got) + } + + time.Sleep(100 * time.Millisecond) + + if got := p.GetNextConfigDir(); got != "/dir1" { + t.Fatalf("after expiry expected /dir1, got %q", got) + } +} + +func TestGlobalPool(t *testing.T) { + InitAccountPool([]string{"/g1", "/g2"}) + if got := GetNextAccountConfigDir(); got != "/g1" { + t.Fatalf("expected /g1, got %q", got) + } + if got := GetNextAccountConfigDir(); got != "/g2" { + t.Fatalf("expected /g2, got %q", got) + } + if got := GetNextAccountConfigDir(); got != "/g1" { + t.Fatalf("expected /g1 again, got %q", got) + } +} + +func TestGlobalPoolEmpty(t *testing.T) { + InitAccountPool(nil) + if got := GetNextAccountConfigDir(); got != "" { + t.Fatalf("expected empty string for empty global pool, got %q", got) + } +} + +func TestGlobalPoolReinit(t *testing.T) { + InitAccountPool([]string{"/old1", "/old2"}) + GetNextAccountConfigDir() + InitAccountPool([]string{"/new1"}) + if got := GetNextAccountConfigDir(); got != "/new1" { + t.Fatalf("expected /new1 after reinit, got %q", got) + } +} + +func TestGlobalPoolFunctionsNoopBeforeInit(t *testing.T) { + InitAccountPool(nil) + ReportRequestStart("/dir1") + ReportRequestEnd("/dir1") + ReportRateLimit("/dir1", 1000) +} diff --git a/internal/process/kill_unix.go b/internal/process/kill_unix.go new file mode 100644 index 0000000..b235d96 --- /dev/null +++ b/internal/process/kill_unix.go @@ -0,0 +1,21 @@ +//go:build !windows + +package process + +import ( + "os/exec" + "syscall" +) + +func killProcessGroup(c *exec.Cmd) error { + if c.Process == nil { + return nil + } + // 殺死整個 process group(負號表示 group) + pgid, err := syscall.Getpgid(c.Process.Pid) + if err == nil { + _ = syscall.Kill(-pgid, syscall.SIGKILL) + } + // 同時也 kill 主程序,以防萬一 + return c.Process.Kill() +} diff --git a/internal/process/kill_windows.go b/internal/process/kill_windows.go new file mode 100644 index 0000000..1874bbf --- /dev/null +++ b/internal/process/kill_windows.go @@ -0,0 +1,14 @@ +//go:build windows + +package process + +import ( + "os/exec" +) + +func killProcessGroup(c *exec.Cmd) error { + if c.Process == nil { + return nil + } + return c.Process.Kill() +} diff --git a/internal/process/process.go b/internal/process/process.go new file mode 100644 index 0000000..44aa14a --- /dev/null +++ b/internal/process/process.go @@ -0,0 +1,248 @@ +package process + +import ( + "bufio" + "context" + "cursor-api-proxy/internal/env" + "fmt" + "os/exec" + "strings" + "sync" + "syscall" + "time" +) + +type RunResult struct { + Code int + Stdout string + Stderr string +} + +type RunOptions struct { + Cwd string + TimeoutMs int + MaxMode bool + ConfigDir string + Ctx context.Context +} + +type RunStreamingOptions struct { + RunOptions + OnLine func(line string) +} + +// ─── Global child process registry ────────────────────────────────────────── + +var ( + activeMu sync.Mutex + activeChildren []*exec.Cmd +) + +func registerChild(c *exec.Cmd) { + activeMu.Lock() + activeChildren = append(activeChildren, c) + activeMu.Unlock() +} + +func unregisterChild(c *exec.Cmd) { + activeMu.Lock() + for i, ch := range activeChildren { + if ch == c { + activeChildren = append(activeChildren[:i], activeChildren[i+1:]...) + break + } + } + activeMu.Unlock() +} + +func KillAllChildProcesses() { + activeMu.Lock() + all := make([]*exec.Cmd, len(activeChildren)) + copy(all, activeChildren) + activeChildren = nil + activeMu.Unlock() + for _, c := range all { + killProcessGroup(c) + } +} + +// ─── Spawn ──────────────────────────────────────────────────────────────── + +func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd { + envSrc := env.OsEnvToMap() + resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd) + + if opts.MaxMode && maxModeFn != nil { + maxModeFn(resolved.AgentScriptPath, opts.ConfigDir) + } + + envMap := make(map[string]string, len(resolved.Env)) + for k, v := range resolved.Env { + envMap[k] = v + } + if opts.ConfigDir != "" { + envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir + } else if resolved.ConfigDir != "" { + if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists { + envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir + } + } + + envSlice := make([]string, 0, len(envMap)) + for k, v := range envMap { + envSlice = append(envSlice, k+"="+v) + } + + ctx := opts.Ctx + if ctx == nil { + ctx = context.Background() + } + + // 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出 + c := exec.CommandContext(ctx, resolved.Command, resolved.Args...) + c.Dir = opts.Cwd + c.Env = envSlice + // 設定新的 process group,使 kill 能傳遞給所有子孫程序 + c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + // WaitDelay:context cancel 後額外等待這麼久再強制關閉 pipes + c.WaitDelay = 5 * time.Second + // Cancel 函式:殺死整個 process group + c.Cancel = func() error { + return killProcessGroup(c) + } + return c +} + +// MaxModeFn is set by the agent package to avoid import cycle. +var MaxModeFn func(agentScriptPath, configDir string) + +func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) { + ctx := opts.Ctx + var cancel context.CancelFunc + if opts.TimeoutMs > 0 { + if ctx == nil { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond) + } else { + ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond) + } + defer cancel() + opts.Ctx = ctx + } else if ctx == nil { + opts.Ctx = context.Background() + } + + c := spawnChild(cmdStr, args, &opts, MaxModeFn) + var stdoutBuf, stderrBuf strings.Builder + c.Stdout = &stdoutBuf + c.Stderr = &stderrBuf + + if err := c.Start(); err != nil { + // context 已取消或命令找不到時 + if opts.Ctx != nil && opts.Ctx.Err() != nil { + return RunResult{Code: -1}, nil + } + if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") { + return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr) + } + return RunResult{}, err + } + registerChild(c) + defer unregisterChild(c) + + err := c.Wait() + code := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + code = exitErr.ExitCode() + if code == 0 { + code = -1 + } + } else { + // context cancelled or killed — return -1 but no error + return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil + } + } + return RunResult{ + Code: code, + Stdout: stdoutBuf.String(), + Stderr: stderrBuf.String(), + }, nil +} + +type StreamResult struct { + Code int + Stderr string +} + +func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) { + ctx := opts.Ctx + var cancel context.CancelFunc + if opts.TimeoutMs > 0 { + if ctx == nil { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond) + } else { + ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond) + } + defer cancel() + opts.RunOptions.Ctx = ctx + } else if opts.RunOptions.Ctx == nil { + opts.RunOptions.Ctx = context.Background() + } + + c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn) + stdoutPipe, err := c.StdoutPipe() + if err != nil { + return StreamResult{}, err + } + stderrPipe, err := c.StderrPipe() + if err != nil { + return StreamResult{}, err + } + + if err := c.Start(); err != nil { + if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") { + return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr) + } + return StreamResult{}, err + } + registerChild(c) + defer unregisterChild(c) + + var stderrBuf strings.Builder + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) != "" { + opts.OnLine(line) + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + scanner := bufio.NewScanner(stderrPipe) + for scanner.Scan() { + stderrBuf.WriteString(scanner.Text()) + stderrBuf.WriteString("\n") + } + }() + + wg.Wait() + err = c.Wait() + code := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + code = exitErr.ExitCode() + if code == 0 { + code = -1 + } + } + } + return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil +} diff --git a/internal/process/process_test.go b/internal/process/process_test.go new file mode 100644 index 0000000..c48d4b5 --- /dev/null +++ b/internal/process/process_test.go @@ -0,0 +1,283 @@ +package process_test + +import ( + "context" + "cursor-api-proxy/internal/process" + "os" + "testing" + "time" +) + +// sh 是跨平台 shell 執行小 script 的輔助函式 +func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) { + t.Helper() + return process.Run("sh", []string{"-c", script}, opts) +} + +func TestRun_StdoutAndStderr(t *testing.T) { + result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{}) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Errorf("Code = %d, want 0", result.Code) + } + if result.Stdout != "hello\n" { + t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n") + } + if result.Stderr != "world\n" { + t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n") + } +} + +func TestRun_BasicSpawn(t *testing.T) { + result, err := sh(t, "printf ok", process.RunOptions{}) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Errorf("Code = %d, want 0", result.Code) + } + if result.Stdout != "ok" { + t.Errorf("Stdout = %q, want ok", result.Stdout) + } +} + +func TestRun_ConfigDir_Propagated(t *testing.T) { + result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`}, + process.RunOptions{ConfigDir: "/test/account/dir"}) + if err != nil { + t.Fatal(err) + } + if result.Stdout != "/test/account/dir" { + t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout) + } +} + +func TestRun_ConfigDir_Absent(t *testing.T) { + // 確保沒有殘留的環境變數 + _ = os.Unsetenv("CURSOR_CONFIG_DIR") + result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`}, + process.RunOptions{}) + if err != nil { + t.Fatal(err) + } + if result.Stdout != "unset" { + t.Errorf("Stdout = %q, want unset", result.Stdout) + } +} + +func TestRun_NonZeroExit(t *testing.T) { + result, err := sh(t, "exit 42", process.RunOptions{}) + if err != nil { + t.Fatal(err) + } + if result.Code != 42 { + t.Errorf("Code = %d, want 42", result.Code) + } +} + +func TestRun_Timeout(t *testing.T) { + start := time.Now() + result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300}) + elapsed := time.Since(start) + if err != nil { + t.Fatal(err) + } + if result.Code == 0 { + t.Error("expected non-zero exit code after timeout") + } + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestRunStreaming_OnLine(t *testing.T) { + var lines []string + result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"}, + process.RunStreamingOptions{ + OnLine: func(line string) { lines = append(lines, line) }, + }) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Errorf("Code = %d, want 0", result.Code) + } + if len(lines) != 3 { + t.Errorf("got %d lines, want 3: %v", len(lines), lines) + } + if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" { + t.Errorf("lines = %v, want [a b c]", lines) + } +} + +func TestRunStreaming_FlushFinalLine(t *testing.T) { + var lines []string + result, err := process.RunStreaming("sh", []string{"-c", "printf tail"}, + process.RunStreamingOptions{ + OnLine: func(line string) { lines = append(lines, line) }, + }) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Errorf("Code = %d, want 0", result.Code) + } + if len(lines) != 1 { + t.Errorf("got %d lines, want 1: %v", len(lines), lines) + } + if lines[0] != "tail" { + t.Errorf("lines[0] = %q, want tail", lines[0]) + } +} + +func TestRunStreaming_ConfigDir(t *testing.T) { + var lines []string + _, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`}, + process.RunStreamingOptions{ + RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"}, + OnLine: func(line string) { lines = append(lines, line) }, + }) + if err != nil { + t.Fatal(err) + } + if len(lines) != 1 || lines[0] != "/my/config/dir" { + t.Errorf("lines = %v, want [/my/config/dir]", lines) + } +} + +func TestRunStreaming_Stderr(t *testing.T) { + result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"}, + process.RunStreamingOptions{OnLine: func(string) {}}) + if err != nil { + t.Fatal(err) + } + if result.Stderr == "" { + t.Error("expected stderr to contain output") + } +} + +func TestRunStreaming_Timeout(t *testing.T) { + start := time.Now() + result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"}, + process.RunStreamingOptions{ + RunOptions: process.RunOptions{TimeoutMs: 300}, + OnLine: func(string) {}, + }) + elapsed := time.Since(start) + if err != nil { + t.Fatal(err) + } + if result.Code == 0 { + t.Error("expected non-zero exit code after timeout") + } + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestRunStreaming_Concurrent(t *testing.T) { + var lines1, lines2 []string + done := make(chan struct{}, 2) + + run := func(label string, target *[]string) { + process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"}, + process.RunStreamingOptions{ + OnLine: func(line string) { *target = append(*target, line) }, + }) + done <- struct{}{} + } + + go run("stream1", &lines1) + go run("stream2", &lines2) + + <-done + <-done + + if len(lines1) != 1 || lines1[0] != "stream1" { + t.Errorf("lines1 = %v, want [stream1]", lines1) + } + if len(lines2) != 1 || lines2[0] != "stream2" { + t.Errorf("lines2 = %v, want [stream2]", lines2) + } +} + +func TestRunStreaming_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + start := time.Now() + done := make(chan struct{}) + + go func() { + process.RunStreaming("sh", []string{"-c", "sleep 30"}, + process.RunStreamingOptions{ + RunOptions: process.RunOptions{Ctx: ctx}, + OnLine: func(string) {}, + }) + close(done) + }() + + time.AfterFunc(100*time.Millisecond, cancel) + <-done + elapsed := time.Since(start) + + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestRun_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + start := time.Now() + done := make(chan process.RunResult, 1) + + go func() { + r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx}) + done <- r + }() + + time.AfterFunc(100*time.Millisecond, cancel) + result := <-done + elapsed := time.Since(start) + + if result.Code == 0 { + t.Error("expected non-zero exit code after cancel") + } + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestRun_AlreadyCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 已取消 + + start := time.Now() + result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx}) + elapsed := time.Since(start) + + if result.Code == 0 { + t.Error("expected non-zero exit code") + } + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestKillAllChildProcesses(t *testing.T) { + done := make(chan process.RunResult, 1) + go func() { + r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{}) + done <- r + }() + + time.Sleep(80 * time.Millisecond) + process.KillAllChildProcesses() + result := <-done + + if result.Code == 0 { + t.Error("expected non-zero exit code after kill") + } + // 再次呼叫不應 panic + process.KillAllChildProcesses() +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000..5d9f6e2 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,136 @@ +package router + +import ( + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/handlers" + "cursor-api-proxy/internal/httputil" + "cursor-api-proxy/internal/logger" + "fmt" + "net/http" + "os" + "time" +) + +type RouterOptions struct { + Version string + Config config.BridgeConfig + ModelCache *handlers.ModelCacheRef + LastModel *string +} + +func NewRouter(opts RouterOptions) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + cfg := opts.Config + pathname := r.URL.Path + method := r.Method + remoteAddress := r.RemoteAddr + if r.Header.Get("X-Real-IP") != "" { + remoteAddress = r.Header.Get("X-Real-IP") + } + + logger.LogIncoming(method, pathname, remoteAddress) + + defer func() { + logger.AppendSessionLine(cfg.SessionsLogPath, method, pathname, remoteAddress, 200) + }() + + if cfg.RequiredKey != "" { + token := httputil.ExtractBearerToken(r) + if token != cfg.RequiredKey { + httputil.WriteJSON(w, 401, map[string]interface{}{ + "error": map[string]string{"message": "Invalid API key", "code": "unauthorized"}, + }, nil) + return + } + } + + switch { + case method == "GET" && pathname == "/health": + handlers.HandleHealth(w, r, opts.Version, cfg) + + case method == "GET" && pathname == "/v1/models": + opts.ModelCache.HandleModels(w, r, cfg) + + case method == "POST" && pathname == "/v1/chat/completions": + raw, err := httputil.ReadBody(r) + if err != nil { + httputil.WriteJSON(w, 400, map[string]interface{}{ + "error": map[string]string{"message": "failed to read body", "code": "bad_request"}, + }, nil) + return + } + handlers.HandleChatCompletions(w, r, cfg, opts.LastModel, raw, method, pathname, remoteAddress) + + case method == "POST" && pathname == "/v1/messages": + raw, err := httputil.ReadBody(r) + if err != nil { + httputil.WriteJSON(w, 400, map[string]interface{}{ + "error": map[string]string{"message": "failed to read body", "code": "bad_request"}, + }, nil) + return + } + handlers.HandleAnthropicMessages(w, r, cfg, opts.LastModel, raw, method, pathname, remoteAddress) + + case (method == "POST" || method == "GET") && pathname == "/v1/completions": + httputil.WriteJSON(w, 404, map[string]interface{}{ + "error": map[string]string{ + "message": "Legacy completions endpoint is not supported. Use POST /v1/chat/completions instead.", + "code": "not_found", + }, + }, nil) + + case pathname == "/v1/embeddings": + httputil.WriteJSON(w, 404, map[string]interface{}{ + "error": map[string]string{ + "message": "Embeddings are not supported by this proxy.", + "code": "not_found", + }, + }, nil) + + default: + httputil.WriteJSON(w, 404, map[string]interface{}{ + "error": map[string]string{"message": "Not found", "code": "not_found"}, + }, nil) + } + } +} + +func recoveryMiddleware(logPath string, next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + msg := fmt.Sprintf("%v", rec) + fmt.Fprintf(os.Stderr, "[%s] Proxy panic: %s\n", time.Now().UTC().Format(time.RFC3339), msg) + line := fmt.Sprintf("%s ERROR %s %s %s %s\n", + time.Now().UTC().Format(time.RFC3339), r.Method, r.URL.Path, r.RemoteAddr, + msg[:min(200, len(msg))]) + if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { + _, _ = f.WriteString(line) + f.Close() + } + if !isHeaderWritten(w) { + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"message": msg, "code": "internal_error"}, + }, nil) + } + } + }() + next(w, r) + } +} + +func isHeaderWritten(w http.ResponseWriter) bool { + // Can't reliably detect without wrapping; always try to write + return false +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func WrapWithRecovery(logPath string, handler http.HandlerFunc) http.HandlerFunc { + return recoveryMiddleware(logPath, handler) +} diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go new file mode 100644 index 0000000..dd514d2 --- /dev/null +++ b/internal/sanitize/sanitize.go @@ -0,0 +1,95 @@ +package sanitize + +import "regexp" + +type rule struct { + pattern *regexp.Regexp + replacement string +} + +var rules = []rule{ + {regexp.MustCompile(`(?i)x-anthropic-billing-header:[^\n]*\n?`), ""}, + {regexp.MustCompile(`(?i)\bcc_version=[^\s;,\n]+[;,]?\s*`), ""}, + {regexp.MustCompile(`(?i)\bcc_entrypoint=[^\s;,\n]+[;,]?\s*`), ""}, + {regexp.MustCompile(`(?i)\bcch=[a-f0-9]+[;,]?\s*`), ""}, + {regexp.MustCompile(`\bClaude Code\b`), "Cursor"}, + {regexp.MustCompile(`(?i)Anthropic['']s official CLI for Claude`), "Cursor AI assistant"}, + {regexp.MustCompile(`\bAnthropic\b`), "Cursor"}, + {regexp.MustCompile(`(?i)anthropic\.com`), "cursor.com"}, + {regexp.MustCompile(`(?i)claude\.ai`), "cursor.sh"}, + {regexp.MustCompile(`^[;,\s]+`), ""}, +} + +func SanitizeText(text string) string { + for _, r := range rules { + text = r.pattern.ReplaceAllString(text, r.replacement) + } + return text +} + +func SanitizeMessages(messages []interface{}) []interface{} { + result := make([]interface{}, len(messages)) + for i, raw := range messages { + if raw == nil { + result[i] = raw + continue + } + m, ok := raw.(map[string]interface{}) + if !ok { + result[i] = raw + continue + } + newMsg := make(map[string]interface{}, len(m)) + for k, v := range m { + newMsg[k] = v + } + switch c := m["content"].(type) { + case string: + newMsg["content"] = SanitizeText(c) + case []interface{}: + newParts := make([]interface{}, len(c)) + for j, p := range c { + if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" { + if t, ok := pm["text"].(string); ok { + newPart := make(map[string]interface{}, len(pm)) + for k, v := range pm { + newPart[k] = v + } + newPart["text"] = SanitizeText(t) + newParts[j] = newPart + continue + } + } + newParts[j] = p + } + newMsg["content"] = newParts + } + result[i] = newMsg + } + return result +} + +func SanitizeSystem(system interface{}) interface{} { + switch v := system.(type) { + case string: + return SanitizeText(v) + case []interface{}: + result := make([]interface{}, len(v)) + for i, p := range v { + if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" { + if t, ok := pm["text"].(string); ok { + newPart := make(map[string]interface{}, len(pm)) + for k, val := range pm { + newPart[k] = val + } + newPart["text"] = SanitizeText(t) + result[i] = newPart + continue + } + } + result[i] = p + } + return result + } + return system +} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go new file mode 100644 index 0000000..59886a2 --- /dev/null +++ b/internal/sanitize/sanitize_test.go @@ -0,0 +1,60 @@ +package sanitize + +import ( + "strings" + "testing" +) + +func TestSanitizeTextAnthropicBilling(t *testing.T) { + input := "x-anthropic-billing-header: abc123\nHello" + got := SanitizeText(input) + if strings.Contains(got, "x-anthropic-billing-header") { + t.Errorf("billing header not removed: %q", got) + } +} + +func TestSanitizeTextClaudeCode(t *testing.T) { + input := "I am Claude Code assistant" + got := SanitizeText(input) + if strings.Contains(got, "Claude Code") { + t.Errorf("'Claude Code' not replaced: %q", got) + } + if !strings.Contains(got, "Cursor") { + t.Errorf("'Cursor' not present in output: %q", got) + } +} + +func TestSanitizeTextAnthropic(t *testing.T) { + input := "Powered by Anthropic's technology and anthropic.com" + got := SanitizeText(input) + if strings.Contains(got, "Anthropic") { + t.Errorf("'Anthropic' not replaced: %q", got) + } + if strings.Contains(got, "anthropic.com") { + t.Errorf("'anthropic.com' not replaced: %q", got) + } +} + +func TestSanitizeTextNoChange(t *testing.T) { + input := "Hello, this is a normal message about cursor." + got := SanitizeText(input) + if got != input { + t.Errorf("unexpected change: %q -> %q", input, got) + } +} + +func TestSanitizeMessages(t *testing.T) { + messages := []interface{}{ + map[string]interface{}{"role": "user", "content": "Ask Claude Code something"}, + map[string]interface{}{"role": "system", "content": "Use Anthropic's tools"}, + } + result := SanitizeMessages(messages) + + for _, raw := range result { + m := raw.(map[string]interface{}) + c := m["content"].(string) + if strings.Contains(c, "Claude Code") || strings.Contains(c, "Anthropic") { + t.Errorf("found unsanitized content: %q", c) + } + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..a992c06 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,180 @@ +package server + +import ( + "context" + "crypto/tls" + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/handlers" + "cursor-api-proxy/internal/pool" + "cursor-api-proxy/internal/process" + "cursor-api-proxy/internal/router" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) + +type ServerOptions struct { + Version string + Config config.BridgeConfig +} + +func StartBridgeServer(opts ServerOptions) []*http.Server { + cfg := opts.Config + var servers []*http.Server + + if len(cfg.ConfigDirs) > 0 { + if cfg.MultiPort { + for i, dir := range cfg.ConfigDirs { + port := cfg.Port + i + subCfg := cfg + subCfg.Port = port + subCfg.ConfigDirs = []string{dir} + subCfg.MultiPort = false + pool.InitAccountPool([]string{dir}) + srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg}) + servers = append(servers, srv) + } + return servers + } + pool.InitAccountPool(cfg.ConfigDirs) + } + + servers = append(servers, startSingleServer(opts)) + return servers +} + +func startSingleServer(opts ServerOptions) *http.Server { + cfg := opts.Config + + modelCache := &handlers.ModelCacheRef{} + lastModel := cfg.DefaultModel + + handler := router.NewRouter(router.RouterOptions{ + Version: opts.Version, + Config: cfg, + ModelCache: modelCache, + LastModel: &lastModel, + }) + handler = router.WrapWithRecovery(cfg.SessionsLogPath, handler) + + useTLS := cfg.TLSCertPath != "" && cfg.TLSKeyPath != "" + + srv := &http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Handler: handler, + } + + if useTLS { + cert, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath) + if err != nil { + fmt.Fprintf(os.Stderr, "TLS error: %v\n", err) + os.Exit(1) + } + srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} + } + + scheme := "http" + if useTLS { + scheme = "https" + } + + go func() { + var err error + if useTLS { + err = srv.ListenAndServeTLS("", "") + } else { + err = srv.ListenAndServe() + } + if err != nil && err != http.ErrServerClosed { + if isAddrInUse(err) { + fmt.Fprintf(os.Stderr, "❌ Port %d is already in use. Set CURSOR_BRIDGE_PORT to use a different port.\n", cfg.Port) + } else { + fmt.Fprintf(os.Stderr, "❌ Server error: %v\n", err) + } + os.Exit(1) + } + }() + + fmt.Printf("cursor-api-proxy listening on %s://%s:%d\n", scheme, cfg.Host, cfg.Port) + fmt.Printf("- agent bin: %s\n", cfg.AgentBin) + fmt.Printf("- workspace: %s\n", cfg.Workspace) + fmt.Printf("- mode: %s\n", cfg.Mode) + fmt.Printf("- default model: %s\n", cfg.DefaultModel) + fmt.Printf("- force: %v\n", cfg.Force) + fmt.Printf("- approve mcps: %v\n", cfg.ApproveMcps) + fmt.Printf("- required api key: %v\n", cfg.RequiredKey != "") + fmt.Printf("- sessions log: %s\n", cfg.SessionsLogPath) + if cfg.ChatOnlyWorkspace { + fmt.Println("- chat-only workspace: yes (isolated temp dir)") + } else { + fmt.Println("- chat-only workspace: no") + } + if cfg.Verbose { + fmt.Println("- verbose traffic: yes (CURSOR_BRIDGE_VERBOSE=true)") + } else { + fmt.Println("- verbose traffic: no") + } + if cfg.MaxMode { + fmt.Println("- max mode: yes (CURSOR_BRIDGE_MAX_MODE=true)") + } else { + fmt.Println("- max mode: no") + } + fmt.Printf("- Windows cmdline budget: %d (prompt tail truncation when over limit; Windows only)\n", cfg.WinCmdlineMax) + if len(cfg.ConfigDirs) > 0 { + fmt.Printf("- account pool: enabled with %d configuration directories\n", len(cfg.ConfigDirs)) + } + + return srv +} + +func SetupGracefulShutdown(servers []*http.Server, timeoutMs int) { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + + go func() { + sig := <-sigCh + fmt.Printf("\n[%s] %s received — shutting down gracefully…\n", + time.Now().UTC().Format(time.RFC3339), sig) + + process.KillAllChildProcesses() + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) + defer cancel() + + done := make(chan struct{}) + go func() { + for _, srv := range servers { + _ = srv.Shutdown(ctx) + } + close(done) + }() + + select { + case <-done: + os.Exit(0) + case <-ctx.Done(): + fmt.Fprintln(os.Stderr, "[shutdown] Timed out waiting for connections to drain — forcing exit.") + os.Exit(1) + } + }() +} + +func isAddrInUse(err error) bool { + return err != nil && (contains(err.Error(), "address already in use") || contains(err.Error(), "bind: address already in use")) +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsHelper(s, sub)) +} + +func containsHelper(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..eba9262 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,331 @@ +package server_test + +import ( + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/server" + "encoding/json" + "fmt" + "io" + "net" + "context" + "net/http" + "os" + "strings" + "testing" + "time" +) + +// freePort 取得一個暫時可用的隨機 port +func freePort(t *testing.T) int { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := l.Addr().(*net.TCPAddr).Port + l.Close() + return port +} + +// makeFakeAgentBin 建立一個 shell script,模擬 agent 固定輸出 +// sync 模式:直接輸出一行文字 +// stream 模式:輸出 JSON stream 行 +func makeFakeAgentBin(t *testing.T, syncOutput string) string { + t.Helper() + dir := t.TempDir() + script := dir + "/agent" + content := fmt.Sprintf(`#!/bin/sh +# 若有 --stream-json 則輸出 stream 格式 +for arg; do + if [ "$arg" = "--stream-json" ]; then + printf '%%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"%s"}]}}' + printf '%%s\n' '{"type":"result","subtype":"success"}' + exit 0 + fi +done +# 否則輸出 sync 格式 +printf '%%s' '%s' +`, syncOutput, syncOutput) + if err := os.WriteFile(script, []byte(content), 0755); err != nil { + t.Fatal(err) + } + return script +} + +// makeFakeAgentBinWithModels 額外支援 --list-models 輸出 +func makeFakeAgentBinWithModels(t *testing.T) string { + t.Helper() + dir := t.TempDir() + script := dir + "/agent" + content := `#!/bin/sh +for arg; do + if [ "$arg" = "--list-models" ]; then + printf 'claude-3-opus - Claude 3 Opus\n' + printf 'claude-3-sonnet - Claude 3 Sonnet\n' + exit 0 + fi + if [ "$arg" = "--stream-json" ]; then + printf '%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}' + printf '%s\n' '{"type":"result","subtype":"success"}' + exit 0 + fi +done +printf 'Hello from agent' +` + if err := os.WriteFile(script, []byte(content), 0755); err != nil { + t.Fatal(err) + } + return script +} + +func makeTestConfig(agentBin string, port int, overrides ...func(*config.BridgeConfig)) config.BridgeConfig { + cfg := config.BridgeConfig{ + AgentBin: agentBin, + Host: "127.0.0.1", + Port: port, + DefaultModel: "auto", + Mode: "ask", + Force: false, + ApproveMcps: false, + StrictModel: true, + Workspace: os.TempDir(), + TimeoutMs: 30000, + SessionsLogPath: os.TempDir() + "/test-sessions.log", + ChatOnlyWorkspace: true, + Verbose: false, + MaxMode: false, + ConfigDirs: []string{}, + MultiPort: false, + WinCmdlineMax: 30000, + } + for _, fn := range overrides { + fn(&cfg) + } + return cfg +} + +func waitListening(t *testing.T, host string, port int, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 50*time.Millisecond) + if err == nil { + conn.Close() + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("server on port %d did not start within %v", port, timeout) +} + +func doRequest(t *testing.T, method, url, body string, headers map[string]string) (int, string) { + t.Helper() + var reqBody io.Reader + if body != "" { + reqBody = strings.NewReader(body) + } + req, err := http.NewRequest(method, url, reqBody) + if err != nil { + t.Fatal(err) + } + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + data, _ := io.ReadAll(resp.Body) + return resp.StatusCode, string(data) +} + +func TestBridgeServer_Health(t *testing.T) { + port := freePort(t) + agentBin := makeFakeAgentBinWithModels(t) + cfg := makeTestConfig(agentBin, port) + + srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) + waitListening(t, "127.0.0.1", port, 3*time.Second) + defer func() { + for _, s := range srvs { + s.Shutdown(context.Background()) + } + }() + + status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil) + if status != 200 { + t.Fatalf("status = %d, want 200; body: %s", status, body) + } + var result map[string]interface{} + json.Unmarshal([]byte(body), &result) + if result["ok"] != true { + t.Errorf("ok = %v, want true", result["ok"]) + } + if result["version"] != "1.0.0" { + t.Errorf("version = %v, want 1.0.0", result["version"]) + } +} + +func TestBridgeServer_Models(t *testing.T) { + port := freePort(t) + agentBin := makeFakeAgentBinWithModels(t) + cfg := makeTestConfig(agentBin, port) + + srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) + waitListening(t, "127.0.0.1", port, 3*time.Second) + defer func() { + for _, s := range srvs { + s.Shutdown(context.Background()) + } + }() + + status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/v1/models", port), "", nil) + if status != 200 { + t.Fatalf("status = %d, want 200; body: %s", status, body) + } + var result map[string]interface{} + json.Unmarshal([]byte(body), &result) + if result["object"] != "list" { + t.Errorf("object = %v, want list", result["object"]) + } + data := result["data"].([]interface{}) + if len(data) < 2 { + t.Errorf("data len = %d, want >= 2", len(data)) + } +} + +func TestBridgeServer_Unauthorized(t *testing.T) { + port := freePort(t) + agentBin := makeFakeAgentBinWithModels(t) + cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) { + c.RequiredKey = "secret123" + }) + + srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) + waitListening(t, "127.0.0.1", port, 3*time.Second) + defer func() { + for _, s := range srvs { + s.Shutdown(context.Background()) + } + }() + + status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil) + if status != 401 { + t.Fatalf("status = %d, want 401; body: %s", status, body) + } + var result map[string]interface{} + json.Unmarshal([]byte(body), &result) + errObj := result["error"].(map[string]interface{}) + if errObj["message"] != "Invalid API key" { + t.Errorf("message = %v, want 'Invalid API key'", errObj["message"]) + } +} + +func TestBridgeServer_AuthorizedKey(t *testing.T) { + port := freePort(t) + agentBin := makeFakeAgentBinWithModels(t) + cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) { + c.RequiredKey = "secret123" + }) + + srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) + waitListening(t, "127.0.0.1", port, 3*time.Second) + defer func() { + for _, s := range srvs { + s.Shutdown(context.Background()) + } + }() + + status, _ := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", map[string]string{ + "Authorization": "Bearer secret123", + }) + if status != 200 { + t.Errorf("status = %d, want 200", status) + } +} + +func TestBridgeServer_NotFound(t *testing.T) { + port := freePort(t) + agentBin := makeFakeAgentBinWithModels(t) + cfg := makeTestConfig(agentBin, port) + + srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) + waitListening(t, "127.0.0.1", port, 3*time.Second) + defer func() { + for _, s := range srvs { + s.Shutdown(context.Background()) + } + }() + + status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/unknown", port), "", nil) + if status != 404 { + t.Fatalf("status = %d, want 404; body: %s", status, body) + } + var result map[string]interface{} + json.Unmarshal([]byte(body), &result) + errObj := result["error"].(map[string]interface{}) + if errObj["code"] != "not_found" { + t.Errorf("code = %v, want not_found", errObj["code"]) + } +} + +func TestBridgeServer_ChatCompletions_Sync(t *testing.T) { + port := freePort(t) + agentBin := makeFakeAgentBin(t, "Hello from agent") + cfg := makeTestConfig(agentBin, port) + + srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) + waitListening(t, "127.0.0.1", port, 3*time.Second) + defer func() { + for _, s := range srvs { + s.Shutdown(context.Background()) + } + }() + + reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hi"}]}` + status, body := doRequest(t, "POST", fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", port), reqBody, nil) + if status != 200 { + t.Fatalf("status = %d, want 200; body: %s", status, body) + } + var result map[string]interface{} + json.Unmarshal([]byte(body), &result) + if result["object"] != "chat.completion" { + t.Errorf("object = %v, want chat.completion", result["object"]) + } + choices := result["choices"].([]interface{}) + msg := choices[0].(map[string]interface{})["message"].(map[string]interface{}) + if msg["content"] != "Hello from agent" { + t.Errorf("content = %v, want 'Hello from agent'", msg["content"]) + } +} + +func TestBridgeServer_MultiPort(t *testing.T) { + basePort := freePort(t) + agentBin := makeFakeAgentBinWithModels(t) + + dir1 := t.TempDir() + dir2 := t.TempDir() + + cfg := makeTestConfig(agentBin, basePort, func(c *config.BridgeConfig) { + c.ConfigDirs = []string{dir1, dir2} + c.MultiPort = true + }) + + srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg}) + if len(srvs) != 2 { + t.Fatalf("got %d servers, want 2", len(srvs)) + } + + // 等待兩個 server 啟動(port 可能會衝突,這裡不嚴格測試 port 分配) + time.Sleep(200 * time.Millisecond) + + defer func() { + for _, s := range srvs { + s.Shutdown(context.Background()) + } + }() +} diff --git a/internal/winlimit/winlimit.go b/internal/winlimit/winlimit.go new file mode 100644 index 0000000..035c447 --- /dev/null +++ b/internal/winlimit/winlimit.go @@ -0,0 +1,132 @@ +package winlimit + +import ( + "cursor-api-proxy/internal/env" + "runtime" +) + +const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n" + +type FitPromptResult struct { + OK bool + Args []string + Truncated bool + OriginalLength int + FinalPromptLength int + Error string +} + +func estimateCmdlineLength(resolved env.AgentCommand) int { + argv := append([]string{resolved.Command}, resolved.Args...) + if resolved.WindowsVerbatimArguments { + n := 0 + for _, a := range argv { + n += len(a) + } + if len(argv) > 1 { + n += len(argv) - 1 + } + return n + 512 + } + dstLen := 0 + for _, a := range argv { + dstLen += len(a) + } + dstLen = dstLen*2 + len(argv)*2 + if len(argv) > 1 { + dstLen += len(argv) - 1 + } + return dstLen + 512 +} + +func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult { + if runtime.GOOS != "windows" { + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = prompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: false, + OriginalLength: len(prompt), + FinalPromptLength: len(prompt), + } + } + + e := env.OsEnvToMap() + measured := func(p string) int { + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = p + resolved := env.ResolveAgentCommand(agentBin, args, e, cwd) + return estimateCmdlineLength(resolved) + } + + if measured("") > maxCmdline { + return FitPromptResult{ + OK: false, + Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.", + } + } + + if measured(prompt) <= maxCmdline { + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = prompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: false, + OriginalLength: len(prompt), + FinalPromptLength: len(prompt), + } + } + + prefix := WinPromptOmissionPrefix + if measured(prefix) > maxCmdline { + return FitPromptResult{ + OK: false, + Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.", + } + } + + lo, hi, best := 0, len(prompt), 0 + for lo <= hi { + mid := (lo + hi) / 2 + var tail string + if mid > 0 { + tail = prompt[len(prompt)-mid:] + } + candidate := prefix + tail + if measured(candidate) <= maxCmdline { + best = mid + lo = mid + 1 + } else { + hi = mid - 1 + } + } + + var finalPrompt string + if best == 0 { + finalPrompt = prefix + } else { + finalPrompt = prefix + prompt[len(prompt)-best:] + } + + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = finalPrompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: true, + OriginalLength: len(prompt), + FinalPromptLength: len(finalPrompt), + } +} + +func WarnPromptTruncated(originalLength, finalLength int) { + _ = originalLength + _ = finalLength + // fmt.Fprintf skipped to avoid import; caller may log as needed +} diff --git a/internal/winlimit/winlimit_test.go b/internal/winlimit/winlimit_test.go new file mode 100644 index 0000000..03d2488 --- /dev/null +++ b/internal/winlimit/winlimit_test.go @@ -0,0 +1,37 @@ +package winlimit + +import ( + "runtime" + "strings" + "testing" +) + +func TestNonWindowsPassThrough(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping non-Windows test on Windows") + } + + fixedArgs := []string{"--print", "--model", "gpt-4"} + prompt := "Hello world" + result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp") + + if !result.OK { + t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error) + } + if result.Truncated { + t.Error("expected no truncation on non-Windows") + } + if result.OriginalLength != len(prompt) { + t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength) + } + // Last arg should be the prompt + if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt { + t.Errorf("expected last arg to be prompt, got %v", result.Args) + } +} + +func TestOmissionPrefix(t *testing.T) { + if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") { + t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix) + } +} diff --git a/internal/workspace/workspace.go b/internal/workspace/workspace.go new file mode 100644 index 0000000..8fe76bd --- /dev/null +++ b/internal/workspace/workspace.go @@ -0,0 +1,30 @@ +package workspace + +import ( + "cursor-api-proxy/internal/config" + "os" + "path/filepath" + "strings" +) + +type WorkspaceResult struct { + WorkspaceDir string + TempDir string +} + +func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult { + if cfg.ChatOnlyWorkspace { + tempDir, err := os.MkdirTemp("", "cursor-proxy-") + if err != nil { + tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback") + _ = os.MkdirAll(tempDir, 0700) + } + return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir} + } + + headerWs := strings.TrimSpace(workspaceHeader) + if headerWs != "" { + return WorkspaceResult{WorkspaceDir: headerWs} + } + return WorkspaceResult{WorkspaceDir: cfg.Workspace} +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..47c2409 --- /dev/null +++ b/main.go @@ -0,0 +1,73 @@ +package main + +import ( + "cursor-api-proxy/cmd" + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/env" + "cursor-api-proxy/internal/server" + "fmt" + "os" +) + +const version = "1.0.0" + +func main() { + args, err := cmd.ParseArgs(os.Args[1:]) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + if args.Help { + cmd.PrintHelp(version) + return + } + + if args.Login { + if err := cmd.HandleLogin(args.AccountName, args.Proxies); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + return + } + + if args.Logout { + if err := cmd.HandleLogout(args.AccountName); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + return + } + + if args.AccountsList { + if err := cmd.HandleAccountsList(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + return + } + + if args.ResetHwid { + if err := cmd.HandleResetHwid(args.DeepClean, args.DryRun); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + return + } + + e := env.OsEnvToMap() + if args.Tailscale { + e["CURSOR_BRIDGE_HOST"] = "0.0.0.0" + } + + cwd, _ := os.Getwd() + cfg := config.LoadBridgeConfig(e, cwd) + + servers := server.StartBridgeServer(server.ServerOptions{ + Version: version, + Config: cfg, + }) + server.SetupGracefulShutdown(servers, 10000) + + select {} +}