Task 9: Cleanup - remove old internal files, update import paths, add go-zero entry point

- Removed old files from internal/* (migrated to pkg/*)
- Removed old CLI files from cmd/ (now in cmd/cli/)
- Updated import paths (internal/* -> pkg/*)
- Rewrote main.go to support CLI commands + go-zero HTTP server
- Fixed AccountStat type references (use entity.AccountStat)
- Fixed cmd/cli/* to use usecase package instead of agent
- Fixed logger to use entity.AccountStat
- Fixed geminiweb fmt.Println redundant newlines
- Fixed scripts/detect-gemini-dom.go pointer format issues
This commit is contained in:
王性驊 2026-04-03 22:54:18 +08:00
parent 7e0b7a970c
commit 081f404f77
68 changed files with 81 additions and 8771 deletions

View File

@ -1,196 +0,0 @@
package cmd
import (
"cursor-api-proxy/internal/agent"
"encoding/json"
"fmt"
"os"
"path/filepath"
)
type AccountInfo struct {
Name string
ConfigDir string
Authenticated bool
Email string
DisplayName string
AuthID string
Plan string
SubscriptionStatus string
ExpiresAt string
}
func ReadAccountInfo(name, configDir string) AccountInfo {
info := AccountInfo{Name: name, ConfigDir: configDir}
configFile := filepath.Join(configDir, "cli-config.json")
data, err := os.ReadFile(configFile)
if err != nil {
return info
}
var raw struct {
AuthInfo *struct {
Email string `json:"email"`
DisplayName string `json:"displayName"`
AuthID string `json:"authId"`
} `json:"authInfo"`
}
if err := json.Unmarshal(data, &raw); err == nil && raw.AuthInfo != nil {
info.Authenticated = true
info.Email = raw.AuthInfo.Email
info.DisplayName = raw.AuthInfo.DisplayName
info.AuthID = raw.AuthInfo.AuthID
}
statsigFile := filepath.Join(configDir, "statsig-cache.json")
statsigData, err := os.ReadFile(statsigFile)
if err != nil {
return info
}
var statsigWrapper struct {
Data string `json:"data"`
}
if err := json.Unmarshal(statsigData, &statsigWrapper); err != nil || statsigWrapper.Data == "" {
return info
}
var statsig struct {
User *struct {
Custom *struct {
IsEnterpriseUser bool `json:"isEnterpriseUser"`
StripeSubscriptionStatus string `json:"stripeSubscriptionStatus"`
StripeMembershipExpiration string `json:"stripeMembershipExpiration"`
} `json:"custom"`
} `json:"user"`
}
if err := json.Unmarshal([]byte(statsigWrapper.Data), &statsig); err != nil {
return info
}
if statsig.User != nil && statsig.User.Custom != nil {
c := statsig.User.Custom
if c.IsEnterpriseUser {
info.Plan = "Enterprise"
} else if c.StripeSubscriptionStatus == "active" {
info.Plan = "Pro"
} else {
info.Plan = "Free"
}
info.SubscriptionStatus = c.StripeSubscriptionStatus
info.ExpiresAt = c.StripeMembershipExpiration
}
return info
}
func HandleAccountsList() error {
accountsDir := agent.AccountsDir()
entries, err := os.ReadDir(accountsDir)
if err != nil {
fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.")
return nil
}
var names []string
for _, e := range entries {
if e.IsDir() {
names = append(names, e.Name())
}
}
if len(names) == 0 {
fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.")
return nil
}
fmt.Print("Cursor Accounts:\n\n")
keychainToken := agent.ReadKeychainToken()
for i, name := range names {
configDir := filepath.Join(accountsDir, name)
info := ReadAccountInfo(name, configDir)
fmt.Printf(" %d. %s\n", i+1, name)
if info.Authenticated {
cachedToken := agent.ReadCachedToken(configDir)
keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID
token := cachedToken
if token == "" && keychainMatchesAccount {
token = keychainToken
}
var liveProfile *StripeProfile
var liveUsage *UsageData
if token != "" {
liveUsage, _ = FetchAccountUsage(token)
liveProfile, _ = FetchStripeProfile(token)
}
if info.Email != "" {
display := ""
if info.DisplayName != "" {
display = " (" + info.DisplayName + ")"
}
fmt.Printf(" %s%s\n", info.Email, display)
}
if info.Plan != "" && liveProfile == nil {
canceled := ""
if info.SubscriptionStatus == "canceled" {
canceled = " · canceled"
}
expiry := ""
if info.ExpiresAt != "" {
expiry = " · expires " + info.ExpiresAt
}
fmt.Printf(" %s%s%s\n", info.Plan, canceled, expiry)
}
fmt.Println(" Authenticated")
if liveProfile != nil {
fmt.Printf(" %s\n", DescribePlan(liveProfile))
}
if liveUsage != nil {
for _, line := range FormatUsageSummary(liveUsage) {
fmt.Println(line)
}
}
} else {
fmt.Println(" Not authenticated")
}
fmt.Println("")
}
fmt.Println("Tip: run 'cursor-api-proxy logout <name>' to remove an account.")
return nil
}
func HandleLogout(accountName string) error {
if accountName == "" {
fmt.Fprintln(os.Stderr, "Error: Please specify the account name to remove.")
fmt.Fprintln(os.Stderr, "Usage: cursor-api-proxy logout <account-name>")
os.Exit(1)
}
accountsDir := agent.AccountsDir()
configDir := filepath.Join(accountsDir, accountName)
if _, err := os.Stat(configDir); os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Account '%s' not found.\n", accountName)
os.Exit(1)
}
if err := os.RemoveAll(configDir); err != nil {
fmt.Fprintf(os.Stderr, "Error removing account: %v\n", err)
os.Exit(1)
}
fmt.Printf("Account '%s' removed.\n", accountName)
return nil
}

View File

@ -1,118 +0,0 @@
package cmd
import "fmt"
type ParsedArgs struct {
Tailscale bool
Help bool
Login bool
AccountsList bool
Logout bool
AccountName string
Proxies []string
ResetHwid bool
DeepClean bool
DryRun bool
}
func ParseArgs(argv []string) (ParsedArgs, error) {
var args ParsedArgs
for i := 0; i < len(argv); i++ {
arg := argv[i]
switch arg {
case "login", "add-account":
args.Login = true
if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' {
i++
args.AccountName = argv[i]
}
case "logout", "remove-account":
args.Logout = true
if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' {
i++
args.AccountName = argv[i]
}
case "accounts", "list-accounts":
args.AccountsList = true
case "reset-hwid", "reset":
args.ResetHwid = true
case "--deep-clean":
args.DeepClean = true
case "--dry-run":
args.DryRun = true
case "--tailscale":
args.Tailscale = true
case "--help", "-h":
args.Help = true
default:
if len(arg) > len("--proxy=") && arg[:len("--proxy=")] == "--proxy=" {
raw := arg[len("--proxy="):]
parts := splitComma(raw)
for _, p := range parts {
if p != "" {
args.Proxies = append(args.Proxies, p)
}
}
} else {
return args, fmt.Errorf("Unknown argument: %s", arg)
}
}
}
return args, nil
}
func splitComma(s string) []string {
var result []string
start := 0
for i := 0; i <= len(s); i++ {
if i == len(s) || s[i] == ',' {
part := trim(s[start:i])
if part != "" {
result = append(result, part)
}
start = i + 1
}
}
return result
}
func trim(s string) string {
start := 0
end := len(s)
for start < end && (s[start] == ' ' || s[start] == '\t') {
start++
}
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
end--
}
return s[start:end]
}
func PrintHelp(version string) {
fmt.Printf("cursor-api-proxy v%s\n\n", version)
fmt.Println("Usage:")
fmt.Println(" cursor-api-proxy [options]")
fmt.Println("")
fmt.Println("Commands:")
fmt.Println(" login [name] Log into a Cursor account (saved to ~/.cursor-api-proxy/accounts/)")
fmt.Println(" login [name] --proxy=... Same, but with a proxy from a comma-separated list")
fmt.Println(" logout <name> Remove a saved Cursor account")
fmt.Println(" accounts List saved accounts with plan info")
fmt.Println(" reset-hwid Reset Cursor machine/telemetry IDs (anti-ban)")
fmt.Println(" reset-hwid --deep-clean Also wipe session storage and cookies")
fmt.Println("")
fmt.Println("Options:")
fmt.Println(" --tailscale Bind to 0.0.0.0 for tailnet/LAN access")
fmt.Println(" -h, --help Show this help message")
}

View File

@ -1,11 +1,12 @@
package cmd package cmd
import ( import (
"cursor-api-proxy/internal/agent"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"cursor-api-proxy/pkg/usecase"
) )
type AccountInfo struct { type AccountInfo struct {
@ -86,7 +87,7 @@ func ReadAccountInfo(name, configDir string) AccountInfo {
} }
func HandleAccountsList() error { func HandleAccountsList() error {
accountsDir := agent.AccountsDir() accountsDir := usecase.AccountsDir()
entries, err := os.ReadDir(accountsDir) entries, err := os.ReadDir(accountsDir)
if err != nil { if err != nil {
@ -108,7 +109,7 @@ func HandleAccountsList() error {
fmt.Print("Cursor Accounts:\n\n") fmt.Print("Cursor Accounts:\n\n")
keychainToken := agent.ReadKeychainToken() keychainToken := usecase.ReadKeychainToken()
for i, name := range names { for i, name := range names {
configDir := filepath.Join(accountsDir, name) configDir := filepath.Join(accountsDir, name)
@ -117,7 +118,7 @@ func HandleAccountsList() error {
fmt.Printf(" %d. %s\n", i+1, name) fmt.Printf(" %d. %s\n", i+1, name)
if info.Authenticated { if info.Authenticated {
cachedToken := agent.ReadCachedToken(configDir) cachedToken := usecase.ReadCachedToken(configDir)
keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID
token := cachedToken token := cachedToken
if token == "" && keychainMatchesAccount { if token == "" && keychainMatchesAccount {
@ -178,7 +179,7 @@ func HandleLogout(accountName string) error {
os.Exit(1) os.Exit(1)
} }
accountsDir := agent.AccountsDir() accountsDir := usecase.AccountsDir()
configDir := filepath.Join(accountsDir, accountName) configDir := filepath.Join(accountsDir, accountName)
if _, err := os.Stat(configDir); os.IsNotExist(err) { if _, err := os.Stat(configDir); os.IsNotExist(err) {

View File

@ -2,8 +2,6 @@ package cmd
import ( import (
"bufio" "bufio"
"cursor-api-proxy/internal/agent"
"cursor-api-proxy/internal/env"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
@ -12,6 +10,9 @@ import (
"regexp" "regexp"
"syscall" "syscall"
"time" "time"
"cursor-api-proxy/pkg/infrastructure/env"
"cursor-api-proxy/pkg/usecase"
) )
var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`) var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`)
@ -25,7 +26,7 @@ func HandleLogin(accountName string, proxies []string) error {
accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000) accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000)
} }
accountsDir := agent.AccountsDir() accountsDir := usecase.AccountsDir()
configDir := filepath.Join(accountsDir, accountName) configDir := filepath.Join(accountsDir, accountName)
dirWasNew := !fileExists(configDir) dirWasNew := !fileExists(configDir)
@ -110,9 +111,9 @@ func HandleLogin(accountName string, proxies []string) error {
} }
// Cache keychain token for this account // Cache keychain token for this account
token := agent.ReadKeychainToken() token := usecase.ReadKeychainToken()
if token != "" { if token != "" {
agent.WriteCachedToken(configDir, token) usecase.WriteCachedToken(configDir, token)
} }
fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName) fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName)

View File

@ -2,8 +2,8 @@ package main
import ( import (
"cursor-api-proxy/internal/config" "cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/env" "cursor-api-proxy/pkg/infrastructure/env"
"cursor-api-proxy/internal/providers/geminiweb" "cursor-api-proxy/pkg/provider/geminiweb"
"fmt" "fmt"
"os" "os"
"strings" "strings"

View File

@ -1,125 +0,0 @@
package cmd
import (
"bufio"
"cursor-api-proxy/internal/agent"
"cursor-api-proxy/internal/env"
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"syscall"
"time"
)
var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`)
func HandleLogin(accountName string, proxies []string) error {
e := env.OsEnvToMap()
loaded := env.LoadEnvConfig(e, "")
agentBin := loaded.AgentBin
if accountName == "" {
accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000)
}
accountsDir := agent.AccountsDir()
configDir := filepath.Join(accountsDir, accountName)
dirWasNew := !fileExists(configDir)
if err := os.MkdirAll(accountsDir, 0755); err != nil {
return fmt.Errorf("failed to create accounts dir: %w", err)
}
if err := os.MkdirAll(configDir, 0755); err != nil {
return fmt.Errorf("failed to create config dir: %w", err)
}
fmt.Printf("Logging into Cursor account: %s\n", accountName)
fmt.Printf("Config: %s\n\n", configDir)
fmt.Println("Run the login command — complete the login in your browser.")
fmt.Println("")
cleanupDir := func() {
if dirWasNew {
_ = os.RemoveAll(configDir)
}
}
cmdEnv := make([]string, 0, len(e)+2)
for k, v := range e {
cmdEnv = append(cmdEnv, k+"="+v)
}
cmdEnv = append(cmdEnv, "CURSOR_CONFIG_DIR="+configDir)
cmdEnv = append(cmdEnv, "NO_OPEN_BROWSER=1")
child := exec.Command(agentBin, "login")
child.Env = cmdEnv
child.Stdin = os.Stdin
child.Stderr = os.Stderr
stdoutPipe, err := child.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}
if err := child.Start(); err != nil {
cleanupDir()
if os.IsNotExist(err) {
return fmt.Errorf("could not find '%s'. Make sure the Cursor CLI is installed", agentBin)
}
return fmt.Errorf("error launching agent login: %w", err)
}
// Handle cancellation signals
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
go func() {
sig := <-sigCh
_ = child.Process.Kill()
cleanupDir()
if sig == syscall.SIGINT {
fmt.Println("\n\nLogin cancelled.")
}
os.Exit(0)
}()
defer signal.Stop(sigCh)
var stdoutBuf string
scanner := bufio.NewScanner(stdoutPipe)
for scanner.Scan() {
line := scanner.Text()
fmt.Println(line)
stdoutBuf += line + "\n"
if loginURLRe.MatchString(stdoutBuf) {
match := loginURLRe.FindString(stdoutBuf)
if match != "" {
fmt.Printf("\nOpen this URL in your browser (incognito recommended):\n%s\n\n", match)
}
}
}
if err := child.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
cleanupDir()
return fmt.Errorf("login failed with code %d", exitErr.ExitCode())
}
return err
}
// Cache keychain token for this account
token := agent.ReadKeychainToken()
if token != "" {
agent.WriteCachedToken(configDir, token)
}
fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName)
return nil
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}

View File

@ -1,261 +0,0 @@
package cmd
import (
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"time"
"github.com/google/uuid"
)
func sha256hex() string {
b := make([]byte, 32)
_, _ = rand.Read(b)
h := sha256.Sum256(b)
return hex.EncodeToString(h[:])
}
func sha512hex() string {
b := make([]byte, 64)
_, _ = rand.Read(b)
h := sha512.Sum512(b)
return hex.EncodeToString(h[:])
}
func newUUID() string {
return uuid.New().String()
}
func log(icon, msg string) {
fmt.Printf(" %s %s\n", icon, msg)
}
func getCursorGlobalStorage() string {
switch runtime.GOOS {
case "darwin":
home, _ := os.UserHomeDir()
return filepath.Join(home, "Library", "Application Support", "Cursor", "User", "globalStorage")
case "windows":
appdata := os.Getenv("APPDATA")
return filepath.Join(appdata, "Cursor", "User", "globalStorage")
default:
xdg := os.Getenv("XDG_CONFIG_HOME")
if xdg == "" {
home, _ := os.UserHomeDir()
xdg = filepath.Join(home, ".config")
}
return filepath.Join(xdg, "Cursor", "User", "globalStorage")
}
}
func getCursorRoot() string {
gs := getCursorGlobalStorage()
return filepath.Dir(filepath.Dir(gs))
}
func generateNewIDs() map[string]string {
return map[string]string{
"telemetry.machineId": sha256hex(),
"telemetry.macMachineId": sha512hex(),
"telemetry.devDeviceId": newUUID(),
"telemetry.sqmId": "{" + fmt.Sprintf("%s", newUUID()+"") + "}",
"storage.serviceMachineId": newUUID(),
}
}
func killCursor() {
log("", "Stopping Cursor processes...")
switch runtime.GOOS {
case "windows":
exec.Command("taskkill", "/F", "/IM", "Cursor.exe").Run()
default:
exec.Command("pkill", "-x", "Cursor").Run()
exec.Command("pkill", "-f", "Cursor.app").Run()
}
log("", "Cursor stopped (or was not running)")
}
func updateStorageJSON(storagePath string, ids map[string]string) {
if _, err := os.Stat(storagePath); os.IsNotExist(err) {
log("", fmt.Sprintf("storage.json not found: %s", storagePath))
return
}
if runtime.GOOS == "darwin" {
exec.Command("chflags", "nouchg", storagePath).Run()
exec.Command("chmod", "644", storagePath).Run()
}
data, err := os.ReadFile(storagePath)
if err != nil {
log("", fmt.Sprintf("storage.json read error: %v", err))
return
}
var obj map[string]interface{}
if err := json.Unmarshal(data, &obj); err != nil {
log("", fmt.Sprintf("storage.json parse error: %v", err))
return
}
for k, v := range ids {
obj[k] = v
}
out, err := json.MarshalIndent(obj, "", " ")
if err != nil {
log("", fmt.Sprintf("storage.json marshal error: %v", err))
return
}
if err := os.WriteFile(storagePath, out, 0644); err != nil {
log("", fmt.Sprintf("storage.json write error: %v", err))
return
}
log("", "storage.json updated")
}
func updateStateVscdb(dbPath string, ids map[string]string) {
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
log("", fmt.Sprintf("state.vscdb not found: %s", dbPath))
return
}
if runtime.GOOS == "darwin" {
exec.Command("chflags", "nouchg", dbPath).Run()
exec.Command("chmod", "644", dbPath).Run()
}
if err := updateVscdbPureGo(dbPath, ids); err != nil {
log("", fmt.Sprintf("state.vscdb error: %v", err))
} else {
log("", "state.vscdb updated")
}
}
func updateMachineIDFile(machineID, cursorRoot string) {
var candidates []string
if runtime.GOOS == "linux" {
candidates = []string{
filepath.Join(cursorRoot, "machineid"),
filepath.Join(cursorRoot, "machineId"),
}
} else {
candidates = []string{filepath.Join(cursorRoot, "machineId")}
}
filePath := candidates[0]
for _, c := range candidates {
if _, err := os.Stat(c); err == nil {
filePath = c
break
}
}
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
log("", fmt.Sprintf("machineId dir error: %v", err))
return
}
if runtime.GOOS == "darwin" {
if _, err := os.Stat(filePath); err == nil {
exec.Command("chflags", "nouchg", filePath).Run()
exec.Command("chmod", "644", filePath).Run()
}
}
if err := os.WriteFile(filePath, []byte(machineID+"\n"), 0644); err != nil {
log("", fmt.Sprintf("machineId write error: %v", err))
return
}
log("", fmt.Sprintf("machineId file updated (%s)", filepath.Base(filePath)))
}
var dirsToWipe = []string{
"Session Storage", "Local Storage", "IndexedDB", "Cache", "Code Cache",
"GPUCache", "Service Worker", "Network", "Cookies", "Cookies-journal",
}
func deepClean(cursorRoot string) {
log("", "Deep-cleaning session data...")
wiped := 0
for _, name := range dirsToWipe {
target := filepath.Join(cursorRoot, name)
if _, err := os.Stat(target); os.IsNotExist(err) {
continue
}
info, err := os.Stat(target)
if err != nil {
continue
}
if info.IsDir() {
if err := os.RemoveAll(target); err == nil {
wiped++
}
} else {
if err := os.Remove(target); err == nil {
wiped++
}
}
}
log("", fmt.Sprintf("Wiped %d cache/session items", wiped))
}
func HandleResetHwid(doDeepClean, dryRun bool) error {
fmt.Print("\nCursor HWID Reset\n\n")
fmt.Println(" Resets all machine / telemetry IDs so Cursor sees a fresh install.")
fmt.Print(" Cursor must be closed — it will be killed automatically.\n\n")
globalStorage := getCursorGlobalStorage()
cursorRoot := getCursorRoot()
if _, err := os.Stat(globalStorage); os.IsNotExist(err) {
fmt.Printf("Cursor config not found at:\n %s\n", globalStorage)
fmt.Println(" Make sure Cursor is installed and has been run at least once.")
os.Exit(1)
}
if dryRun {
fmt.Println(" [DRY RUN] Would reset IDs in:")
fmt.Printf(" %s\n", filepath.Join(globalStorage, "storage.json"))
fmt.Printf(" %s\n", filepath.Join(globalStorage, "state.vscdb"))
fmt.Printf(" %s\n", filepath.Join(cursorRoot, "machineId"))
return nil
}
killCursor()
time.Sleep(800 * time.Millisecond)
newIDs := generateNewIDs()
log("", "Generated new IDs:")
for k, v := range newIDs {
fmt.Printf(" %s: %s\n", k, v)
}
fmt.Println()
log("", "Updating storage.json...")
updateStorageJSON(filepath.Join(globalStorage, "storage.json"), newIDs)
log("", "Updating state.vscdb...")
updateStateVscdb(filepath.Join(globalStorage, "state.vscdb"), newIDs)
log("", "Updating machineId file...")
updateMachineIDFile(newIDs["telemetry.machineId"], cursorRoot)
if doDeepClean {
fmt.Println()
deepClean(cursorRoot)
}
fmt.Print("\nHWID reset complete. You can now restart Cursor.\n\n")
return nil
}

View File

@ -1,29 +0,0 @@
package cmd
import (
"database/sql"
"fmt"
_ "modernc.org/sqlite"
)
func updateVscdbPureGo(dbPath string, ids map[string]string) error {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS ItemTable (key TEXT PRIMARY KEY, value TEXT NOT NULL)`)
if err != nil {
return fmt.Errorf("create table: %w", err)
}
for k, v := range ids {
_, err = db.Exec(`INSERT OR REPLACE INTO ItemTable (key, value) VALUES (?, ?)`, k, v)
if err != nil {
return fmt.Errorf("insert %s: %w", k, err)
}
}
return nil
}

View File

@ -1,255 +0,0 @@
package cmd
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type ModelUsage struct {
NumRequests int `json:"numRequests"`
NumRequestsTotal int `json:"numRequestsTotal"`
NumTokens int `json:"numTokens"`
MaxTokenUsage *int `json:"maxTokenUsage"`
MaxRequestUsage *int `json:"maxRequestUsage"`
}
type UsageData struct {
StartOfMonth string `json:"startOfMonth"`
Models map[string]ModelUsage `json:"-"`
}
type StripeProfile struct {
MembershipType string `json:"membershipType"`
SubscriptionStatus string `json:"subscriptionStatus"`
DaysRemainingOnTrial *int `json:"daysRemainingOnTrial"`
IsTeamMember bool `json:"isTeamMember"`
IsYearlyPlan bool `json:"isYearlyPlan"`
}
func DecodeJWTPayload(token string) map[string]interface{} {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return nil
}
padded := strings.ReplaceAll(parts[1], "-", "+")
padded = strings.ReplaceAll(padded, "_", "/")
data, err := base64.StdEncoding.DecodeString(padded + strings.Repeat("=", (4-len(padded)%4)%4))
if err != nil {
return nil
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
return nil
}
return result
}
func TokenSub(token string) string {
payload := DecodeJWTPayload(token)
if payload == nil {
return ""
}
if sub, ok := payload["sub"].(string); ok {
return sub
}
return ""
}
func apiGet(path, token string) (map[string]interface{}, error) {
client := &http.Client{Timeout: 8 * time.Second}
req, err := http.NewRequest("GET", "https://api2.cursor.sh"+path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
return nil, nil
}
return result, nil
}
func FetchAccountUsage(token string) (*UsageData, error) {
raw, err := apiGet("/auth/usage", token)
if err != nil || raw == nil {
return nil, err
}
startOfMonth, _ := raw["startOfMonth"].(string)
usage := &UsageData{
StartOfMonth: startOfMonth,
Models: make(map[string]ModelUsage),
}
for k, v := range raw {
if k == "startOfMonth" {
continue
}
data, err := json.Marshal(v)
if err != nil {
continue
}
var mu ModelUsage
if err := json.Unmarshal(data, &mu); err == nil {
usage.Models[k] = mu
}
}
return usage, nil
}
func FetchStripeProfile(token string) (*StripeProfile, error) {
raw, err := apiGet("/auth/full_stripe_profile", token)
if err != nil || raw == nil {
return nil, err
}
profile := &StripeProfile{
MembershipType: fmt.Sprintf("%v", raw["membershipType"]),
SubscriptionStatus: fmt.Sprintf("%v", raw["subscriptionStatus"]),
IsTeamMember: raw["isTeamMember"] == true,
IsYearlyPlan: raw["isYearlyPlan"] == true,
}
if d, ok := raw["daysRemainingOnTrial"].(float64); ok {
di := int(d)
profile.DaysRemainingOnTrial = &di
}
return profile, nil
}
func DescribePlan(profile *StripeProfile) string {
if profile == nil {
return ""
}
switch profile.MembershipType {
case "free_trial":
days := 0
if profile.DaysRemainingOnTrial != nil {
days = *profile.DaysRemainingOnTrial
}
return fmt.Sprintf("Pro Trial (%dd left) — unlimited fast requests", days)
case "pro":
return "Pro — extended limits"
case "pro_plus":
return "Pro+ — extended limits"
case "ultra":
return "Ultra — extended limits"
case "free", "hobby":
return "Hobby (free) — limited agent requests"
default:
return fmt.Sprintf("%s · %s", profile.MembershipType, profile.SubscriptionStatus)
}
}
var modelLabels = map[string]string{
"gpt-4": "Fast Premium Requests",
"claude-sonnet-4-6": "Claude Sonnet 4.6",
"claude-sonnet-4-5-20250929-v1": "Claude Sonnet 4.5",
"claude-sonnet-4-20250514-v1": "Claude Sonnet 4",
"claude-opus-4-6-v1": "Claude Opus 4.6",
"claude-opus-4-5-20251101-v1": "Claude Opus 4.5",
"claude-opus-4-1-20250805-v1": "Claude Opus 4.1",
"claude-opus-4-20250514-v1": "Claude Opus 4",
"claude-haiku-4-5-20251001-v1": "Claude Haiku 4.5",
"claude-3-5-haiku-20241022-v1": "Claude 3.5 Haiku",
"gpt-5": "GPT-5",
"gpt-4o": "GPT-4o",
"o1": "o1",
"o3-mini": "o3-mini",
"cursor-small": "Cursor Small (free)",
}
func modelLabel(key string) string {
if label, ok := modelLabels[key]; ok {
return label
}
return key
}
func FormatUsageSummary(usage *UsageData) []string {
if usage == nil {
return nil
}
var lines []string
start := "?"
if usage.StartOfMonth != "" {
if t, err := time.Parse(time.RFC3339, usage.StartOfMonth); err == nil {
start = t.Format("2006-01-02")
} else {
start = usage.StartOfMonth
}
}
lines = append(lines, fmt.Sprintf(" Billing period from %s", start))
if len(usage.Models) == 0 {
lines = append(lines, " No requests this billing period")
return lines
}
type entry struct {
key string
usage ModelUsage
}
var entries []entry
for k, v := range usage.Models {
entries = append(entries, entry{k, v})
}
// Sort: entries with limits first, then by usage descending
for i := 1; i < len(entries); i++ {
for j := i; j > 0; j-- {
a, b := entries[j-1], entries[j]
aHasLimit := a.usage.MaxRequestUsage != nil
bHasLimit := b.usage.MaxRequestUsage != nil
if !aHasLimit && bHasLimit {
entries[j-1], entries[j] = entries[j], entries[j-1]
} else if aHasLimit == bHasLimit && a.usage.NumRequests < b.usage.NumRequests {
entries[j-1], entries[j] = entries[j], entries[j-1]
} else {
break
}
}
}
for _, e := range entries {
used := e.usage.NumRequests
max := e.usage.MaxRequestUsage
label := modelLabel(e.key)
if max != nil && *max > 0 {
pct := int(float64(used) / float64(*max) * 100)
bar := makeBar(used, *max, 12)
lines = append(lines, fmt.Sprintf(" %s: %d/%d (%d%%) [%s]", label, used, *max, pct, bar))
} else if used > 0 {
lines = append(lines, fmt.Sprintf(" %s: %d requests", label, used))
} else {
lines = append(lines, fmt.Sprintf(" %s: 0 requests (unlimited)", label))
}
}
return lines
}
func makeBar(used, max, width int) string {
fill := int(float64(used) / float64(max) * float64(width))
if fill > width {
fill = width
}
return strings.Repeat("█", fill) + strings.Repeat("░", width-fill)
}

View File

@ -1,28 +0,0 @@
package agent
import "cursor-api-proxy/internal/config"
func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, stream bool) []string {
args := []string{"--print"}
if cfg.ApproveMcps {
args = append(args, "--approve-mcps")
}
if cfg.Force {
args = append(args, "--force")
}
if cfg.ChatOnlyWorkspace {
args = append(args, "--trust")
}
args = append(args, "--workspace", workspaceDir)
args = append(args, "--model", model)
if stream {
args = append(args, "--stream-partial-output", "--output-format", "stream-json")
} else {
args = append(args, "--output-format", "text")
}
return args
}
func BuildAgentCmdArgs(cfg config.BridgeConfig, workspaceDir, model, prompt string, stream bool) []string {
return append(BuildAgentFixedArgs(cfg, workspaceDir, model, stream), prompt)
}

View File

@ -1,85 +0,0 @@
package agent
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
)
func getCandidates(agentScriptPath, configDirOverride string) []string {
if configDirOverride != "" {
return []string{filepath.Join(configDirOverride, "cli-config.json")}
}
var result []string
if dir := os.Getenv("CURSOR_CONFIG_DIR"); dir != "" {
result = append(result, filepath.Join(dir, "cli-config.json"))
}
if agentScriptPath != "" {
agentDir := filepath.Dir(agentScriptPath)
result = append(result, filepath.Join(agentDir, "..", "data", "config", "cli-config.json"))
}
home := os.Getenv("HOME")
if home == "" {
home = os.Getenv("USERPROFILE")
}
switch runtime.GOOS {
case "windows":
local := os.Getenv("LOCALAPPDATA")
if local == "" {
local = filepath.Join(home, "AppData", "Local")
}
result = append(result, filepath.Join(local, "cursor-agent", "cli-config.json"))
case "darwin":
result = append(result, filepath.Join(home, "Library", "Application Support", "cursor-agent", "cli-config.json"))
default:
xdg := os.Getenv("XDG_CONFIG_HOME")
if xdg == "" {
xdg = filepath.Join(home, ".config")
}
result = append(result, filepath.Join(xdg, "cursor-agent", "cli-config.json"))
}
return result
}
func RunMaxModePreflight(agentScriptPath, configDirOverride string) {
for _, candidate := range getCandidates(agentScriptPath, configDirOverride) {
data, err := os.ReadFile(candidate)
if err != nil {
continue
}
// Strip BOM if present
if len(data) >= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF {
data = data[3:]
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
continue
}
if raw == nil || len(raw) <= 1 {
continue
}
raw["maxMode"] = true
if model, ok := raw["model"].(map[string]interface{}); ok {
model["maxMode"] = true
}
out, err := json.MarshalIndent(raw, "", " ")
if err != nil {
continue
}
if err := os.WriteFile(candidate, out, 0644); err != nil {
continue
}
return
}
}

View File

@ -1,72 +0,0 @@
package agent
import (
"context"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/process"
"os"
"path/filepath"
)
func init() {
process.MaxModeFn = RunMaxModePreflight
}
func cacheTokenForAccount(configDir string) {
if configDir == "" {
return
}
token := ReadKeychainToken()
if token != "" {
WriteCachedToken(configDir, token)
}
}
func AccountsDir() string {
home := os.Getenv("HOME")
if home == "" {
home = os.Getenv("USERPROFILE")
}
return filepath.Join(home, ".cursor-api-proxy", "accounts")
}
func RunAgentSync(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, tempDir, configDir string, ctx context.Context) (process.RunResult, error) {
opts := process.RunOptions{
Cwd: workspaceDir,
TimeoutMs: cfg.TimeoutMs,
MaxMode: cfg.MaxMode,
ConfigDir: configDir,
Ctx: ctx,
}
result, err := process.Run(cfg.AgentBin, cmdArgs, opts)
cacheTokenForAccount(configDir)
if tempDir != "" {
os.RemoveAll(tempDir)
}
return result, err
}
func RunAgentStreamWithContext(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, onLine func(string), tempDir, configDir string, ctx context.Context) (process.StreamResult, error) {
opts := process.RunStreamingOptions{
RunOptions: process.RunOptions{
Cwd: workspaceDir,
TimeoutMs: cfg.TimeoutMs,
MaxMode: cfg.MaxMode,
ConfigDir: configDir,
Ctx: ctx,
},
OnLine: onLine,
}
result, err := process.RunStreaming(cfg.AgentBin, cmdArgs, opts)
cacheTokenForAccount(configDir)
if tempDir != "" {
os.RemoveAll(tempDir)
}
return result, err
}

View File

@ -1,36 +0,0 @@
package agent
import (
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
)
const tokenFile = ".cursor-token"
func ReadCachedToken(configDir string) string {
p := filepath.Join(configDir, tokenFile)
data, err := os.ReadFile(p)
if err != nil {
return ""
}
return strings.TrimSpace(string(data))
}
func WriteCachedToken(configDir, token string) {
p := filepath.Join(configDir, tokenFile)
_ = os.WriteFile(p, []byte(token), 0600)
}
func ReadKeychainToken() string {
if runtime.GOOS != "darwin" {
return ""
}
out, err := exec.Command("security", "find-generic-password", "-s", "cursor-access-token", "-w").Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}

View File

@ -1,174 +0,0 @@
package anthropic
import (
"cursor-api-proxy/internal/openai"
"encoding/json"
"fmt"
"strings"
)
type MessageParam struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System interface{} `json:"system"`
Stream bool `json:"stream"`
Tools []interface{} `json:"tools"`
}
func systemToText(system interface{}) string {
if system == nil {
return ""
}
switch v := system.(type) {
case string:
return strings.TrimSpace(v)
case []interface{}:
var parts []string
for _, p := range v {
if m, ok := p.(map[string]interface{}); ok {
if m["type"] == "text" {
if t, ok := m["text"].(string); ok {
parts = append(parts, t)
}
}
}
}
return strings.Join(parts, "\n")
}
return ""
}
func anthropicBlockToText(p interface{}) string {
if p == nil {
return ""
}
switch v := p.(type) {
case string:
return v
case map[string]interface{}:
typ, _ := v["type"].(string)
switch typ {
case "text":
if t, ok := v["text"].(string); ok {
return t
}
case "image":
if src, ok := v["source"].(map[string]interface{}); ok {
srcType, _ := src["type"].(string)
switch srcType {
case "base64":
mt, _ := src["media_type"].(string)
if mt == "" {
mt = "image"
}
return "[Image: base64 " + mt + "]"
case "url":
url, _ := src["url"].(string)
return "[Image: " + url + "]"
}
}
return "[Image]"
case "document":
title, _ := v["title"].(string)
if title == "" {
if src, ok := v["source"].(map[string]interface{}); ok {
title, _ = src["url"].(string)
}
}
if title != "" {
return "[Document: " + title + "]"
}
return "[Document]"
case "tool_use":
name, _ := v["name"].(string)
id, _ := v["id"].(string)
input := v["input"]
inputJSON, _ := json.Marshal(input)
if inputJSON == nil {
inputJSON = []byte("{}")
}
tag := fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, string(inputJSON))
if id != "" {
tag = fmt.Sprintf("[tool_use_id=%s] ", id) + tag
}
return tag
case "tool_result":
toolUseID, _ := v["tool_use_id"].(string)
content := v["content"]
var contentText string
switch c := content.(type) {
case string:
contentText = c
case []interface{}:
var parts []string
for _, block := range c {
if bm, ok := block.(map[string]interface{}); ok {
if bm["type"] == "text" {
if t, ok := bm["text"].(string); ok {
parts = append(parts, t)
}
}
}
}
contentText = strings.Join(parts, "\n")
}
label := "Tool result"
if toolUseID != "" {
label += " [id=" + toolUseID + "]"
}
return label + ": " + contentText
}
}
return ""
}
func anthropicContentToText(content interface{}) string {
switch v := content.(type) {
case string:
return v
case []interface{}:
var parts []string
for _, p := range v {
if t := anthropicBlockToText(p); t != "" {
parts = append(parts, t)
}
}
return strings.Join(parts, " ")
}
return ""
}
func BuildPromptFromAnthropicMessages(messages []MessageParam, system interface{}) string {
var oaiMessages []interface{}
systemText := systemToText(system)
if systemText != "" {
oaiMessages = append(oaiMessages, map[string]interface{}{
"role": "system",
"content": systemText,
})
}
for _, m := range messages {
text := anthropicContentToText(m.Content)
if text == "" {
continue
}
role := m.Role
if role != "user" && role != "assistant" {
role = "user"
}
oaiMessages = append(oaiMessages, map[string]interface{}{
"role": role,
"content": text,
})
}
return openai.BuildPromptFromMessages(oaiMessages)
}

View File

@ -1,109 +0,0 @@
package anthropic_test
import (
"cursor-api-proxy/internal/anthropic"
"strings"
"testing"
)
func TestBuildPromptFromAnthropicMessages_Simple(t *testing.T) {
messages := []anthropic.MessageParam{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there"},
}
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
if !strings.Contains(prompt, "Hello") {
t.Errorf("prompt missing user message: %q", prompt)
}
if !strings.Contains(prompt, "Hi there") {
t.Errorf("prompt missing assistant message: %q", prompt)
}
}
func TestBuildPromptFromAnthropicMessages_WithSystem(t *testing.T) {
messages := []anthropic.MessageParam{
{Role: "user", Content: "ping"},
}
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, "You are a helpful bot.")
if !strings.Contains(prompt, "You are a helpful bot.") {
t.Errorf("prompt missing system: %q", prompt)
}
if !strings.Contains(prompt, "ping") {
t.Errorf("prompt missing user: %q", prompt)
}
}
func TestBuildPromptFromAnthropicMessages_SystemArray(t *testing.T) {
system := []interface{}{
map[string]interface{}{"type": "text", "text": "Part A"},
map[string]interface{}{"type": "text", "text": "Part B"},
}
messages := []anthropic.MessageParam{
{Role: "user", Content: "test"},
}
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, system)
if !strings.Contains(prompt, "Part A") {
t.Errorf("prompt missing Part A: %q", prompt)
}
if !strings.Contains(prompt, "Part B") {
t.Errorf("prompt missing Part B: %q", prompt)
}
}
func TestBuildPromptFromAnthropicMessages_ContentBlocks(t *testing.T) {
content := []interface{}{
map[string]interface{}{"type": "text", "text": "block one"},
map[string]interface{}{"type": "text", "text": "block two"},
}
messages := []anthropic.MessageParam{
{Role: "user", Content: content},
}
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
if !strings.Contains(prompt, "block one") {
t.Errorf("prompt missing 'block one': %q", prompt)
}
if !strings.Contains(prompt, "block two") {
t.Errorf("prompt missing 'block two': %q", prompt)
}
}
func TestBuildPromptFromAnthropicMessages_ImageBlock(t *testing.T) {
content := []interface{}{
map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": "base64",
"media_type": "image/png",
"data": "abc123",
},
},
}
messages := []anthropic.MessageParam{
{Role: "user", Content: content},
}
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
if !strings.Contains(prompt, "[Image") {
t.Errorf("prompt missing [Image]: %q", prompt)
}
}
func TestBuildPromptFromAnthropicMessages_EmptyContentSkipped(t *testing.T) {
messages := []anthropic.MessageParam{
{Role: "user", Content: ""},
{Role: "assistant", Content: "response"},
}
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
if !strings.Contains(prompt, "response") {
t.Errorf("prompt missing 'response': %q", prompt)
}
}
func TestBuildPromptFromAnthropicMessages_UnknownRoleBecomesUser(t *testing.T) {
messages := []anthropic.MessageParam{
{Role: "system", Content: "system-as-user"},
}
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
if !strings.Contains(prompt, "system-as-user") {
t.Errorf("prompt missing 'system-as-user': %q", prompt)
}
}

View File

@ -1,40 +0,0 @@
package apitypes
type Message struct {
Role string
Content string
}
type Tool struct {
Type string
Function ToolFunction
}
type ToolFunction struct {
Name string
Description string
Parameters interface{}
}
type ToolCall struct {
ID string
Name string
Arguments string
}
type StreamChunk struct {
Type ChunkType
Text string
Thinking string
ToolCall *ToolCall
Done bool
}
type ChunkType int
const (
ChunkText ChunkType = iota
ChunkThinking
ChunkToolCall
ChunkDone
)

View File

@ -1,7 +1,7 @@
package config package config
import ( import (
"cursor-api-proxy/internal/env" "cursor-api-proxy/pkg/infrastructure/env"
"github.com/zeromicro/go-zero/rest" "github.com/zeromicro/go-zero/rest"
) )

View File

@ -2,7 +2,7 @@ package config_test
import ( import (
"cursor-api-proxy/internal/config" "cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/env" "cursor-api-proxy/pkg/infrastructure/env"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"

381
internal/env/env.go vendored
View File

@ -1,381 +0,0 @@
package env
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
)
type EnvSource map[string]string
type LoadedEnv struct {
AgentBin string
AgentNode string
AgentScript string
CommandShell string
Host string
Port int
RequiredKey string
DefaultModel string
Provider string
Force bool
ApproveMcps bool
StrictModel bool
Workspace string
TimeoutMs int
TLSCertPath string
TLSKeyPath string
SessionsLogPath string
ChatOnlyWorkspace bool
Verbose bool
MaxMode bool
ConfigDirs []string
MultiPort bool
WinCmdlineMax int
GeminiAccountDir string
GeminiBrowserVisible bool
GeminiMaxSessions int
}
type AgentCommand struct {
Command string
Args []string
Env map[string]string
WindowsVerbatimArguments bool
AgentScriptPath string
ConfigDir string
}
func getEnvVal(e EnvSource, names []string) string {
for _, name := range names {
if v, ok := e[name]; ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
}
return ""
}
func envBool(e EnvSource, names []string, def bool) bool {
raw := getEnvVal(e, names)
if raw == "" {
return def
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
return def
}
func envInt(e EnvSource, names []string, def int) int {
raw := getEnvVal(e, names)
if raw == "" {
return def
}
v, err := strconv.Atoi(raw)
if err != nil {
return def
}
return v
}
func normalizeModelId(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return "auto"
}
parts := strings.Split(raw, "/")
last := parts[len(parts)-1]
if last == "" {
return "auto"
}
return last
}
func resolveAbs(raw, cwd string) string {
if raw == "" {
return ""
}
if filepath.IsAbs(raw) {
return raw
}
return filepath.Join(cwd, raw)
}
func isAuthenticatedAccountDir(dir string) bool {
data, err := os.ReadFile(filepath.Join(dir, "cli-config.json"))
if err != nil {
return false
}
var cfg struct {
AuthInfo *struct {
Email string `json:"email"`
} `json:"authInfo"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
return cfg.AuthInfo != nil && cfg.AuthInfo.Email != ""
}
func discoverAccountDirs(homeDir string) []string {
if homeDir == "" {
return nil
}
accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts")
entries, err := os.ReadDir(accountsDir)
if err != nil {
return nil
}
var dirs []string
for _, e := range entries {
if !e.IsDir() {
continue
}
dir := filepath.Join(accountsDir, e.Name())
if isAuthenticatedAccountDir(dir) {
dirs = append(dirs, dir)
}
}
return dirs
}
func parseDotEnv(path string) EnvSource {
data, err := os.ReadFile(path)
if err != nil {
return nil
}
m := make(EnvSource)
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
m[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
}
return m
}
func OsEnvToMap(cwdHint ...string) EnvSource {
m := make(EnvSource)
for _, kv := range os.Environ() {
parts := strings.SplitN(kv, "=", 2)
if len(parts) == 2 {
m[parts[0]] = parts[1]
}
}
cwd := ""
if len(cwdHint) > 0 && cwdHint[0] != "" {
cwd = cwdHint[0]
} else {
cwd, _ = os.Getwd()
}
if dotenv := parseDotEnv(filepath.Join(cwd, ".env")); dotenv != nil {
for k, v := range dotenv {
if _, exists := m[k]; !exists {
m[k] = v
}
}
}
return m
}
func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv {
if e == nil {
e = OsEnvToMap()
}
if cwd == "" {
var err error
cwd, err = os.Getwd()
if err != nil {
cwd = "."
}
}
host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"})
if host == "" {
host = "127.0.0.1"
}
port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765)
if port <= 0 {
port = 8765
}
home := getEnvVal(e, []string{"HOME", "USERPROFILE"})
sessionsLogPath := func() string {
if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" {
return p
}
if home != "" {
return filepath.Join(home, ".cursor-api-proxy", "sessions.log")
}
return filepath.Join(cwd, "sessions.log")
}()
var configDirs []string
if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" {
for _, d := range strings.Split(raw, ",") {
d = strings.TrimSpace(d)
if d != "" {
if p := resolveAbs(d, cwd); p != "" {
configDirs = append(configDirs, p)
}
}
}
}
if len(configDirs) == 0 {
configDirs = discoverAccountDirs(home)
}
winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000)
if winMax < 4096 {
winMax = 4096
}
if winMax > 32700 {
winMax = 32700
}
agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"})
if agentBin == "" {
agentBin = "agent"
}
commandShell := getEnvVal(e, []string{"COMSPEC"})
if commandShell == "" {
commandShell = "cmd.exe"
}
workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd)
if workspace == "" {
workspace = cwd
}
geminiAccountDir := getEnvVal(e, []string{"GEMINI_ACCOUNT_DIR"})
if geminiAccountDir == "" {
geminiAccountDir = filepath.Join(home, ".cursor-api-proxy", "gemini-accounts")
} else {
geminiAccountDir = resolveAbs(geminiAccountDir, cwd)
}
return LoadedEnv{
AgentBin: agentBin,
AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}),
AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}),
CommandShell: commandShell,
Host: host,
Port: port,
RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}),
DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})),
Provider: getEnvVal(e, []string{"CURSOR_BRIDGE_PROVIDER"}),
Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false),
ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false),
StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true),
Workspace: workspace,
TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000),
TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd),
TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd),
SessionsLogPath: sessionsLogPath,
ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true),
Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false),
MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false),
ConfigDirs: configDirs,
MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false),
WinCmdlineMax: winMax,
GeminiAccountDir: geminiAccountDir,
GeminiBrowserVisible: envBool(e, []string{"GEMINI_BROWSER_VISIBLE"}, false),
GeminiMaxSessions: envInt(e, []string{"GEMINI_MAX_SESSIONS"}, 3),
}
}
func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand {
if e == nil {
e = OsEnvToMap()
}
loaded := LoadEnvConfig(e, cwd)
cloneEnv := func() map[string]string {
m := make(map[string]string, len(e))
for k, v := range e {
m[k] = v
}
return m
}
if runtime.GOOS == "windows" {
if loaded.AgentNode != "" && loaded.AgentScript != "" {
agentScriptPath := loaded.AgentScript
if !filepath.IsAbs(agentScriptPath) {
agentScriptPath = filepath.Join(cwd, agentScriptPath)
}
agentDir := filepath.Dir(agentScriptPath)
configDir := filepath.Join(agentDir, "..", "data", "config")
env2 := cloneEnv()
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
ac := AgentCommand{
Command: loaded.AgentNode,
Args: append([]string{loaded.AgentScript}, args...),
Env: env2,
AgentScriptPath: agentScriptPath,
}
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
ac.ConfigDir = configDir
}
return ac
}
if strings.HasSuffix(strings.ToLower(cmd), ".cmd") {
cmdResolved := cmd
if !filepath.IsAbs(cmd) {
cmdResolved = filepath.Join(cwd, cmd)
}
dir := filepath.Dir(cmdResolved)
nodeBin := filepath.Join(dir, "node.exe")
script := filepath.Join(dir, "index.js")
if _, err1 := os.Stat(nodeBin); err1 == nil {
if _, err2 := os.Stat(script); err2 == nil {
configDir := filepath.Join(dir, "..", "data", "config")
env2 := cloneEnv()
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
ac := AgentCommand{
Command: nodeBin,
Args: append([]string{script}, args...),
Env: env2,
AgentScriptPath: script,
}
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
ac.ConfigDir = configDir
}
return ac
}
}
quotedArgs := make([]string, len(args))
for i, a := range args {
if strings.Contains(a, " ") {
quotedArgs[i] = `"` + a + `"`
} else {
quotedArgs[i] = a
}
}
cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"`
return AgentCommand{
Command: loaded.CommandShell,
Args: []string{"/d", "/s", "/c", cmdLine},
Env: cloneEnv(),
WindowsVerbatimArguments: true,
}
}
}
return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()}
}

View File

@ -1,65 +0,0 @@
package env
import "testing"
func TestLoadEnvConfigDefaults(t *testing.T) {
e := EnvSource{}
loaded := LoadEnvConfig(e, "/tmp")
if loaded.Host != "127.0.0.1" {
t.Errorf("expected 127.0.0.1, got %s", loaded.Host)
}
if loaded.Port != 8765 {
t.Errorf("expected 8765, got %d", loaded.Port)
}
if loaded.DefaultModel != "auto" {
t.Errorf("expected auto, got %s", loaded.DefaultModel)
}
if loaded.AgentBin != "agent" {
t.Errorf("expected agent, got %s", loaded.AgentBin)
}
if !loaded.StrictModel {
t.Error("expected strictModel=true by default")
}
}
func TestLoadEnvConfigOverride(t *testing.T) {
e := EnvSource{
"CURSOR_BRIDGE_HOST": "0.0.0.0",
"CURSOR_BRIDGE_PORT": "9000",
"CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4",
"CURSOR_AGENT_BIN": "/usr/local/bin/agent",
}
loaded := LoadEnvConfig(e, "/tmp")
if loaded.Host != "0.0.0.0" {
t.Errorf("expected 0.0.0.0, got %s", loaded.Host)
}
if loaded.Port != 9000 {
t.Errorf("expected 9000, got %d", loaded.Port)
}
if loaded.DefaultModel != "gpt-4" {
t.Errorf("expected gpt-4, got %s", loaded.DefaultModel)
}
if loaded.AgentBin != "/usr/local/bin/agent" {
t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin)
}
}
func TestNormalizeModelID(t *testing.T) {
tests := []struct {
input string
want string
}{
{"gpt-4", "gpt-4"},
{"openai/gpt-4", "gpt-4"},
{"", "auto"},
{" ", "auto"},
}
for _, tc := range tests {
got := normalizeModelId(tc.input)
if got != tc.want {
t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want)
}
}
}

View File

@ -1,577 +0,0 @@
package handlers
import (
"context"
"cursor-api-proxy/internal/agent"
"cursor-api-proxy/internal/anthropic"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/httputil"
"cursor-api-proxy/internal/logger"
"cursor-api-proxy/internal/models"
"cursor-api-proxy/internal/openai"
"cursor-api-proxy/internal/parser"
"cursor-api-proxy/internal/pool"
"cursor-api-proxy/internal/sanitize"
"cursor-api-proxy/internal/toolcall"
"cursor-api-proxy/internal/winlimit"
"cursor-api-proxy/internal/workspace"
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/google/uuid"
)
func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
var req anthropic.MessagesRequest
if err := json.Unmarshal([]byte(rawBody), &req); err != nil {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"},
}, nil)
return
}
requested := openai.NormalizeModelID(req.Model)
model := ResolveModel(requested, lastModelRef, cfg)
var rawMap map[string]interface{}
_ = json.Unmarshal([]byte(rawBody), &rawMap)
cleanSystem := sanitize.SanitizeSystem(req.System)
rawMessages := make([]interface{}, len(req.Messages))
for i, m := range req.Messages {
rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
}
cleanRawMessages := sanitize.SanitizeMessages(rawMessages)
var cleanMessages []anthropic.MessageParam
for _, raw := range cleanRawMessages {
if m, ok := raw.(map[string]interface{}); ok {
role, _ := m["role"].(string)
cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]})
}
}
toolsText := openai.ToolsToSystemText(req.Tools, nil)
var systemWithTools interface{}
if toolsText != "" {
sysStr := ""
switch v := cleanSystem.(type) {
case string:
sysStr = v
}
if sysStr != "" {
systemWithTools = sysStr + "\n\n" + toolsText
} else {
systemWithTools = toolsText
}
} else {
systemWithTools = cleanSystem
}
prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools)
if req.MaxTokens == 0 {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"},
}, nil)
return
}
cursorModel := models.ResolveToCursorModel(model)
if cursorModel == "" {
cursorModel = model
}
var trafficMsgs []logger.TrafficMessage
if s := systemToString(cleanSystem); s != "" {
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s})
}
for _, m := range cleanMessages {
text := contentToString(m.Content)
if text != "" {
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text})
}
}
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream)
headerWs := r.Header.Get("x-cursor-workspace")
ws := workspace.ResolveWorkspace(cfg, headerWs)
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream)
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
if cfg.Verbose {
if len(prompt) > 200 {
logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...")
} else {
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt)
}
logger.LogDebug("cmd_args=%v", fit.Args)
}
if !fit.OK {
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": fit.Error},
}, nil)
return
}
if fit.Truncated {
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
}
cmdArgs := fit.Args
msgID := "msg_" + uuid.New().String()
var truncatedHeaders map[string]string
if fit.Truncated {
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
}
hasTools := len(req.Tools) > 0
var toolNames map[string]bool
if hasTools {
toolNames = toolcall.CollectToolNames(req.Tools)
}
if req.Stream {
httputil.WriteSSEHeaders(w, truncatedHeaders)
flusher, _ := w.(http.Flusher)
writeEvent := func(evt interface{}) {
data, _ := json.Marshal(evt)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
}
var accumulated string
var accumulatedThinking string
var chunkNum int
var p parser.Parser
writeEvent(map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []interface{}{},
},
})
if hasTools {
toolCallMarkerRe := regexp.MustCompile(`行政法规|<function_calls>`)
var toolCallMode bool
textBlockOpen := false
textBlockIndex := 0
thinkingOpen := false
thinkingBlockIndex := 0
blockCount := 0
p = parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
chunkNum++
logger.LogStreamChunk(model, text, chunkNum)
if toolCallMode {
return
}
if toolCallMarkerRe.MatchString(text) {
if textBlockOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
textBlockOpen = false
}
if thinkingOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
thinkingOpen = false
}
toolCallMode = true
return
}
if !textBlockOpen && !thinkingOpen {
textBlockIndex = blockCount
writeEvent(map[string]interface{}{
"type": "content_block_start",
"index": textBlockIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
textBlockOpen = true
blockCount++
}
if thinkingOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
thinkingOpen = false
}
writeEvent(map[string]interface{}{
"type": "content_block_delta",
"index": textBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": text},
})
},
func(thinking string) {
accumulatedThinking += thinking
chunkNum++
if toolCallMode {
return
}
if !thinkingOpen {
thinkingBlockIndex = blockCount
writeEvent(map[string]interface{}{
"type": "content_block_start",
"index": thinkingBlockIndex,
"content_block": map[string]string{"type": "thinking", "thinking": ""},
})
thinkingOpen = true
blockCount++
}
writeEvent(map[string]interface{}{
"type": "content_block_delta",
"index": thinkingBlockIndex,
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
})
},
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
blockIndex := 0
if thinkingOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
thinkingOpen = false
}
if parsed.HasToolCalls() {
if textBlockOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
blockIndex = textBlockIndex + 1
}
if parsed.TextContent != "" && !textBlockOpen && !toolCallMode {
writeEvent(map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
writeEvent(map[string]interface{}{
"type": "content_block_delta", "index": blockIndex,
"delta": map[string]string{"type": "text_delta", "text": parsed.TextContent},
})
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
blockIndex++
}
for _, tc := range parsed.ToolCalls {
toolID := "toolu_" + uuid.New().String()[:12]
var inputObj interface{}
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
if inputObj == nil {
inputObj = map[string]interface{}{}
}
writeEvent(map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]interface{}{
"type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{},
},
})
writeEvent(map[string]interface{}{
"type": "content_block_delta", "index": blockIndex,
"delta": map[string]interface{}{
"type": "input_json_delta", "partial_json": tc.Arguments,
},
})
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
blockIndex++
}
writeEvent(map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
writeEvent(map[string]interface{}{"type": "message_stop"})
} else {
if textBlockOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
} else if accumulated != "" {
writeEvent(map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
writeEvent(map[string]interface{}{
"type": "content_block_delta", "index": blockIndex,
"delta": map[string]string{"type": "text_delta", "text": accumulated},
})
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
blockIndex++
} else {
writeEvent(map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
blockIndex++
}
writeEvent(map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
writeEvent(map[string]interface{}{"type": "message_stop"})
}
},
)
} else {
// 非 tools 模式:即時串流 thinking 和 text
blockCount := 0
thinkingOpen := false
textOpen := false
p = parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
chunkNum++
logger.LogStreamChunk(model, text, chunkNum)
if thinkingOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
thinkingOpen = false
}
if !textOpen {
writeEvent(map[string]interface{}{
"type": "content_block_start",
"index": blockCount,
"content_block": map[string]string{"type": "text", "text": ""},
})
textOpen = true
blockCount++
}
writeEvent(map[string]interface{}{
"type": "content_block_delta",
"index": blockCount - 1,
"delta": map[string]string{"type": "text_delta", "text": text},
})
},
func(thinking string) {
accumulatedThinking += thinking
chunkNum++
if !thinkingOpen {
writeEvent(map[string]interface{}{
"type": "content_block_start",
"index": blockCount,
"content_block": map[string]string{"type": "thinking", "thinking": ""},
})
thinkingOpen = true
blockCount++
}
writeEvent(map[string]interface{}{
"type": "content_block_delta",
"index": blockCount - 1,
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
})
},
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
if thinkingOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
thinkingOpen = false
}
if !textOpen {
writeEvent(map[string]interface{}{
"type": "content_block_start",
"index": blockCount,
"content_block": map[string]string{"type": "text", "text": ""},
})
blockCount++
}
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
writeEvent(map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
writeEvent(map[string]interface{}{"type": "message_stop"})
},
)
}
configDir := ph.GetNextConfigDir()
logger.LogAccountAssigned(configDir)
ph.ReportRequestStart(configDir)
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
streamStart := time.Now().UnixMilli()
ctx := r.Context()
wrappedParser := func(line string) {
logger.LogRawLine(line)
p.Parse(line)
}
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
if ctx.Err() == nil {
p.Flush()
}
latencyMs := time.Now().UnixMilli() - streamStart
ph.ReportRequestEnd(configDir)
if ctx.Err() == context.DeadlineExceeded {
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
} else if ctx.Err() == context.Canceled {
logger.LogClientDisconnect(method, pathname, model, latencyMs)
} else if err == nil && isRateLimited(result.Stderr) {
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
}
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
ph.ReportRequestError(configDir, latencyMs)
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
} else {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
}
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
} else if ctx.Err() == nil {
ph.ReportRequestSuccess(configDir, latencyMs)
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
}
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
return
}
configDir := ph.GetNextConfigDir()
logger.LogAccountAssigned(configDir)
ph.ReportRequestStart(configDir)
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
syncStart := time.Now().UnixMilli()
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
syncLatency := time.Now().UnixMilli() - syncStart
ph.ReportRequestEnd(configDir)
ctx := r.Context()
if ctx.Err() == context.DeadlineExceeded {
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
httputil.WriteJSON(w, 504, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)},
}, nil)
return
}
if ctx.Err() == context.Canceled {
logger.LogClientDisconnect(method, pathname, model, syncLatency)
return
}
if err != nil {
ph.ReportRequestError(configDir, syncLatency)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": err.Error()},
}, nil)
return
}
if isRateLimited(out.Stderr) {
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
}
if out.Code != 0 {
ph.ReportRequestError(configDir, syncLatency)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": errMsg},
}, nil)
return
}
ph.ReportRequestSuccess(configDir, syncLatency)
content := strings.TrimSpace(out.Stdout)
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
if hasTools {
parsed := toolcall.ExtractToolCalls(content, toolNames)
if parsed.HasToolCalls() {
var contentBlocks []map[string]interface{}
if parsed.TextContent != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text", "text": parsed.TextContent,
})
}
for _, tc := range parsed.ToolCalls {
toolID := "toolu_" + uuid.New().String()[:12]
var inputObj interface{}
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
if inputObj == nil {
inputObj = map[string]interface{}{}
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj,
})
}
httputil.WriteJSON(w, 200, map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": contentBlocks,
"model": model,
"stop_reason": "tool_use",
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
}, truncatedHeaders)
return
}
}
httputil.WriteJSON(w, 200, map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": []map[string]string{{"type": "text", "text": content}},
"model": model,
"stop_reason": "end_turn",
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
}, truncatedHeaders)
}
func systemToString(system interface{}) string {
switch v := system.(type) {
case string:
return v
case []interface{}:
result := ""
for _, p := range v {
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
if t, ok := m["text"].(string); ok {
result += t
}
}
}
return result
}
return ""
}
func contentToString(content interface{}) string {
switch v := content.(type) {
case string:
return v
case []interface{}:
result := ""
for _, p := range v {
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
if t, ok := m["text"].(string); ok {
result += t
}
}
}
return result
}
return ""
}

View File

@ -1,471 +0,0 @@
package handlers
import (
"context"
"cursor-api-proxy/internal/agent"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/httputil"
"cursor-api-proxy/internal/logger"
"cursor-api-proxy/internal/models"
"cursor-api-proxy/internal/openai"
"cursor-api-proxy/internal/parser"
"cursor-api-proxy/internal/pool"
"cursor-api-proxy/internal/sanitize"
"cursor-api-proxy/internal/toolcall"
"cursor-api-proxy/internal/winlimit"
"cursor-api-proxy/internal/workspace"
"encoding/json"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/google/uuid"
)
var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`)
var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`)
func isRateLimited(stderr string) bool {
return rateLimitRe.MatchString(stderr)
}
func extractRetryAfterMs(stderr string) int64 {
if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 {
if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 {
return secs * 1000
}
}
return 60000
}
func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
var bodyMap map[string]interface{}
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
}, nil)
return
}
rawModel, _ := bodyMap["model"].(string)
requested := openai.NormalizeModelID(rawModel)
model := ResolveModel(requested, lastModelRef, cfg)
cursorModel := models.ResolveToCursorModel(model)
if cursorModel == "" {
cursorModel = model
}
var messages []interface{}
if m, ok := bodyMap["messages"].([]interface{}); ok {
messages = m
}
var tools []interface{}
if t, ok := bodyMap["tools"].([]interface{}); ok {
tools = t
}
var funcs []interface{}
if f, ok := bodyMap["functions"].([]interface{}); ok {
funcs = f
}
cleanMessages := sanitize.SanitizeMessages(messages)
toolsText := openai.ToolsToSystemText(tools, funcs)
messagesWithTools := cleanMessages
if toolsText != "" {
messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...)
}
prompt := openai.BuildPromptFromMessages(messagesWithTools)
var trafficMsgs []logger.TrafficMessage
for _, raw := range cleanMessages {
if m, ok := raw.(map[string]interface{}); ok {
role, _ := m["role"].(string)
content := openai.MessageContentToText(m["content"])
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
}
}
isStream := false
if s, ok := bodyMap["stream"].(bool); ok {
isStream = s
}
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream)
headerWs := r.Header.Get("x-cursor-workspace")
ws := workspace.ResolveWorkspace(cfg, headerWs)
promptLen := len(prompt)
if cfg.Verbose {
if promptLen > 200 {
logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200])
} else {
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt)
}
}
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream)
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
if cfg.Verbose {
logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args)
}
if !fit.OK {
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"},
}, nil)
return
}
if fit.Truncated {
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
}
cmdArgs := fit.Args
id := "chatcmpl_" + uuid.New().String()
created := time.Now().Unix()
var truncatedHeaders map[string]string
if fit.Truncated {
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
}
hasTools := len(tools) > 0 || len(funcs) > 0
var toolNames map[string]bool
if hasTools {
toolNames = toolcall.CollectToolNames(tools)
for _, f := range funcs {
if fm, ok := f.(map[string]interface{}); ok {
if name, ok := fm["name"].(string); ok {
toolNames[name] = true
}
}
}
}
if isStream {
httputil.WriteSSEHeaders(w, truncatedHeaders)
flusher, _ := w.(http.Flusher)
var accumulated string
var chunkNum int
var p parser.Parser
// toolCallMarkerRe 偵測 tool call 開頭標記,一旦出現就停止即時輸出並進入累積模式
toolCallMarkerRe := regexp.MustCompile(`<tool_call>|<function_calls>`)
if hasTools {
var toolCallMode bool // 是否已進入 tool call 累積模式
p = parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
chunkNum++
logger.LogStreamChunk(model, text, chunkNum)
if toolCallMode {
// 已進入累積模式,不即時輸出
return
}
if toolCallMarkerRe.MatchString(text) {
// 偵測到 tool call 標記,切換為累積模式
toolCallMode = true
return
}
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
},
func(_ string) {}, // thinking ignored in tools mode
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
if parsed.HasToolCalls() {
if parsed.TextContent != "" && toolCallMode {
// 已有部分 text 被即時輸出,只補發剩餘的
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
}
for i, tc := range parsed.ToolCalls {
callID := "call_" + uuid.New().String()[:8]
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{
"tool_calls": []map[string]interface{}{
{
"index": i,
"id": callID,
"type": "function",
"function": map[string]interface{}{
"name": tc.Name,
"arguments": tc.Arguments,
},
},
},
}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
}
stopChunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
} else {
stopChunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
}
},
)
} else {
p = parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
chunkNum++
logger.LogStreamChunk(model, text, chunkNum)
chunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
},
func(thinking string) {
chunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
},
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
stopChunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
},
)
}
configDir := ph.GetNextConfigDir()
logger.LogAccountAssigned(configDir)
ph.ReportRequestStart(configDir)
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
streamStart := time.Now().UnixMilli()
ctx := r.Context()
wrappedParser := func(line string) {
logger.LogRawLine(line)
p.Parse(line)
}
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
// agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾
if ctx.Err() == nil {
p.Flush()
}
latencyMs := time.Now().UnixMilli() - streamStart
ph.ReportRequestEnd(configDir)
if ctx.Err() == context.DeadlineExceeded {
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
} else if ctx.Err() == context.Canceled {
logger.LogClientDisconnect(method, pathname, model, latencyMs)
} else if err == nil && isRateLimited(result.Stderr) {
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
}
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
ph.ReportRequestError(configDir, latencyMs)
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
} else {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
}
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
} else if ctx.Err() == nil {
ph.ReportRequestSuccess(configDir, latencyMs)
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
}
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
return
}
configDir := ph.GetNextConfigDir()
logger.LogAccountAssigned(configDir)
ph.ReportRequestStart(configDir)
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
syncStart := time.Now().UnixMilli()
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
syncLatency := time.Now().UnixMilli() - syncStart
ph.ReportRequestEnd(configDir)
ctx := r.Context()
if ctx.Err() == context.DeadlineExceeded {
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
httputil.WriteJSON(w, 504, map[string]interface{}{
"error": map[string]string{"message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs), "code": "timeout"},
}, nil)
return
}
if ctx.Err() == context.Canceled {
logger.LogClientDisconnect(method, pathname, model, syncLatency)
return
}
if err != nil {
ph.ReportRequestError(configDir, syncLatency)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"},
}, nil)
return
}
if isRateLimited(out.Stderr) {
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
}
if out.Code != 0 {
ph.ReportRequestError(configDir, syncLatency)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": errMsg, "code": "cursor_cli_error"},
}, nil)
return
}
ph.ReportRequestSuccess(configDir, syncLatency)
content := strings.TrimSpace(out.Stdout)
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
if hasTools {
parsed := toolcall.ExtractToolCalls(content, toolNames)
if parsed.HasToolCalls() {
msg := map[string]interface{}{"role": "assistant"}
if parsed.TextContent != "" {
msg["content"] = parsed.TextContent
} else {
msg["content"] = nil
}
var tcArr []map[string]interface{}
for _, tc := range parsed.ToolCalls {
callID := "call_" + uuid.New().String()[:8]
tcArr = append(tcArr, map[string]interface{}{
"id": callID,
"type": "function",
"function": map[string]interface{}{
"name": tc.Name,
"arguments": tc.Arguments,
},
})
}
msg["tool_calls"] = tcArr
httputil.WriteJSON(w, 200, map[string]interface{}{
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": []map[string]interface{}{
{"index": 0, "message": msg, "finish_reason": "tool_calls"},
},
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}, truncatedHeaders)
return
}
}
httputil.WriteJSON(w, 200, map[string]interface{}{
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"message": map[string]string{"role": "assistant", "content": content},
"finish_reason": "stop",
},
},
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}, truncatedHeaders)
}

View File

@ -1,203 +0,0 @@
package handlers
import (
"context"
"cursor-api-proxy/internal/apitypes"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/httputil"
"cursor-api-proxy/internal/logger"
"cursor-api-proxy/internal/providers/geminiweb"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
)
func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) {
_ = context.Background() // 確保 context 被使用
var bodyMap map[string]interface{}
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
}, nil)
return
}
rawModel, _ := bodyMap["model"].(string)
if rawModel == "" {
rawModel = "gemini-2.0-flash"
}
var messages []interface{}
if m, ok := bodyMap["messages"].([]interface{}); ok {
messages = m
}
isStream := false
if s, ok := bodyMap["stream"].(bool); ok {
isStream = s
}
// 轉換 messages 為 apitypes.Message
var apiMessages []apitypes.Message
for _, m := range messages {
if msgMap, ok := m.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content := ""
if c, ok := msgMap["content"].(string); ok {
content = c
}
apiMessages = append(apiMessages, apitypes.Message{
Role: role,
Content: content,
})
}
}
logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream)
start := time.Now().UnixMilli()
// 創建 Gemini provider (使用 Playwright)
provider, provErr := geminiweb.NewPlaywrightProvider(cfg)
if provErr != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, provErr.Error())
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": provErr.Error(), "code": "provider_error"},
}, nil)
return
}
if isStream {
httputil.WriteSSEHeaders(w, nil)
flusher, _ := w.(http.Flusher)
id := "chatcmpl_" + uuid.New().String()
created := time.Now().Unix()
var accumulated string
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
if chunk.Type == apitypes.ChunkText {
accumulated += chunk.Text
respChunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": rawModel,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]string{"content": chunk.Text},
"finish_reason": nil,
},
},
}
data, _ := json.Marshal(respChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
} else if chunk.Type == apitypes.ChunkThinking {
respChunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": rawModel,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{"reasoning_content": chunk.Thinking},
"finish_reason": nil,
},
},
}
data, _ := json.Marshal(respChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
} else if chunk.Type == apitypes.ChunkDone {
stopChunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": rawModel,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{},
"finish_reason": "stop",
},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
}
})
latencyMs := time.Now().UnixMilli() - start
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
return
}
logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true)
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
return
}
// 非串流模式
var resultText string
var resultThinking string
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
if chunk.Type == apitypes.ChunkText {
resultText += chunk.Text
} else if chunk.Type == apitypes.ChunkThinking {
resultThinking += chunk.Thinking
}
})
latencyMs := time.Now().UnixMilli() - start
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": err.Error(), "code": "gemini_error"},
}, nil)
return
}
logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false)
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
id := "chatcmpl_" + uuid.New().String()
created := time.Now().Unix()
resp := map[string]interface{}{
"id": id,
"object": "chat.completion",
"created": created,
"model": rawModel,
"choices": []map[string]interface{}{
{
"index": 0,
"message": map[string]interface{}{
"role": "assistant",
"content": resultText,
},
"finish_reason": "stop",
},
},
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
httputil.WriteJSON(w, 200, resp, nil)
}

View File

@ -1,20 +0,0 @@
package handlers
import (
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/httputil"
"net/http"
)
func HandleHealth(w http.ResponseWriter, r *http.Request, version string, cfg config.BridgeConfig) {
httputil.WriteJSON(w, 200, map[string]interface{}{
"ok": true,
"version": version,
"workspace": cfg.Workspace,
"mode": cfg.Mode,
"defaultModel": cfg.DefaultModel,
"force": cfg.Force,
"approveMcps": cfg.ApproveMcps,
"strictModel": cfg.StrictModel,
}, nil)
}

View File

@ -1,107 +0,0 @@
package handlers
import (
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/httputil"
"cursor-api-proxy/internal/models"
"net/http"
"sync"
"time"
)
const modelCacheTTLMs = 5 * 60 * 1000
type ModelCache struct {
At int64
Models []models.CursorCliModel
}
type ModelCacheRef struct {
mu sync.Mutex
cache *ModelCache
inflight bool
waiters []chan struct{}
}
func (ref *ModelCacheRef) HandleModels(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig) {
now := time.Now().UnixMilli()
ref.mu.Lock()
if ref.cache != nil && now-ref.cache.At <= modelCacheTTLMs {
cache := ref.cache
ref.mu.Unlock()
writeModels(w, cache.Models)
return
}
if ref.inflight {
// Wait for the in-flight fetch
ch := make(chan struct{}, 1)
ref.waiters = append(ref.waiters, ch)
ref.mu.Unlock()
<-ch
ref.mu.Lock()
cache := ref.cache
ref.mu.Unlock()
writeModels(w, cache.Models)
return
}
ref.inflight = true
ref.mu.Unlock()
fetched, err := models.ListCursorCliModels(cfg.AgentBin, 60000)
ref.mu.Lock()
ref.inflight = false
if err == nil {
ref.cache = &ModelCache{At: time.Now().UnixMilli(), Models: fetched}
}
waiters := ref.waiters
ref.waiters = nil
ref.mu.Unlock()
for _, ch := range waiters {
ch <- struct{}{}
}
if err != nil {
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": err.Error(), "code": "models_fetch_error"},
}, nil)
return
}
writeModels(w, fetched)
}
func writeModels(w http.ResponseWriter, mods []models.CursorCliModel) {
cursorModels := make([]map[string]interface{}, len(mods))
for i, m := range mods {
cursorModels[i] = map[string]interface{}{
"id": m.ID,
"object": "model",
"owned_by": "cursor",
"name": m.Name,
}
}
ids := make([]string, len(mods))
for i, m := range mods {
ids[i] = m.ID
}
aliases := models.GetAnthropicModelAliases(ids)
for _, a := range aliases {
cursorModels = append(cursorModels, map[string]interface{}{
"id": a.ID,
"object": "model",
"owned_by": "cursor",
"name": a.Name,
})
}
httputil.WriteJSON(w, 200, map[string]interface{}{
"object": "list",
"data": cursorModels,
}, nil)
}

View File

@ -1,27 +0,0 @@
package handlers
import "cursor-api-proxy/internal/config"
func ResolveModel(requested string, lastModelRef *string, cfg config.BridgeConfig) string {
isAuto := requested == "auto"
var explicitModel string
if requested != "" && !isAuto {
explicitModel = requested
}
if explicitModel != "" {
*lastModelRef = explicitModel
}
if isAuto {
return "auto"
}
if explicitModel != "" {
return explicitModel
}
if cfg.StrictModel && *lastModelRef != "" {
return *lastModelRef
}
if *lastModelRef != "" {
return *lastModelRef
}
return cfg.DefaultModel
}

View File

@ -1,50 +0,0 @@
package httputil
import (
"encoding/json"
"io"
"net/http"
"regexp"
)
var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`)
func ExtractBearerToken(r *http.Request) string {
h := r.Header.Get("Authorization")
if h == "" {
return ""
}
m := bearerRe.FindStringSubmatch(h)
if m == nil {
return ""
}
return m[1]
}
func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) {
for k, v := range extraHeaders {
w.Header().Set(k, v)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) {
for k, v := range extraHeaders {
w.Header().Set(k, v)
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(200)
}
func ReadBody(r *http.Request) (string, error) {
data, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}
return string(data), nil
}

View File

@ -1,50 +0,0 @@
package httputil
import (
"net/http/httptest"
"testing"
)
func TestExtractBearerToken(t *testing.T) {
tests := []struct {
header string
want string
}{
{"Bearer mytoken123", "mytoken123"},
{"bearer MYTOKEN", "MYTOKEN"},
{"", ""},
{"Basic abc", ""},
{"Bearer ", ""},
}
for _, tc := range tests {
req := httptest.NewRequest("GET", "/", nil)
if tc.header != "" {
req.Header.Set("Authorization", tc.header)
}
got := ExtractBearerToken(req)
if got != tc.want {
t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want)
}
}
}
func TestWriteJSON(t *testing.T) {
w := httptest.NewRecorder()
WriteJSON(w, 200, map[string]string{"ok": "true"}, nil)
if w.Code != 200 {
t.Errorf("expected 200, got %d", w.Code)
}
if w.Header().Get("Content-Type") != "application/json" {
t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type"))
}
}
func TestWriteJSONWithExtraHeaders(t *testing.T) {
w := httptest.NewRecorder()
WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"})
if w.Header().Get("X-Custom") != "value" {
t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom"))
}
}

View File

@ -1,309 +0,0 @@
package logger
import (
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/pool"
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
const (
cReset = "\x1b[0m"
cBold = "\x1b[1m"
cDim = "\x1b[2m"
cCyan = "\x1b[36m"
cBCyan = "\x1b[1;96m"
cGreen = "\x1b[32m"
cBGreen = "\x1b[1;92m"
cYellow = "\x1b[33m"
cMagenta = "\x1b[35m"
cBMagenta = "\x1b[1;95m"
cRed = "\x1b[31m"
cGray = "\x1b[90m"
cWhite = "\x1b[97m"
)
var roleStyle = map[string]string{
"system": cYellow,
"user": cCyan,
"assistant": cGreen,
}
var roleEmoji = map[string]string{
"system": "🔧",
"user": "👤",
"assistant": "🤖",
}
func ts() string {
return cGray + time.Now().UTC().Format("15:04:05") + cReset
}
func tsDate() string {
return cGray + time.Now().UTC().Format("2006-01-02 15:04:05") + cReset
}
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
head := int(float64(max) * 0.6)
tail := max - head
omitted := len(s) - head - tail
return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset
}
func hr(ch string, length int) string {
return cGray + strings.Repeat(ch, length) + cReset
}
type TrafficMessage struct {
Role string
Content string
}
func LogDebug(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg)
}
func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) {
provider := cfg.Provider
if provider == "" {
provider = "cursor"
}
fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset)
fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n",
cBold, cBCyan, cReset, cBold, cWhite, version, cBCyan, cReset)
fmt.Printf("%s%s╚══════════════════════════════════════════╝%s\n\n", cBold, cBCyan, cReset)
url := fmt.Sprintf("%s://%s:%d", scheme, host, port)
fmt.Printf(" %s●%s listening %s%s%s\n", cBGreen, cReset, cBold, url, cReset)
fmt.Printf(" %s▸%s provider %s%s%s\n", cCyan, cReset, cBold, provider, cReset)
fmt.Printf(" %s▸%s agent %s%s%s\n", cCyan, cReset, cDim, cfg.AgentBin, cReset)
fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset)
fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset)
fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset)
fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset)
// 顯示 Gemini Web Provider 相關設定
if provider == "gemini-web" {
fmt.Printf(" %s▸%s gemini-dir %s%s%s\n", cCyan, cReset, cDim, cfg.GeminiAccountDir, cReset)
fmt.Printf(" %s▸%s max-sess %s%d%s\n", cCyan, cReset, cDim, cfg.GeminiMaxSessions, cReset)
}
flags := []string{}
if cfg.Force {
flags = append(flags, "force")
}
if cfg.ApproveMcps {
flags = append(flags, "approve-mcps")
}
if cfg.MaxMode {
flags = append(flags, "max-mode")
}
if cfg.Verbose {
flags = append(flags, "verbose")
}
if cfg.ChatOnlyWorkspace {
flags = append(flags, "chat-only")
}
if cfg.RequiredKey != "" {
flags = append(flags, "api-key-required")
}
if len(flags) > 0 {
fmt.Printf(" %s▸%s flags %s%s%s\n", cCyan, cReset, cYellow, strings.Join(flags, " · "), cReset)
}
if len(cfg.ConfigDirs) > 0 {
fmt.Printf(" %s▸%s pool %s%d accounts%s\n", cCyan, cReset, cBGreen, len(cfg.ConfigDirs), cReset)
}
fmt.Println()
}
func LogShutdown(sig string) {
fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset)
}
func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) {
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
if isStream {
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
}
fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n",
ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag)
}
func LogRequestDone(method, pathname, model string, latencyMs int64, code int) {
statusColor := cBGreen
if code != 0 {
statusColor = cRed
}
fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n",
ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset)
}
func LogRequestTimeout(method, pathname, model string, timeoutMs int) {
fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n",
ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset)
}
func LogClientDisconnect(method, pathname, model string, latencyMs int64) {
fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n",
ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset)
}
func LogStreamChunk(model string, text string, chunkNum int) {
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120)
fmt.Printf("%s %s▸%s #%d %s%s%s\n",
ts(), cDim, cReset, chunkNum, cWhite, preview, cReset)
}
func LogRawLine(line string) {
preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200)
fmt.Printf("%s %s│%s %sraw%s %s\n",
ts(), cGray, cReset, cDim, cReset, preview)
}
func LogIncoming(method, pathname, remoteAddress string) {
methodColor := cBCyan
switch method {
case "POST":
methodColor = cBMagenta
case "GET":
methodColor = cBCyan
case "DELETE":
methodColor = cRed
}
fmt.Printf("%s %s%s%s%s %s%s%s %s(%s)%s\n",
ts(),
methodColor, cBold, method, cReset,
cWhite, pathname, cReset,
cDim, remoteAddress, cReset,
)
}
func LogAccountAssigned(configDir string) {
if configDir == "" {
return
}
name := filepath.Base(configDir)
fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset)
}
func LogAccountStats(verbose bool, stats []pool.AccountStat) {
if !verbose || len(stats) == 0 {
return
}
now := time.Now().UnixMilli()
fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset)
for _, s := range stats {
name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir))
active := fmt.Sprintf("%sactive:0%s", cDim, cReset)
if s.ActiveRequests > 0 {
active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset)
}
total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset)
ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset)
errStr := fmt.Sprintf("%serr:0%s", cDim, cReset)
if s.TotalErrors > 0 {
errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset)
}
rl := fmt.Sprintf("%srl:0%s", cDim, cReset)
if s.TotalRateLimits > 0 {
rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset)
}
avg := "avg:-"
if s.TotalRequests > 0 {
avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests))
}
status := fmt.Sprintf("%s✓%s", cGreen, cReset)
if s.IsRateLimited {
recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339)
_ = now
status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset)
}
fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n",
cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status)
}
fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset)
}
func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) {
if !verbose {
return
}
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
if isStream {
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
}
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
fmt.Println(hr("─", 60))
fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag)
for _, m := range messages {
roleColor := cWhite
if c, ok := roleStyle[m.Role]; ok {
roleColor = c
}
emoji := "💬"
if e, ok := roleEmoji[m.Role]; ok {
emoji = e
}
label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset)
charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset)
preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280)
fmt.Printf(" %s %s %s\n", emoji, label, charCount)
fmt.Printf(" %s%s%s\n", cDim, preview, cReset)
}
}
func LogTrafficResponse(verbose bool, model, text string, isStream bool) {
if !verbose {
return
}
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
if isStream {
modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset)
}
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset)
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480)
fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount)
fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset)
fmt.Println(hr("─", 60))
}
func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) {
line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode)
dir := filepath.Dir(logPath)
if err := os.MkdirAll(dir, 0755); err == nil {
f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err == nil {
_, _ = f.WriteString(line)
f.Close()
}
}
}
func LogTruncation(originalLen, finalLen int) {
fmt.Printf("%s %s⚠ prompt truncated%s %s(%d → %d chars, tail preserved)%s\n",
ts(), cYellow, cReset, cDim, originalLen, finalLen, cReset)
}
func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string {
errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr))
fmt.Fprintf(os.Stderr, "%s %s✗ agent error%s %s%s%s\n", ts(), cRed, cReset, cDim, errMsg, cReset)
truncated := strings.TrimSpace(stderr)
if len(truncated) > 200 {
truncated = truncated[:200]
}
truncated = strings.ReplaceAll(truncated, "\n", " ")
line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n",
time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated)
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
_, _ = f.WriteString(line)
f.Close()
}
return errMsg
}

View File

@ -1,62 +0,0 @@
package models
import (
"cursor-api-proxy/internal/process"
"fmt"
"os"
"regexp"
"strings"
)
type CursorCliModel struct {
ID string
Name string
}
var modelLineRe = regexp.MustCompile(`^([A-Za-z0-9][A-Za-z0-9._:/-]*)\s+-\s+(.*)$`)
var trailingParenRe = regexp.MustCompile(`\s*\([^)]*\)\s*$`)
func ParseCursorCliModels(output string) []CursorCliModel {
lines := strings.Split(output, "\n")
seen := make(map[string]CursorCliModel)
var order []string
for _, line := range lines {
line = strings.TrimSpace(line)
m := modelLineRe.FindStringSubmatch(line)
if m == nil {
continue
}
id := m[1]
rawName := m[2]
name := strings.TrimSpace(trailingParenRe.ReplaceAllString(rawName, ""))
if name == "" {
name = id
}
if _, exists := seen[id]; !exists {
seen[id] = CursorCliModel{ID: id, Name: name}
order = append(order, id)
}
}
result := make([]CursorCliModel, 0, len(order))
for _, id := range order {
result = append(result, seen[id])
}
return result
}
func ListCursorCliModels(agentBin string, timeoutMs int) ([]CursorCliModel, error) {
tmpDir := os.TempDir()
result, err := process.Run(agentBin, []string{"--list-models"}, process.RunOptions{
Cwd: tmpDir,
TimeoutMs: timeoutMs,
})
if err != nil {
return nil, err
}
if result.Code != 0 {
return nil, fmt.Errorf("agent --list-models failed: %s", strings.TrimSpace(result.Stderr))
}
return ParseCursorCliModels(result.Stdout), nil
}

View File

@ -1,33 +0,0 @@
package models
import "testing"
func TestParseCursorCliModels(t *testing.T) {
output := `
gpt-4o - GPT-4o (some info)
claude-3-5-sonnet - Claude 3.5 Sonnet
gpt-4o - GPT-4o duplicate
invalid line without dash
`
result := ParseCursorCliModels(output)
if len(result) != 2 {
t.Fatalf("expected 2 unique models, got %d: %v", len(result), result)
}
if result[0].ID != "gpt-4o" {
t.Errorf("expected gpt-4o, got %s", result[0].ID)
}
if result[0].Name != "GPT-4o" {
t.Errorf("expected 'GPT-4o', got %s", result[0].Name)
}
if result[1].ID != "claude-3-5-sonnet" {
t.Errorf("expected claude-3-5-sonnet, got %s", result[1].ID)
}
}
func TestParseCursorCliModelsEmpty(t *testing.T) {
result := ParseCursorCliModels("")
if len(result) != 0 {
t.Fatalf("expected empty, got %v", result)
}
}

View File

@ -1,123 +0,0 @@
package models
import (
"regexp"
"strings"
)
var anthropicToCursor = map[string]string{
"claude-opus-4-6": "opus-4.6",
"claude-opus-4.6": "opus-4.6",
"claude-sonnet-4-6": "sonnet-4.6",
"claude-sonnet-4.6": "sonnet-4.6",
"claude-opus-4-5": "opus-4.5",
"claude-opus-4.5": "opus-4.5",
"claude-sonnet-4-5": "sonnet-4.5",
"claude-sonnet-4.5": "sonnet-4.5",
"claude-opus-4": "opus-4.6",
"claude-sonnet-4": "sonnet-4.6",
"claude-haiku-4-5-20251001": "sonnet-4.5",
"claude-haiku-4-5": "sonnet-4.5",
"claude-haiku-4-6": "sonnet-4.6",
"claude-haiku-4": "sonnet-4.5",
"claude-opus-4-6-thinking": "opus-4.6-thinking",
"claude-sonnet-4-6-thinking": "sonnet-4.6-thinking",
"claude-opus-4-5-thinking": "opus-4.5-thinking",
"claude-sonnet-4-5-thinking": "sonnet-4.5-thinking",
}
type ModelAlias struct {
CursorID string
AnthropicID string
Name string
}
var cursorToAnthropicAlias = []ModelAlias{
{"opus-4.6", "claude-opus-4-6", "Claude 4.6 Opus"},
{"opus-4.6-thinking", "claude-opus-4-6-thinking", "Claude 4.6 Opus (Thinking)"},
{"sonnet-4.6", "claude-sonnet-4-6", "Claude 4.6 Sonnet"},
{"sonnet-4.6-thinking", "claude-sonnet-4-6-thinking", "Claude 4.6 Sonnet (Thinking)"},
{"opus-4.5", "claude-opus-4-5", "Claude 4.5 Opus"},
{"opus-4.5-thinking", "claude-opus-4-5-thinking", "Claude 4.5 Opus (Thinking)"},
{"sonnet-4.5", "claude-sonnet-4-5", "Claude 4.5 Sonnet"},
{"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"},
}
// cursorModelPattern matches cursor model IDs like "opus-4.6", "sonnet-4.7-thinking".
var cursorModelPattern = regexp.MustCompile(`^([a-zA-Z]+)-(\d+)\.(\d+)(-thinking)?$`)
// reverseDynamicPattern matches dynamically generated anthropic aliases
// like "claude-opus-4-7", "claude-sonnet-4-7-thinking".
var reverseDynamicPattern = regexp.MustCompile(`^claude-([a-zA-Z]+)-(\d+)-(\d+)(-thinking)?$`)
func generateDynamicAlias(cursorID string) (AnthropicAlias, bool) {
m := cursorModelPattern.FindStringSubmatch(cursorID)
if m == nil {
return AnthropicAlias{}, false
}
family := m[1]
major := m[2]
minor := m[3]
thinking := m[4]
anthropicID := "claude-" + family + "-" + major + "-" + minor + thinking
capFamily := strings.ToUpper(family[:1]) + family[1:]
name := capFamily + " " + major + "." + minor
if thinking == "-thinking" {
name += " (Thinking)"
}
return AnthropicAlias{ID: anthropicID, Name: name}, true
}
func reverseDynamicAlias(anthropicID string) (string, bool) {
m := reverseDynamicPattern.FindStringSubmatch(anthropicID)
if m == nil {
return "", false
}
return m[1] + "-" + m[2] + "." + m[3] + m[4], true
}
func ResolveToCursorModel(requested string) string {
if strings.TrimSpace(requested) == "" {
return ""
}
key := strings.ToLower(strings.TrimSpace(requested))
if v, ok := anthropicToCursor[key]; ok {
return v
}
if v, ok := reverseDynamicAlias(key); ok {
return v
}
return strings.TrimSpace(requested)
}
type AnthropicAlias struct {
ID string
Name string
}
func GetAnthropicModelAliases(availableCursorIDs []string) []AnthropicAlias {
set := make(map[string]bool, len(availableCursorIDs))
for _, id := range availableCursorIDs {
set[id] = true
}
staticSet := make(map[string]bool, len(cursorToAnthropicAlias))
var result []AnthropicAlias
for _, a := range cursorToAnthropicAlias {
if set[a.CursorID] {
staticSet[a.CursorID] = true
result = append(result, AnthropicAlias{ID: a.AnthropicID, Name: a.Name})
}
}
for _, id := range availableCursorIDs {
if staticSet[id] {
continue
}
if alias, ok := generateDynamicAlias(id); ok {
result = append(result, alias)
}
}
return result
}

View File

@ -1,163 +0,0 @@
package models
import "testing"
func TestGetAnthropicModelAliases_StaticOnly(t *testing.T) {
aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.5"})
if len(aliases) != 2 {
t.Fatalf("expected 2 aliases, got %d: %v", len(aliases), aliases)
}
ids := map[string]string{}
for _, a := range aliases {
ids[a.ID] = a.Name
}
if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" {
t.Errorf("unexpected name for claude-sonnet-4-6: %s", ids["claude-sonnet-4-6"])
}
if ids["claude-opus-4-5"] != "Claude 4.5 Opus" {
t.Errorf("unexpected name for claude-opus-4-5: %s", ids["claude-opus-4-5"])
}
}
func TestGetAnthropicModelAliases_DynamicFallback(t *testing.T) {
aliases := GetAnthropicModelAliases([]string{"sonnet-4.7", "opus-5.0-thinking", "gpt-4o"})
ids := map[string]string{}
for _, a := range aliases {
ids[a.ID] = a.Name
}
if ids["claude-sonnet-4-7"] != "Sonnet 4.7" {
t.Errorf("unexpected name for claude-sonnet-4-7: %s", ids["claude-sonnet-4-7"])
}
if ids["claude-opus-5-0-thinking"] != "Opus 5.0 (Thinking)" {
t.Errorf("unexpected name for claude-opus-5-0-thinking: %s", ids["claude-opus-5-0-thinking"])
}
if _, ok := ids["claude-gpt-4o"]; ok {
t.Errorf("gpt-4o should not generate a claude alias")
}
}
func TestGetAnthropicModelAliases_Mixed(t *testing.T) {
aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.7", "gpt-4o"})
ids := map[string]string{}
for _, a := range aliases {
ids[a.ID] = a.Name
}
// static entry keeps its custom name
if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" {
t.Errorf("static alias should keep original name, got: %s", ids["claude-sonnet-4-6"])
}
// dynamic entry uses auto-generated name
if ids["claude-opus-4-7"] != "Opus 4.7" {
t.Errorf("dynamic alias name mismatch: %s", ids["claude-opus-4-7"])
}
}
func TestGetAnthropicModelAliases_UnknownPattern(t *testing.T) {
aliases := GetAnthropicModelAliases([]string{"some-unknown-model"})
if len(aliases) != 0 {
t.Fatalf("expected 0 aliases for unknown pattern, got %d: %v", len(aliases), aliases)
}
}
func TestResolveToCursorModel_Static(t *testing.T) {
tests := []struct {
input string
want string
}{
{"claude-opus-4-6", "opus-4.6"},
{"claude-opus-4.6", "opus-4.6"},
{"claude-sonnet-4-5", "sonnet-4.5"},
{"claude-opus-4-6-thinking", "opus-4.6-thinking"},
}
for _, tc := range tests {
got := ResolveToCursorModel(tc.input)
if got != tc.want {
t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want)
}
}
}
func TestResolveToCursorModel_DynamicFallback(t *testing.T) {
tests := []struct {
input string
want string
}{
{"claude-opus-4-7", "opus-4.7"},
{"claude-sonnet-5-0", "sonnet-5.0"},
{"claude-opus-4-7-thinking", "opus-4.7-thinking"},
{"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking"},
}
for _, tc := range tests {
got := ResolveToCursorModel(tc.input)
if got != tc.want {
t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want)
}
}
}
func TestResolveToCursorModel_Passthrough(t *testing.T) {
tests := []string{"sonnet-4.6", "gpt-4o", "custom-model"}
for _, input := range tests {
got := ResolveToCursorModel(input)
if got != input {
t.Errorf("ResolveToCursorModel(%q) = %q, want passthrough %q", input, got, input)
}
}
}
func TestResolveToCursorModel_Empty(t *testing.T) {
if got := ResolveToCursorModel(""); got != "" {
t.Errorf("ResolveToCursorModel(\"\") = %q, want empty", got)
}
if got := ResolveToCursorModel(" "); got != "" {
t.Errorf("ResolveToCursorModel(\" \") = %q, want empty", got)
}
}
func TestGenerateDynamicAlias(t *testing.T) {
tests := []struct {
input string
want AnthropicAlias
ok bool
}{
{"opus-4.7", AnthropicAlias{"claude-opus-4-7", "Opus 4.7"}, true},
{"sonnet-5.0-thinking", AnthropicAlias{"claude-sonnet-5-0-thinking", "Sonnet 5.0 (Thinking)"}, true},
{"gpt-4o", AnthropicAlias{}, false},
{"invalid", AnthropicAlias{}, false},
}
for _, tc := range tests {
got, ok := generateDynamicAlias(tc.input)
if ok != tc.ok {
t.Errorf("generateDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok)
continue
}
if ok && (got.ID != tc.want.ID || got.Name != tc.want.Name) {
t.Errorf("generateDynamicAlias(%q) = {%q, %q}, want {%q, %q}", tc.input, got.ID, got.Name, tc.want.ID, tc.want.Name)
}
}
}
func TestReverseDynamicAlias(t *testing.T) {
tests := []struct {
input string
want string
ok bool
}{
{"claude-opus-4-7", "opus-4.7", true},
{"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking", true},
{"claude-opus-4-6", "opus-4.6", true},
{"claude-opus-4.6", "", false},
{"claude-haiku-4-5-20251001", "", false},
{"some-model", "", false},
}
for _, tc := range tests {
got, ok := reverseDynamicAlias(tc.input)
if ok != tc.ok {
t.Errorf("reverseDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok)
continue
}
if ok && got != tc.want {
t.Errorf("reverseDynamicAlias(%q) = %q, want %q", tc.input, got, tc.want)
}
}
}

View File

@ -1,243 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"strings"
)
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []interface{} `json:"messages"`
Stream bool `json:"stream"`
Tools []interface{} `json:"tools"`
ToolChoice interface{} `json:"tool_choice"`
Functions []interface{} `json:"functions"`
FunctionCall interface{} `json:"function_call"`
}
func NormalizeModelID(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
parts := strings.Split(trimmed, "/")
last := parts[len(parts)-1]
if last == "" {
return ""
}
return last
}
func imageURLToText(imageURL interface{}) string {
if imageURL == nil {
return "[Image]"
}
var url string
switch v := imageURL.(type) {
case string:
url = v
case map[string]interface{}:
if u, ok := v["url"].(string); ok {
url = u
}
}
if url == "" {
return "[Image]"
}
if strings.HasPrefix(url, "data:") {
end := strings.Index(url, ";")
mime := "image"
if end > 5 {
mime = url[5:end]
}
return "[Image: base64 " + mime + "]"
}
return "[Image: " + url + "]"
}
func MessageContentToText(content interface{}) string {
if content == nil {
return ""
}
switch v := content.(type) {
case string:
return v
case []interface{}:
var parts []string
for _, p := range v {
if p == nil {
continue
}
switch part := p.(type) {
case string:
parts = append(parts, part)
case map[string]interface{}:
typ, _ := part["type"].(string)
switch typ {
case "text":
if t, ok := part["text"].(string); ok {
parts = append(parts, t)
}
case "image_url":
parts = append(parts, imageURLToText(part["image_url"]))
case "image":
src := part["source"]
if src == nil {
src = part["url"]
}
parts = append(parts, imageURLToText(src))
}
}
}
return strings.Join(parts, " ")
}
return ""
}
func ToolsToSystemText(tools []interface{}, functions []interface{}) string {
var defs []interface{}
for _, t := range tools {
if m, ok := t.(map[string]interface{}); ok {
if m["type"] == "function" {
if fn := m["function"]; fn != nil {
defs = append(defs, fn)
}
} else {
defs = append(defs, t)
}
}
}
defs = append(defs, functions...)
if len(defs) == 0 {
return ""
}
var lines []string
lines = append(lines, "Available tools (respond with a JSON object to call one):", "")
for _, raw := range defs {
fn, ok := raw.(map[string]interface{})
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
params := "{}"
if p := fn["parameters"]; p != nil {
if b, err := json.MarshalIndent(p, "", " "); err == nil {
params = string(b)
}
} else if p := fn["input_schema"]; p != nil {
if b, err := json.MarshalIndent(p, "", " "); err == nil {
params = string(b)
}
}
lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params)
}
lines = append(lines, "",
"When you want to call a tool, use this EXACT format:",
"",
"<tool_call>",
`{"name": "function_name", "arguments": {"param1": "value1"}}`,
"</tool_call>",
"",
"Rules:",
"- Write your reasoning BEFORE the tool call",
"- You may make multiple tool calls by using multiple <tool_call> blocks",
"- STOP writing after the last </tool_call> tag",
"- If no tool is needed, respond normally without <tool_call> tags",
)
return strings.Join(lines, "\n")
}
type SimpleMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
func BuildPromptFromMessages(messages []interface{}) string {
var systemParts []string
var convo []string
for _, raw := range messages {
m, ok := raw.(map[string]interface{})
if !ok {
continue
}
role, _ := m["role"].(string)
text := MessageContentToText(m["content"])
switch role {
case "system", "developer":
if text != "" {
systemParts = append(systemParts, text)
}
case "user":
if text != "" {
convo = append(convo, "User: "+text)
}
case "assistant":
toolCalls, _ := m["tool_calls"].([]interface{})
if len(toolCalls) > 0 {
var parts []string
if text != "" {
parts = append(parts, text)
}
for _, tc := range toolCalls {
tcMap, ok := tc.(map[string]interface{})
if !ok {
continue
}
fn, _ := tcMap["function"].(map[string]interface{})
if fn == nil {
continue
}
name, _ := fn["name"].(string)
args, _ := fn["arguments"].(string)
if args == "" {
args = "{}"
}
parts = append(parts, fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, args))
}
if len(parts) > 0 {
convo = append(convo, "Assistant: "+strings.Join(parts, "\n"))
}
} else if text != "" {
convo = append(convo, "Assistant: "+text)
}
case "tool", "function":
name, _ := m["name"].(string)
toolCallID, _ := m["tool_call_id"].(string)
label := "Tool result"
if name != "" {
label = "Tool result (" + name + ")"
}
if toolCallID != "" {
label += " [id=" + toolCallID + "]"
}
if text != "" {
convo = append(convo, label+": "+text)
}
}
}
system := ""
if len(systemParts) > 0 {
system = "System:\n" + strings.Join(systemParts, "\n\n") + "\n\n"
}
transcript := strings.Join(convo, "\n\n")
return system + transcript + "\n\nAssistant:"
}
func BuildPromptFromSimpleMessages(messages []SimpleMessage) string {
ifaces := make([]interface{}, len(messages))
for i, m := range messages {
ifaces[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
}
return BuildPromptFromMessages(ifaces)
}

View File

@ -1,80 +0,0 @@
package openai
import "testing"
func TestNormalizeModelID(t *testing.T) {
tests := []struct {
input string
want string
}{
{"gpt-4", "gpt-4"},
{"openai/gpt-4", "gpt-4"},
{"anthropic/claude-3", "claude-3"},
{"", ""},
{" ", ""},
{"a/b/c", "c"},
}
for _, tc := range tests {
got := NormalizeModelID(tc.input)
if got != tc.want {
t.Errorf("NormalizeModelID(%q) = %q, want %q", tc.input, got, tc.want)
}
}
}
func TestBuildPromptFromMessages(t *testing.T) {
messages := []interface{}{
map[string]interface{}{"role": "system", "content": "You are helpful."},
map[string]interface{}{"role": "user", "content": "Hello"},
map[string]interface{}{"role": "assistant", "content": "Hi there"},
}
got := BuildPromptFromMessages(messages)
if got == "" {
t.Fatal("expected non-empty prompt")
}
containsSystem := false
containsUser := false
containsAssistant := false
for i := 0; i < len(got)-10; i++ {
if got[i:i+6] == "System" {
containsSystem = true
}
if got[i:i+4] == "User" {
containsUser = true
}
if got[i:i+9] == "Assistant" {
containsAssistant = true
}
}
if !containsSystem || !containsUser || !containsAssistant {
t.Errorf("prompt missing sections: system=%v user=%v assistant=%v\n%s",
containsSystem, containsUser, containsAssistant, got)
}
}
func TestToolsToSystemText(t *testing.T) {
tools := []interface{}{
map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": "get_weather",
"description": "Get weather",
"parameters": map[string]interface{}{"type": "object"},
},
},
}
got := ToolsToSystemText(tools, nil)
if got == "" {
t.Fatal("expected non-empty tools text")
}
if len(got) < 10 {
t.Errorf("tools text too short: %q", got)
}
}
func TestToolsToSystemTextEmpty(t *testing.T) {
got := ToolsToSystemText(nil, nil)
if got != "" {
t.Errorf("expected empty string for no tools, got %q", got)
}
}

View File

@ -1,110 +0,0 @@
package parser
import "encoding/json"
type StreamParser func(line string)
type Parser struct {
Parse StreamParser
Flush func()
}
// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking
func CreateStreamParser(onText func(string), onDone func()) Parser {
return CreateStreamParserWithThinking(onText, nil, onDone)
}
// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。
// onThinking 可為 nil表示忽略思考過程。
func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser {
// accumulated 是所有已輸出內容的串接
accumulatedText := ""
accumulatedThinking := ""
done := false
parse := func(line string) {
if done {
return
}
var obj struct {
Type string `json:"type"`
Subtype string `json:"subtype"`
Message *struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
Thinking string `json:"thinking"`
} `json:"content"`
} `json:"message"`
}
if err := json.Unmarshal([]byte(line), &obj); err != nil {
return
}
if obj.Type == "assistant" && obj.Message != nil {
fullText := ""
fullThinking := ""
for _, p := range obj.Message.Content {
switch p.Type {
case "text":
if p.Text != "" {
fullText += p.Text
}
case "thinking":
if p.Thinking != "" {
fullThinking += p.Thinking
}
}
}
// 處理思考過程(不因去重而 return避免跳過同行的文字內容
if onThinking != nil && fullThinking != "" && fullThinking != accumulatedThinking {
// 增量模式:新內容以 accumulated 為前綴
if len(fullThinking) >= len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking {
delta := fullThinking[len(accumulatedThinking):]
if delta != "" {
onThinking(delta)
}
accumulatedThinking = fullThinking
} else {
// 獨立片段:直接輸出,但 accumulated 要串接
onThinking(fullThinking)
accumulatedThinking = accumulatedThinking + fullThinking
}
}
// 處理一般文字
if fullText == "" || fullText == accumulatedText {
return
}
// 增量模式:新內容以 accumulated 為前綴
if len(fullText) >= len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText {
delta := fullText[len(accumulatedText):]
if delta != "" {
onText(delta)
}
accumulatedText = fullText
} else {
// 獨立片段:直接輸出,但 accumulated 要串接
onText(fullText)
accumulatedText = accumulatedText + fullText
}
}
if obj.Type == "result" && obj.Subtype == "success" {
done = true
onDone()
}
}
flush := func() {
if !done {
done = true
onDone()
}
}
return Parser{Parse: parse, Flush: flush}
}

View File

@ -1,304 +0,0 @@
package parser
import (
"encoding/json"
"testing"
)
func makeAssistantLine(text string) string {
obj := map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "text", "text": text},
},
},
}
b, _ := json.Marshal(obj)
return string(b)
}
func makeResultLine() string {
b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"})
return string(b)
}
func TestStreamParserFragmentMode(t *testing.T) {
// cursor --stream-partial-output 模式:每個訊息是獨立 token fragment
var texts []string
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() {},
)
p.Parse(makeAssistantLine("你"))
p.Parse(makeAssistantLine("好!有"))
p.Parse(makeAssistantLine("什"))
p.Parse(makeAssistantLine("麼"))
if len(texts) != 4 {
t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts)
}
if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" {
t.Fatalf("unexpected texts: %v", texts)
}
}
func TestStreamParserDeduplicatesFinalFullText(t *testing.T) {
// 最後一個訊息是完整的累積文字,應被跳過(去重)
var texts []string
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() {},
)
p.Parse(makeAssistantLine("Hello"))
p.Parse(makeAssistantLine(" world"))
// 最後一個是完整累積文字,應被去重
p.Parse(makeAssistantLine("Hello world"))
if len(texts) != 2 {
t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts)
}
if texts[0] != "Hello" || texts[1] != " world" {
t.Fatalf("unexpected texts: %v", texts)
}
}
func TestStreamParserCallsOnDone(t *testing.T) {
var texts []string
doneCount := 0
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() { doneCount++ },
)
p.Parse(makeResultLine())
if doneCount != 1 {
t.Fatalf("expected onDone called once, got %d", doneCount)
}
if len(texts) != 0 {
t.Fatalf("expected no text, got %v", texts)
}
}
func TestStreamParserIgnoresLinesAfterDone(t *testing.T) {
var texts []string
doneCount := 0
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() { doneCount++ },
)
p.Parse(makeResultLine())
p.Parse(makeAssistantLine("late"))
if len(texts) != 0 {
t.Fatalf("expected no text after done, got %v", texts)
}
if doneCount != 1 {
t.Fatalf("expected onDone called once, got %d", doneCount)
}
}
func TestStreamParserIgnoresNonAssistantLines(t *testing.T) {
var texts []string
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() {},
)
b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}})
p.Parse(string(b1))
b2, _ := json.Marshal(map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{"content": []interface{}{}},
})
p.Parse(string(b2))
p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`)
if len(texts) != 0 {
t.Fatalf("expected no texts, got %v", texts)
}
}
func TestStreamParserIgnoresParseErrors(t *testing.T) {
var texts []string
doneCount := 0
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() { doneCount++ },
)
p.Parse("not json")
p.Parse("{")
p.Parse("")
if len(texts) != 0 || doneCount != 0 {
t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount)
}
}
func TestStreamParserJoinsMultipleTextParts(t *testing.T) {
var texts []string
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() {},
)
obj := map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "text", "text": "Hello"},
{"type": "text", "text": " "},
{"type": "text", "text": "world"},
},
},
}
b, _ := json.Marshal(obj)
p.Parse(string(b))
if len(texts) != 1 || texts[0] != "Hello world" {
t.Fatalf("expected ['Hello world'], got %v", texts)
}
}
func TestStreamParserFlushTriggersDone(t *testing.T) {
var texts []string
doneCount := 0
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() { doneCount++ },
)
p.Parse(makeAssistantLine("Hello"))
// agent 結束但沒有 result/success手動 flush
p.Flush()
if doneCount != 1 {
t.Fatalf("expected onDone called once after Flush, got %d", doneCount)
}
// 再 flush 不應重複觸發
p.Flush()
if doneCount != 1 {
t.Fatalf("expected onDone called only once, got %d", doneCount)
}
}
func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) {
doneCount := 0
p := CreateStreamParser(
func(text string) {},
func() { doneCount++ },
)
p.Parse(makeResultLine())
p.Flush()
if doneCount != 1 {
t.Fatalf("expected onDone called once, got %d", doneCount)
}
}
func makeThinkingLine(thinking string) string {
obj := map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "thinking", "thinking": thinking},
},
},
}
b, _ := json.Marshal(obj)
return string(b)
}
func makeThinkingAndTextLine(thinking, text string) string {
obj := map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "thinking", "thinking": thinking},
{"type": "text", "text": text},
},
},
}
b, _ := json.Marshal(obj)
return string(b)
}
func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) {
var texts []string
var thinkings []string
p := CreateStreamParserWithThinking(
func(text string) { texts = append(texts, text) },
func(thinking string) { thinkings = append(thinkings, thinking) },
func() {},
)
p.Parse(makeThinkingLine("思考中..."))
p.Parse(makeAssistantLine("回答"))
if len(thinkings) != 1 || thinkings[0] != "思考中..." {
t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings)
}
if len(texts) != 1 || texts[0] != "回答" {
t.Fatalf("expected texts=['回答'], got %v", texts)
}
}
func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) {
var texts []string
p := CreateStreamParserWithThinking(
func(text string) { texts = append(texts, text) },
nil,
func() {},
)
p.Parse(makeThinkingLine("忽略的思考"))
p.Parse(makeAssistantLine("文字"))
if len(texts) != 1 || texts[0] != "文字" {
t.Fatalf("expected texts=['文字'], got %v", texts)
}
}
func TestStreamParserWithThinkingDeduplication(t *testing.T) {
var thinkings []string
p := CreateStreamParserWithThinking(
func(text string) {},
func(thinking string) { thinkings = append(thinkings, thinking) },
func() {},
)
p.Parse(makeThinkingLine("A"))
p.Parse(makeThinkingLine("B"))
// 重複的完整思考,應被跳過
p.Parse(makeThinkingLine("AB"))
if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" {
t.Fatalf("expected thinkings=['A','B'], got %v", thinkings)
}
}
// TestStreamParserThinkingDuplicateButTextStillEmitted 驗證 bug 修復:
// 當 thinking 重複(去重跳過)但同一行有 text 時text 仍必須輸出。
func TestStreamParserThinkingDuplicateButTextStillEmitted(t *testing.T) {
var texts []string
var thinkings []string
p := CreateStreamParserWithThinking(
func(text string) { texts = append(texts, text) },
func(thinking string) { thinkings = append(thinkings, thinking) },
func() {},
)
// 第一行thinking="思考中" + textthinking 為新增,兩者都應輸出)
p.Parse(makeThinkingAndTextLine("思考中", "第一段"))
// 第二行thinking 與上一行相同(去重),但 text 是新的text 仍應輸出
p.Parse(makeThinkingAndTextLine("思考中", "第二段"))
if len(thinkings) != 1 || thinkings[0] != "思考中" {
t.Fatalf("expected thinkings=['思考中'], got %v", thinkings)
}
if len(texts) != 2 || texts[0] != "第一段" || texts[1] != "第二段" {
t.Fatalf("expected texts=['第一段','第二段'], got %v", texts)
}
}

View File

@ -1,284 +0,0 @@
package pool
import (
"sync"
"time"
)
type accountStatus struct {
configDir string
activeRequests int
lastUsed int64
rateLimitUntil int64
totalRequests int
totalSuccess int
totalErrors int
totalRateLimits int
totalLatencyMs int64
}
type AccountStat struct {
ConfigDir string
ActiveRequests int
TotalRequests int
TotalSuccess int
TotalErrors int
TotalRateLimits int
TotalLatencyMs int64
IsRateLimited bool
RateLimitUntil int64
}
type AccountPool struct {
mu sync.Mutex
accounts []*accountStatus
}
func NewAccountPool(configDirs []string) *AccountPool {
accounts := make([]*accountStatus, 0, len(configDirs))
for _, dir := range configDirs {
accounts = append(accounts, &accountStatus{configDir: dir})
}
return &AccountPool{accounts: accounts}
}
func (p *AccountPool) GetNextConfigDir() string {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.accounts) == 0 {
return ""
}
now := time.Now().UnixMilli()
available := make([]*accountStatus, 0, len(p.accounts))
for _, a := range p.accounts {
if a.rateLimitUntil < now {
available = append(available, a)
}
}
target := available
if len(target) == 0 {
target = make([]*accountStatus, len(p.accounts))
copy(target, p.accounts)
// sort by earliest recovery
for i := 1; i < len(target); i++ {
for j := i; j > 0 && target[j].rateLimitUntil < target[j-1].rateLimitUntil; j-- {
target[j], target[j-1] = target[j-1], target[j]
}
}
}
// pick least busy then least recently used
best := target[0]
for _, a := range target[1:] {
if a.activeRequests < best.activeRequests {
best = a
} else if a.activeRequests == best.activeRequests && a.lastUsed < best.lastUsed {
best = a
}
}
best.lastUsed = now
return best.configDir
}
func (p *AccountPool) find(configDir string) *accountStatus {
for _, a := range p.accounts {
if a.configDir == configDir {
return a
}
}
return nil
}
func (p *AccountPool) ReportRequestStart(configDir string) {
if configDir == "" {
return
}
p.mu.Lock()
defer p.mu.Unlock()
if a := p.find(configDir); a != nil {
a.activeRequests++
a.totalRequests++
}
}
func (p *AccountPool) ReportRequestEnd(configDir string) {
if configDir == "" {
return
}
p.mu.Lock()
defer p.mu.Unlock()
if a := p.find(configDir); a != nil && a.activeRequests > 0 {
a.activeRequests--
}
}
func (p *AccountPool) ReportRequestSuccess(configDir string, latencyMs int64) {
if configDir == "" {
return
}
p.mu.Lock()
defer p.mu.Unlock()
if a := p.find(configDir); a != nil {
a.totalSuccess++
a.totalLatencyMs += latencyMs
}
}
func (p *AccountPool) ReportRequestError(configDir string, latencyMs int64) {
if configDir == "" {
return
}
p.mu.Lock()
defer p.mu.Unlock()
if a := p.find(configDir); a != nil {
a.totalErrors++
a.totalLatencyMs += latencyMs
}
}
func (p *AccountPool) ReportRateLimit(configDir string, penaltyMs int64) {
if configDir == "" {
return
}
if penaltyMs <= 0 {
penaltyMs = 60000
}
p.mu.Lock()
defer p.mu.Unlock()
if a := p.find(configDir); a != nil {
a.rateLimitUntil = time.Now().UnixMilli() + penaltyMs
a.totalRateLimits++
}
}
func (p *AccountPool) GetStats() []AccountStat {
p.mu.Lock()
defer p.mu.Unlock()
now := time.Now().UnixMilli()
stats := make([]AccountStat, len(p.accounts))
for i, a := range p.accounts {
stats[i] = AccountStat{
ConfigDir: a.configDir,
ActiveRequests: a.activeRequests,
TotalRequests: a.totalRequests,
TotalSuccess: a.totalSuccess,
TotalErrors: a.totalErrors,
TotalRateLimits: a.totalRateLimits,
TotalLatencyMs: a.totalLatencyMs,
IsRateLimited: a.rateLimitUntil > now,
RateLimitUntil: a.rateLimitUntil,
}
}
return stats
}
func (p *AccountPool) Count() int {
return len(p.accounts)
}
// ─── PoolHandle interface ──────────────────────────────────────────────────
// PoolHandle 讓 handler 可以注入獨立的 pool 實例,避免多 port 模式共用全域 pool。
type PoolHandle interface {
GetNextConfigDir() string
ReportRequestStart(configDir string)
ReportRequestEnd(configDir string)
ReportRequestSuccess(configDir string, latencyMs int64)
ReportRequestError(configDir string, latencyMs int64)
ReportRateLimit(configDir string, penaltyMs int64)
GetStats() []AccountStat
}
// GlobalPoolHandle 包裝全域函式以實作 PoolHandle 介面(單 port 模式使用)
type GlobalPoolHandle struct{}
func (GlobalPoolHandle) GetNextConfigDir() string { return GetNextAccountConfigDir() }
func (GlobalPoolHandle) ReportRequestStart(d string) { ReportRequestStart(d) }
func (GlobalPoolHandle) ReportRequestEnd(d string) { ReportRequestEnd(d) }
func (GlobalPoolHandle) ReportRequestSuccess(d string, l int64) { ReportRequestSuccess(d, l) }
func (GlobalPoolHandle) ReportRequestError(d string, l int64) { ReportRequestError(d, l) }
func (GlobalPoolHandle) ReportRateLimit(d string, p int64) { ReportRateLimit(d, p) }
func (GlobalPoolHandle) GetStats() []AccountStat { return GetAccountStats() }
// ─── Global pool ───────────────────────────────────────────────────────────
var (
globalPool *AccountPool
globalMu sync.Mutex
)
func InitAccountPool(configDirs []string) {
globalMu.Lock()
defer globalMu.Unlock()
globalPool = NewAccountPool(configDirs)
}
func GetNextAccountConfigDir() string {
globalMu.Lock()
p := globalPool
globalMu.Unlock()
if p == nil {
return ""
}
return p.GetNextConfigDir()
}
func ReportRequestStart(configDir string) {
globalMu.Lock()
p := globalPool
globalMu.Unlock()
if p != nil {
p.ReportRequestStart(configDir)
}
}
func ReportRequestEnd(configDir string) {
globalMu.Lock()
p := globalPool
globalMu.Unlock()
if p != nil {
p.ReportRequestEnd(configDir)
}
}
func ReportRequestSuccess(configDir string, latencyMs int64) {
globalMu.Lock()
p := globalPool
globalMu.Unlock()
if p != nil {
p.ReportRequestSuccess(configDir, latencyMs)
}
}
func ReportRequestError(configDir string, latencyMs int64) {
globalMu.Lock()
p := globalPool
globalMu.Unlock()
if p != nil {
p.ReportRequestError(configDir, latencyMs)
}
}
func ReportRateLimit(configDir string, penaltyMs int64) {
globalMu.Lock()
p := globalPool
globalMu.Unlock()
if p != nil {
p.ReportRateLimit(configDir, penaltyMs)
}
}
func GetAccountStats() []AccountStat {
globalMu.Lock()
p := globalPool
globalMu.Unlock()
if p == nil {
return nil
}
return p.GetStats()
}

View File

@ -1,152 +0,0 @@
package pool
import (
"testing"
"time"
)
func TestEmptyPool(t *testing.T) {
p := NewAccountPool(nil)
if got := p.GetNextConfigDir(); got != "" {
t.Fatalf("expected empty string for empty pool, got %q", got)
}
if p.Count() != 0 {
t.Fatalf("expected count 0, got %d", p.Count())
}
}
func TestSingleDir(t *testing.T) {
p := NewAccountPool([]string{"/dir1"})
if got := p.GetNextConfigDir(); got != "/dir1" {
t.Fatalf("expected /dir1, got %q", got)
}
if got := p.GetNextConfigDir(); got != "/dir1" {
t.Fatalf("expected /dir1 again, got %q", got)
}
}
func TestRoundRobin(t *testing.T) {
p := NewAccountPool([]string{"/a", "/b", "/c"})
got := []string{
p.GetNextConfigDir(),
p.GetNextConfigDir(),
p.GetNextConfigDir(),
p.GetNextConfigDir(),
}
want := []string{"/a", "/b", "/c", "/a"}
for i, w := range want {
if got[i] != w {
t.Fatalf("call %d: expected %q, got %q", i, w, got[i])
}
}
}
func TestLeastBusy(t *testing.T) {
p := NewAccountPool([]string{"/dir1", "/dir2", "/dir3"})
p.ReportRequestStart("/dir1")
p.ReportRequestStart("/dir2")
if got := p.GetNextConfigDir(); got != "/dir3" {
t.Fatalf("expected /dir3 (least busy), got %q", got)
}
p.ReportRequestStart("/dir3")
p.ReportRequestEnd("/dir1")
if got := p.GetNextConfigDir(); got != "/dir1" {
t.Fatalf("expected /dir1 after end, got %q", got)
}
}
func TestSkipsRateLimited(t *testing.T) {
p := NewAccountPool([]string{"/dir1", "/dir2"})
p.ReportRateLimit("/dir1", 60000)
if got := p.GetNextConfigDir(); got != "/dir2" {
t.Fatalf("expected /dir2, got %q", got)
}
if got := p.GetNextConfigDir(); got != "/dir2" {
t.Fatalf("expected /dir2 again, got %q", got)
}
}
func TestFallbackToSoonestRecovery(t *testing.T) {
p := NewAccountPool([]string{"/dir1", "/dir2"})
p.ReportRateLimit("/dir1", 60000)
p.ReportRateLimit("/dir2", 30000)
// dir2 recovers sooner — should be selected
if got := p.GetNextConfigDir(); got != "/dir2" {
t.Fatalf("expected /dir2 (sooner recovery), got %q", got)
}
}
func TestActiveRequestsDoesNotGoNegative(t *testing.T) {
p := NewAccountPool([]string{"/dir1"})
p.ReportRequestEnd("/dir1")
p.ReportRequestEnd("/dir1")
if got := p.GetNextConfigDir(); got != "/dir1" {
t.Fatalf("pool should still work, got %q", got)
}
}
func TestIgnoreUnknownConfigDir(t *testing.T) {
p := NewAccountPool([]string{"/dir1"})
p.ReportRequestStart("/nonexistent")
p.ReportRequestEnd("/nonexistent")
p.ReportRateLimit("/nonexistent", 60000)
if got := p.GetNextConfigDir(); got != "/dir1" {
t.Fatalf("expected /dir1, got %q", got)
}
}
func TestRateLimitExpires(t *testing.T) {
p := NewAccountPool([]string{"/dir1", "/dir2"})
p.ReportRateLimit("/dir1", 50)
if got := p.GetNextConfigDir(); got != "/dir2" {
t.Fatalf("immediately expected /dir2, got %q", got)
}
time.Sleep(100 * time.Millisecond)
if got := p.GetNextConfigDir(); got != "/dir1" {
t.Fatalf("after expiry expected /dir1, got %q", got)
}
}
func TestGlobalPool(t *testing.T) {
InitAccountPool([]string{"/g1", "/g2"})
if got := GetNextAccountConfigDir(); got != "/g1" {
t.Fatalf("expected /g1, got %q", got)
}
if got := GetNextAccountConfigDir(); got != "/g2" {
t.Fatalf("expected /g2, got %q", got)
}
if got := GetNextAccountConfigDir(); got != "/g1" {
t.Fatalf("expected /g1 again, got %q", got)
}
}
func TestGlobalPoolEmpty(t *testing.T) {
InitAccountPool(nil)
if got := GetNextAccountConfigDir(); got != "" {
t.Fatalf("expected empty string for empty global pool, got %q", got)
}
}
func TestGlobalPoolReinit(t *testing.T) {
InitAccountPool([]string{"/old1", "/old2"})
GetNextAccountConfigDir()
InitAccountPool([]string{"/new1"})
if got := GetNextAccountConfigDir(); got != "/new1" {
t.Fatalf("expected /new1 after reinit, got %q", got)
}
}
func TestGlobalPoolFunctionsNoopBeforeInit(t *testing.T) {
InitAccountPool(nil)
ReportRequestStart("/dir1")
ReportRequestEnd("/dir1")
ReportRateLimit("/dir1", 1000)
}

View File

@ -1,21 +0,0 @@
//go:build !windows
package process
import (
"os/exec"
"syscall"
)
func killProcessGroup(c *exec.Cmd) error {
if c.Process == nil {
return nil
}
// 殺死整個 process group負號表示 group
pgid, err := syscall.Getpgid(c.Process.Pid)
if err == nil {
_ = syscall.Kill(-pgid, syscall.SIGKILL)
}
// 同時也 kill 主程序,以防萬一
return c.Process.Kill()
}

View File

@ -1,14 +0,0 @@
//go:build windows
package process
import (
"os/exec"
)
func killProcessGroup(c *exec.Cmd) error {
if c.Process == nil {
return nil
}
return c.Process.Kill()
}

View File

@ -1,250 +0,0 @@
package process
import (
"bufio"
"context"
"cursor-api-proxy/internal/env"
"fmt"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)
type RunResult struct {
Code int
Stdout string
Stderr string
}
type RunOptions struct {
Cwd string
TimeoutMs int
MaxMode bool
ConfigDir string
Ctx context.Context
}
type RunStreamingOptions struct {
RunOptions
OnLine func(line string)
}
// ─── Global child process registry ──────────────────────────────────────────
var (
activeMu sync.Mutex
activeChildren []*exec.Cmd
)
func registerChild(c *exec.Cmd) {
activeMu.Lock()
activeChildren = append(activeChildren, c)
activeMu.Unlock()
}
func unregisterChild(c *exec.Cmd) {
activeMu.Lock()
for i, ch := range activeChildren {
if ch == c {
activeChildren = append(activeChildren[:i], activeChildren[i+1:]...)
break
}
}
activeMu.Unlock()
}
func KillAllChildProcesses() {
activeMu.Lock()
all := make([]*exec.Cmd, len(activeChildren))
copy(all, activeChildren)
activeChildren = nil
activeMu.Unlock()
for _, c := range all {
killProcessGroup(c)
}
}
// ─── Spawn ────────────────────────────────────────────────────────────────
func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd {
envSrc := env.OsEnvToMap()
resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd)
if opts.MaxMode && maxModeFn != nil {
maxModeFn(resolved.AgentScriptPath, opts.ConfigDir)
}
envMap := make(map[string]string, len(resolved.Env))
for k, v := range resolved.Env {
envMap[k] = v
}
if opts.ConfigDir != "" {
envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir
} else if resolved.ConfigDir != "" {
if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists {
envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir
}
}
envSlice := make([]string, 0, len(envMap))
for k, v := range envMap {
envSlice = append(envSlice, k+"="+v)
}
ctx := opts.Ctx
if ctx == nil {
ctx = context.Background()
}
// 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出
c := exec.CommandContext(ctx, resolved.Command, resolved.Args...)
c.Dir = opts.Cwd
c.Env = envSlice
// 設定新的 process group使 kill 能傳遞給所有子孫程序
c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
// WaitDelaycontext cancel 後額外等待這麼久再強制關閉 pipes
c.WaitDelay = 5 * time.Second
// Cancel 函式:殺死整個 process group
c.Cancel = func() error {
return killProcessGroup(c)
}
return c
}
// MaxModeFn is set by the agent package to avoid import cycle.
var MaxModeFn func(agentScriptPath, configDir string)
func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) {
ctx := opts.Ctx
var cancel context.CancelFunc
if opts.TimeoutMs > 0 {
if ctx == nil {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
} else {
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
}
defer cancel()
opts.Ctx = ctx
} else if ctx == nil {
opts.Ctx = context.Background()
}
c := spawnChild(cmdStr, args, &opts, MaxModeFn)
var stdoutBuf, stderrBuf strings.Builder
c.Stdout = &stdoutBuf
c.Stderr = &stderrBuf
if err := c.Start(); err != nil {
// context 已取消或命令找不到時
if opts.Ctx != nil && opts.Ctx.Err() != nil {
return RunResult{Code: -1}, nil
}
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
}
return RunResult{}, err
}
registerChild(c)
defer unregisterChild(c)
err := c.Wait()
code := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
code = exitErr.ExitCode()
if code == 0 {
code = -1
}
} else {
// context cancelled or killed — return -1 but no error
return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil
}
}
return RunResult{
Code: code,
Stdout: stdoutBuf.String(),
Stderr: stderrBuf.String(),
}, nil
}
type StreamResult struct {
Code int
Stderr string
}
func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) {
ctx := opts.Ctx
var cancel context.CancelFunc
if opts.TimeoutMs > 0 {
if ctx == nil {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
} else {
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
}
defer cancel()
opts.RunOptions.Ctx = ctx
} else if opts.RunOptions.Ctx == nil {
opts.RunOptions.Ctx = context.Background()
}
c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn)
stdoutPipe, err := c.StdoutPipe()
if err != nil {
return StreamResult{}, err
}
stderrPipe, err := c.StderrPipe()
if err != nil {
return StreamResult{}, err
}
if err := c.Start(); err != nil {
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
}
return StreamResult{}, err
}
registerChild(c)
defer unregisterChild(c)
var stderrBuf strings.Builder
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stdoutPipe)
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) != "" {
opts.OnLine(line)
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stderrPipe)
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
for scanner.Scan() {
stderrBuf.WriteString(scanner.Text())
stderrBuf.WriteString("\n")
}
}()
wg.Wait()
err = c.Wait()
code := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
code = exitErr.ExitCode()
if code == 0 {
code = -1
}
}
}
return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil
}

View File

@ -1,283 +0,0 @@
package process_test
import (
"context"
"cursor-api-proxy/internal/process"
"os"
"testing"
"time"
)
// sh 是跨平台 shell 執行小 script 的輔助函式
func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) {
t.Helper()
return process.Run("sh", []string{"-c", script}, opts)
}
func TestRun_StdoutAndStderr(t *testing.T) {
result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{})
if err != nil {
t.Fatal(err)
}
if result.Code != 0 {
t.Errorf("Code = %d, want 0", result.Code)
}
if result.Stdout != "hello\n" {
t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n")
}
if result.Stderr != "world\n" {
t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n")
}
}
func TestRun_BasicSpawn(t *testing.T) {
result, err := sh(t, "printf ok", process.RunOptions{})
if err != nil {
t.Fatal(err)
}
if result.Code != 0 {
t.Errorf("Code = %d, want 0", result.Code)
}
if result.Stdout != "ok" {
t.Errorf("Stdout = %q, want ok", result.Stdout)
}
}
func TestRun_ConfigDir_Propagated(t *testing.T) {
result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
process.RunOptions{ConfigDir: "/test/account/dir"})
if err != nil {
t.Fatal(err)
}
if result.Stdout != "/test/account/dir" {
t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout)
}
}
func TestRun_ConfigDir_Absent(t *testing.T) {
// 確保沒有殘留的環境變數
_ = os.Unsetenv("CURSOR_CONFIG_DIR")
result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`},
process.RunOptions{})
if err != nil {
t.Fatal(err)
}
if result.Stdout != "unset" {
t.Errorf("Stdout = %q, want unset", result.Stdout)
}
}
func TestRun_NonZeroExit(t *testing.T) {
result, err := sh(t, "exit 42", process.RunOptions{})
if err != nil {
t.Fatal(err)
}
if result.Code != 42 {
t.Errorf("Code = %d, want 42", result.Code)
}
}
func TestRun_Timeout(t *testing.T) {
start := time.Now()
result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300})
elapsed := time.Since(start)
if err != nil {
t.Fatal(err)
}
if result.Code == 0 {
t.Error("expected non-zero exit code after timeout")
}
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestRunStreaming_OnLine(t *testing.T) {
var lines []string
result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"},
process.RunStreamingOptions{
OnLine: func(line string) { lines = append(lines, line) },
})
if err != nil {
t.Fatal(err)
}
if result.Code != 0 {
t.Errorf("Code = %d, want 0", result.Code)
}
if len(lines) != 3 {
t.Errorf("got %d lines, want 3: %v", len(lines), lines)
}
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
t.Errorf("lines = %v, want [a b c]", lines)
}
}
func TestRunStreaming_FlushFinalLine(t *testing.T) {
var lines []string
result, err := process.RunStreaming("sh", []string{"-c", "printf tail"},
process.RunStreamingOptions{
OnLine: func(line string) { lines = append(lines, line) },
})
if err != nil {
t.Fatal(err)
}
if result.Code != 0 {
t.Errorf("Code = %d, want 0", result.Code)
}
if len(lines) != 1 {
t.Errorf("got %d lines, want 1: %v", len(lines), lines)
}
if lines[0] != "tail" {
t.Errorf("lines[0] = %q, want tail", lines[0])
}
}
func TestRunStreaming_ConfigDir(t *testing.T) {
var lines []string
_, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
process.RunStreamingOptions{
RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"},
OnLine: func(line string) { lines = append(lines, line) },
})
if err != nil {
t.Fatal(err)
}
if len(lines) != 1 || lines[0] != "/my/config/dir" {
t.Errorf("lines = %v, want [/my/config/dir]", lines)
}
}
func TestRunStreaming_Stderr(t *testing.T) {
result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"},
process.RunStreamingOptions{OnLine: func(string) {}})
if err != nil {
t.Fatal(err)
}
if result.Stderr == "" {
t.Error("expected stderr to contain output")
}
}
func TestRunStreaming_Timeout(t *testing.T) {
start := time.Now()
result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"},
process.RunStreamingOptions{
RunOptions: process.RunOptions{TimeoutMs: 300},
OnLine: func(string) {},
})
elapsed := time.Since(start)
if err != nil {
t.Fatal(err)
}
if result.Code == 0 {
t.Error("expected non-zero exit code after timeout")
}
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestRunStreaming_Concurrent(t *testing.T) {
var lines1, lines2 []string
done := make(chan struct{}, 2)
run := func(label string, target *[]string) {
process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"},
process.RunStreamingOptions{
OnLine: func(line string) { *target = append(*target, line) },
})
done <- struct{}{}
}
go run("stream1", &lines1)
go run("stream2", &lines2)
<-done
<-done
if len(lines1) != 1 || lines1[0] != "stream1" {
t.Errorf("lines1 = %v, want [stream1]", lines1)
}
if len(lines2) != 1 || lines2[0] != "stream2" {
t.Errorf("lines2 = %v, want [stream2]", lines2)
}
}
func TestRunStreaming_ContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
start := time.Now()
done := make(chan struct{})
go func() {
process.RunStreaming("sh", []string{"-c", "sleep 30"},
process.RunStreamingOptions{
RunOptions: process.RunOptions{Ctx: ctx},
OnLine: func(string) {},
})
close(done)
}()
time.AfterFunc(100*time.Millisecond, cancel)
<-done
elapsed := time.Since(start)
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestRun_ContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
start := time.Now()
done := make(chan process.RunResult, 1)
go func() {
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
done <- r
}()
time.AfterFunc(100*time.Millisecond, cancel)
result := <-done
elapsed := time.Since(start)
if result.Code == 0 {
t.Error("expected non-zero exit code after cancel")
}
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestRun_AlreadyCancelledContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // 已取消
start := time.Now()
result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
elapsed := time.Since(start)
if result.Code == 0 {
t.Error("expected non-zero exit code")
}
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestKillAllChildProcesses(t *testing.T) {
done := make(chan process.RunResult, 1)
go func() {
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{})
done <- r
}()
time.Sleep(80 * time.Millisecond)
process.KillAllChildProcesses()
result := <-done
if result.Code == 0 {
t.Error("expected non-zero exit code after kill")
}
// 再次呼叫不應 panic
process.KillAllChildProcesses()
}

View File

@ -1,27 +0,0 @@
package cursor
import (
"context"
"cursor-api-proxy/internal/apitypes"
"cursor-api-proxy/internal/config"
)
type Provider struct {
cfg config.BridgeConfig
}
func NewProvider(cfg config.BridgeConfig) *Provider {
return &Provider{cfg: cfg}
}
func (p *Provider) Name() string {
return "cursor"
}
func (p *Provider) Close() error {
return nil
}
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
return nil
}

View File

@ -1,32 +0,0 @@
package providers
import (
"context"
"cursor-api-proxy/internal/apitypes"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/providers/cursor"
"cursor-api-proxy/internal/providers/geminiweb"
"fmt"
)
type Provider interface {
Name() string
Close() error
Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error
}
func NewProvider(cfg config.BridgeConfig) (Provider, error) {
providerType := cfg.Provider
if providerType == "" {
providerType = "cursor"
}
switch providerType {
case "cursor":
return cursor.NewProvider(cfg), nil
case "gemini-web":
return geminiweb.NewPlaywrightProvider(cfg)
default:
return nil, fmt.Errorf("unknown provider: %s", providerType)
}
}

View File

@ -1,125 +0,0 @@
package geminiweb
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/launcher"
"github.com/go-rod/rod/lib/proto"
)
type Browser struct {
browser *rod.Browser
visible bool
}
func NewBrowser(visible bool) (*Browser, error) {
l := launcher.New()
if visible {
l = l.Headless(false)
} else {
l = l.Headless(true)
}
url, err := l.Launch()
if err != nil {
return nil, fmt.Errorf("failed to launch browser: %w", err)
}
b := rod.New().ControlURL(url)
if err := b.Connect(); err != nil {
return nil, fmt.Errorf("failed to connect browser: %w", err)
}
return &Browser{browser: b, visible: visible}, nil
}
func (b *Browser) Close() error {
if b.browser != nil {
return b.browser.Close()
}
return nil
}
func (b *Browser) NewPage() (*rod.Page, error) {
return b.browser.Page(proto.TargetCreateTarget{URL: "about:blank"})
}
type Cookie struct {
Name string `json:"name"`
Value string `json:"value"`
Domain string `json:"domain"`
Path string `json:"path"`
Expires float64 `json:"expires"`
HTTPOnly bool `json:"httpOnly"`
Secure bool `json:"secure"`
}
func LoadCookiesFromFile(cookieFile string) ([]Cookie, error) {
data, err := os.ReadFile(cookieFile)
if err != nil {
return nil, fmt.Errorf("failed to read cookies: %w", err)
}
var cookies []Cookie
if err := json.Unmarshal(data, &cookies); err != nil {
return nil, fmt.Errorf("failed to parse cookies: %w", err)
}
return cookies, nil
}
func SaveCookiesToFile(cookies []Cookie, cookieFile string) error {
data, err := json.MarshalIndent(cookies, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal cookies: %w", err)
}
dir := filepath.Dir(cookieFile)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create cookie dir: %w", err)
}
if err := os.WriteFile(cookieFile, data, 0644); err != nil {
return fmt.Errorf("failed to write cookies: %w", err)
}
return nil
}
func SetCookiesOnPage(page *rod.Page, cookies []Cookie) error {
var protoCookies []*proto.NetworkCookieParam
for _, c := range cookies {
p := &proto.NetworkCookieParam{
Name: c.Name,
Value: c.Value,
Domain: c.Domain,
Path: c.Path,
HTTPOnly: c.HTTPOnly,
Secure: c.Secure,
}
if c.Expires > 0 {
exp := proto.TimeSinceEpoch(c.Expires)
p.Expires = exp
}
protoCookies = append(protoCookies, p)
}
return page.SetCookies(protoCookies)
}
func WaitForElement(page *rod.Page, selector string, timeout time.Duration) (*rod.Element, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return page.Context(ctx).Element(selector)
}
func WaitForElements(page *rod.Page, selector string, timeout time.Duration) (rod.Elements, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return page.Context(ctx).Elements(selector)
}

View File

@ -1,173 +0,0 @@
package geminiweb
import (
"fmt"
"os"
"path/filepath"
"sync"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/launcher"
"github.com/go-rod/rod/lib/proto"
)
// BrowserManager 管理瀏覽器實例的生命週期
type BrowserManager struct {
mu sync.Mutex
browser *rod.Browser
userDataDir string
page *rod.Page
visible bool
isRunning bool
currentModel string
}
var (
globalManager *BrowserManager
globalMu sync.Mutex
)
// GetBrowserManager 獲取全域瀏覽器管理器(單例)
func GetBrowserManager(userDataDir string, visible bool) (*BrowserManager, error) {
globalMu.Lock()
defer globalMu.Unlock()
if globalManager != nil {
return globalManager, nil
}
manager, err := NewBrowserManager(userDataDir, visible)
if err != nil {
return nil, err
}
globalManager = manager
return globalManager, nil
}
// NewBrowserManager 建立新的瀏覽器管理器
func NewBrowserManager(userDataDir string, visible bool) (*BrowserManager, error) {
cleanLockFiles(userDataDir)
if err := os.MkdirAll(userDataDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create user data dir: %w", err)
}
return &BrowserManager{
userDataDir: userDataDir,
visible: visible,
}, nil
}
// cleanLockFiles 清理 Chrome 的殘留鎖檔案
func cleanLockFiles(userDataDir string) {
lockFiles := []string{
"SingletonLock",
"SingletonCookie",
"SingletonSocket",
"Default/SingletonLock",
"Default/SingletonCookie",
"Default/SingletonSocket",
}
for _, file := range lockFiles {
path := filepath.Join(userDataDir, file)
os.Remove(path)
}
}
// Launch 啟動瀏覽器(如果尚未啟動)
func (m *BrowserManager) Launch() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.isRunning && m.browser != nil {
return nil
}
l := launcher.New()
if m.visible {
l = l.Headless(false)
} else {
l = l.Headless(true)
}
l = l.UserDataDir(m.userDataDir)
url, err := l.Launch()
if err != nil {
return fmt.Errorf("failed to launch browser: %w", err)
}
b := rod.New().ControlURL(url)
if err := b.Connect(); err != nil {
return fmt.Errorf("failed to connect browser: %w", err)
}
m.browser = b
page, err := b.Page(proto.TargetCreateTarget{URL: "about:blank"})
if err != nil {
_ = b.Close()
return fmt.Errorf("failed to create page: %w", err)
}
m.page = page
m.isRunning = true
return nil
}
// GetPage 獲取頁面
func (m *BrowserManager) GetPage() (*rod.Page, error) {
m.mu.Lock()
defer m.mu.Unlock()
if !m.isRunning || m.browser == nil {
return nil, fmt.Errorf("browser not running")
}
return m.page, nil
}
// Close 關閉瀏覽器
func (m *BrowserManager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if !m.isRunning {
return nil
}
var err error
if m.browser != nil {
err = m.browser.Close()
m.browser = nil
}
m.page = nil
m.isRunning = false
return err
}
// IsRunning 檢查瀏覽器是否正在運行
func (m *BrowserManager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.isRunning
}
// SetCurrentModel 設定當前模型
func (m *BrowserManager) SetCurrentModel(model string) {
m.mu.Lock()
defer m.mu.Unlock()
m.currentModel = model
}
// GetCurrentModel 獲取當前模型
func (m *BrowserManager) GetCurrentModel() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.currentModel
}

View File

@ -1,250 +0,0 @@
package geminiweb
import (
"context"
"fmt"
"strings"
"time"
"github.com/go-rod/rod"
)
const geminiURL = "https://gemini.google.com/app"
// 輸入框選擇器(依優先順序)
var inputSelectors = []string{
".ProseMirror",
"rich-textarea",
"div[role='textbox'][contenteditable='true']",
"div[contenteditable='true']",
"textarea",
}
// NavigateToGemini 導航到 Gemini
func NavigateToGemini(page *rod.Page) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := page.Context(ctx).Navigate(geminiURL); err != nil {
return fmt.Errorf("failed to navigate: %w", err)
}
return page.Context(ctx).WaitLoad()
}
// IsLoggedIn 檢查是否已登入
func IsLoggedIn(page *rod.Page) bool {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for _, sel := range inputSelectors {
if _, err := page.Context(ctx).Element(sel); err == nil {
return true
}
}
return false
}
// SelectModel 選擇模型(可選)
func SelectModel(page *rod.Page, model string) error {
fmt.Printf("[GeminiWeb] Model selection skipped (using current model)\n")
return nil
}
// TypeInput 在輸入框中輸入文字
func TypeInput(page *rod.Page, text string) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
fmt.Println("[GeminiWeb] Looking for input field...")
// 1. 嘗試所有選擇器
var inputEl *rod.Element
var err error
for _, sel := range inputSelectors {
fmt.Printf(" Trying: %s\n", sel)
inputEl, err = page.Context(ctx).Element(sel)
if err == nil {
fmt.Printf(" ✓ Found with: %s\n", sel)
break
}
}
if err != nil {
// 2. Fallback: 嘗試等待頁面載入完成後重試
fmt.Println("[GeminiWeb] Waiting for page to fully load...")
time.Sleep(3 * time.Second)
for _, sel := range inputSelectors {
fmt.Printf(" Retrying: %s\n", sel)
inputEl, err = page.Context(ctx).Element(sel)
if err == nil {
fmt.Printf(" ✓ Found with: %s\n", sel)
break
}
}
}
if err != nil {
// 3. Debug: 印出頁面標題和 URL
info, _ := page.Info()
fmt.Printf("[GeminiWeb] DEBUG: URL=%s Title=%s\n", info.URL, info.Title)
// 4. Fallback: 嘗試更通用的選擇器
fmt.Println("[GeminiWeb] Trying generic selectors...")
genericSelectors := []string{
"div[contenteditable]",
"[contenteditable]",
"textarea",
"input[type='text']",
}
for _, sel := range genericSelectors {
fmt.Printf(" Trying generic: %s\n", sel)
inputEl, err = page.Context(ctx).Element(sel)
if err == nil {
fmt.Printf(" ✓ Found with: %s\n", sel)
break
}
}
}
if err != nil {
info, _ := page.Info()
return fmt.Errorf("input field not found after trying all selectors (URL=%s)", info.URL)
}
// 2. Focus 輸入框
fmt.Printf("[GeminiWeb] Focusing input field...\n")
if err := inputEl.Focus(); err != nil {
return fmt.Errorf("failed to focus input: %w", err)
}
time.Sleep(500 * time.Millisecond)
// 3. 使用 Input 方法
fmt.Printf("[GeminiWeb] Typing %d chars...\n", len(text))
if err := inputEl.Input(text); err != nil {
return fmt.Errorf("failed to input text: %w", err)
}
time.Sleep(200 * time.Millisecond)
fmt.Println("[GeminiWeb] Input complete")
return nil
}
// ClickSend 發送訊息
func ClickSend(page *rod.Page) error {
// 方法 1: 按 Enter
if err := page.Keyboard.Press('\r'); err != nil {
return fmt.Errorf("failed to press Enter: %w", err)
}
time.Sleep(200 * time.Millisecond)
return nil
}
// WaitForReady 等待頁面空閒
func WaitForReady(page *rod.Page) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
fmt.Println("[GeminiWeb] Checking if page is ready...")
for {
select {
case <-ctx.Done():
fmt.Println("[GeminiWeb] Page ready check timeout, proceeding anyway")
return nil
default:
time.Sleep(500 * time.Millisecond)
// 檢查是否有停止按鈕
hasStopBtn := false
stopBtns, _ := page.Elements("button[aria-label*='Stop'], button[aria-label*='停止']")
for _, btn := range stopBtns {
visible, _ := btn.Visible()
if visible {
hasStopBtn = true
break
}
}
if !hasStopBtn {
fmt.Println("[GeminiWeb] Page is ready")
return nil
}
}
}
}
// ExtractResponse 提取回應文字
func ExtractResponse(page *rod.Page) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
var lastText string
lastUpdate := time.Now()
for {
select {
case <-ctx.Done():
if lastText != "" {
return lastText, nil
}
return "", fmt.Errorf("response timeout")
default:
time.Sleep(500 * time.Millisecond)
// 尋找回應文字
for _, sel := range responseSelectors {
elements, err := page.Elements(sel)
if err != nil || len(elements) == 0 {
continue
}
// 取得最後一個元素的文字
lastEl := elements[len(elements)-1]
text, err := lastEl.Text()
if err != nil {
continue
}
text = strings.TrimSpace(text)
if text != "" && text != lastText && len(text) > len(lastText) {
lastText = text
lastUpdate = time.Now()
fmt.Printf("[GeminiWeb] Response length: %d\n", len(text))
}
}
// 檢查是否已完成2 秒內沒有新內容)
if time.Since(lastUpdate) > 2*time.Second && lastText != "" {
// 最後檢查一次是否還有停止按鈕
hasStopBtn := false
stopBtns, _ := page.Elements("button[aria-label*='Stop'], button[aria-label*='停止']")
for _, btn := range stopBtns {
visible, _ := btn.Visible()
if visible {
hasStopBtn = true
break
}
}
if !hasStopBtn {
return lastText, nil
}
}
}
}
}
// 默認的回應選擇器
var responseSelectors = []string{
".model-response-text",
".message-content",
".markdown",
".prose",
"model-response",
}

View File

@ -1,641 +0,0 @@
package geminiweb
import (
"context"
"cursor-api-proxy/internal/apitypes"
"cursor-api-proxy/internal/config"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/playwright-community/playwright-go"
)
// PlaywrightProvider 使用 Playwright 的 Gemini Provider
type PlaywrightProvider struct {
cfg config.BridgeConfig
pw *playwright.Playwright
browser playwright.Browser
context playwright.BrowserContext
page playwright.Page
mu sync.Mutex
userDataDir string
}
var (
playwrightInstance *playwright.Playwright
playwrightOnce sync.Once
playwrightErr error
)
// NewPlaywrightProvider 建立新的 Playwright Provider
func NewPlaywrightProvider(cfg config.BridgeConfig) (*PlaywrightProvider, error) {
// 確保 Playwright 已初始化(單例)
playwrightOnce.Do(func() {
playwrightInstance, playwrightErr = playwright.Run()
if playwrightErr != nil {
playwrightErr = fmt.Errorf("failed to run playwright: %w", playwrightErr)
}
})
if playwrightErr != nil {
return nil, playwrightErr
}
// 清理 Chrome 鎖檔案
userDataDir := filepath.Join(cfg.GeminiAccountDir, "default-session")
cleanLockFiles(userDataDir)
// 確保目錄存在
if err := os.MkdirAll(userDataDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create user data dir: %w", err)
}
return &PlaywrightProvider{
cfg: cfg,
pw: playwrightInstance,
userDataDir: userDataDir,
}, nil
}
// getName 返回 Provider 名稱
func (p *PlaywrightProvider) Name() string {
return "gemini-web"
}
// launchIfNeeded 如果需要則啟動瀏覽器
func (p *PlaywrightProvider) launchIfNeeded() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.context != nil && p.page != nil {
return nil
}
fmt.Println("[GeminiWeb] Launching Chromium...")
// 使用 LaunchPersistentContext自動保存 session
context, err := p.pw.Chromium.LaunchPersistentContext(p.userDataDir,
playwright.BrowserTypeLaunchPersistentContextOptions{
Headless: playwright.Bool(!p.cfg.GeminiBrowserVisible),
Args: []string{
"--no-first-run",
"--no-default-browser-check",
"--disable-background-networking",
"--disable-extensions",
"--disable-plugins",
"--disable-sync",
},
})
if err != nil {
return fmt.Errorf("failed to launch persistent context: %w", err)
}
p.context = context
// 取得或建立頁面
pages := context.Pages()
if len(pages) > 0 {
p.page = pages[0]
} else {
page, err := context.NewPage()
if err != nil {
_ = context.Close()
return fmt.Errorf("failed to create page: %w", err)
}
p.page = page
}
fmt.Println("[GeminiWeb] Browser launched")
return nil
}
// Generate 生成回應
func (p *PlaywrightProvider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) (err error) {
// 確保在返回錯誤時保存診斷
defer func() {
if err != nil {
fmt.Println("[GeminiWeb] Error occurred, saving diagnostics...")
_ = p.saveDiagnostics()
}
}()
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
// 1. 確保瀏覽器已啟動
if err := p.launchIfNeeded(); err != nil {
return fmt.Errorf("failed to launch browser: %w", err)
}
// 2. 導航到 Gemini如果需要
currentURL := p.page.URL()
if !strings.Contains(currentURL, "gemini.google.com") {
fmt.Println("[GeminiWeb] Navigating to Gemini...")
if _, err := p.page.Goto("https://gemini.google.com/app", playwright.PageGotoOptions{
WaitUntil: playwright.WaitUntilStateDomcontentloaded,
Timeout: playwright.Float(60000),
}); err != nil {
return fmt.Errorf("failed to navigate: %w", err)
}
// 額外等待 JavaScript 載入
fmt.Println("[GeminiWeb] Waiting for page to initialize...")
time.Sleep(3 * time.Second)
}
// 3. 調試模式:等待用戶確認
if p.cfg.GeminiBrowserVisible {
fmt.Println("\n" + strings.Repeat("=", 70))
fmt.Println("🔍 調試模式:瀏覽器已開啟")
fmt.Println("請檢查瀏覽器畫面,然後按 ENTER 繼續...")
fmt.Println("如果有問題,請查看: /tmp/gemini-debug.*")
fmt.Println(strings.Repeat("=", 70))
var input string
fmt.Scanln(&input)
}
// 4. 等待頁面完全載入project-golem 策略)
fmt.Println("[GeminiWeb] Waiting for page to be ready...")
if err := p.waitForPageReady(); err != nil {
fmt.Printf("[GeminiWeb] Warning: %v\n", err)
// 額外調試:輸出頁面 HTML 結構
if p.cfg.GeminiBrowserVisible {
html, _ := p.page.Content()
debugPath := "/tmp/gemini-debug.html"
if err := os.WriteFile(debugPath, []byte(html), 0644); err == nil {
fmt.Printf("[GeminiWeb] HTML saved to: %s\n", debugPath)
}
}
}
// 4. 檢查登入狀態
fmt.Println("[GeminiWeb] Checking login status...")
loggedIn := p.isLoggedIn()
if !loggedIn {
fmt.Println("[GeminiWeb] Not logged in, continuing anyway")
if p.cfg.GeminiBrowserVisible {
fmt.Println("\n========================================")
fmt.Println("Browser is open. You can:")
fmt.Println("1. Log in to Gemini now")
fmt.Println("2. Continue without login")
fmt.Println("========================================\n")
}
} else {
fmt.Println("[GeminiWeb] ✓ Logged in")
}
// 5. 選擇模型(如果支援)
if err := p.selectModel(model); err != nil {
fmt.Printf("[GeminiWeb] Warning: model selection failed: %v\n", err)
}
// 6. 建構提示詞
prompt := buildPromptFromMessagesPlaywright(messages)
fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt))
// 7. 輸入文字(使用 Playwright 的 Auto-wait
if err := p.typeInput(prompt); err != nil {
return fmt.Errorf("failed to type: %w", err)
}
// 7. 發送訊息
fmt.Println("[GeminiWeb] Sending message...")
if err := p.sendMessage(); err != nil {
return fmt.Errorf("failed to send: %w", err)
}
// 8. 提取回應
fmt.Println("[GeminiWeb] Waiting for response...")
response, err := p.extractResponse()
if err != nil {
return fmt.Errorf("failed to extract response: %w", err)
}
// 9. 回調
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response})
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response))
return nil
}
// Close 關閉 Provider
func (p *PlaywrightProvider) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.context != nil {
if err := p.context.Close(); err != nil {
return err
}
p.context = nil
p.page = nil
}
return nil
}
// saveDiagnostics 保存診斷信息
func (p *PlaywrightProvider) saveDiagnostics() error {
if p.page == nil {
return fmt.Errorf("no page available")
}
// 截圖
screenshotPath := "/tmp/gemini-debug.png"
if _, err := p.page.Screenshot(playwright.PageScreenshotOptions{
Path: playwright.String(screenshotPath),
}); err == nil {
fmt.Printf("[GeminiWeb] Screenshot saved: %s\n", screenshotPath)
}
// HTML
htmlPath := "/tmp/gemini-debug.html"
if html, err := p.page.Content(); err == nil {
if err := os.WriteFile(htmlPath, []byte(html), 0644); err == nil {
fmt.Printf("[GeminiWeb] HTML saved: %s\n", htmlPath)
}
}
// 輸出頁面信息
url := p.page.URL()
title, _ := p.page.Title()
fmt.Printf("[GeminiWeb] Diagnostics: URL=%s, Title=%s\n", url, title)
return nil
}
// waitForPageReady 等待頁面完全就緒project-golem 策略)
func (p *PlaywrightProvider) waitForPageReady() error {
fmt.Println("[GeminiWeb] Checking for ready state...")
// 1. 等待停止按鈕消失(如果存在)
_, _ = p.page.WaitForSelector("button[aria-label*='Stop'], button[aria-label*='停止']", playwright.PageWaitForSelectorOptions{
State: playwright.WaitForSelectorStateDetached,
Timeout: playwright.Float(5000),
})
// 2. 嘗試多種等待策略
inputSelectors := []string{
".ql-editor.ql-blank",
".ql-editor",
"div[contenteditable='true'][role='textbox']",
"div[contenteditable='true']",
".ProseMirror",
"rich-textarea",
"textarea",
}
// 策略 A: 等待任一輸入框出現
for i, sel := range inputSelectors {
fmt.Printf(" [%d/%d] Waiting for: %s\n", i+1, len(inputSelectors), sel)
locator := p.page.Locator(sel)
if err := locator.WaitFor(playwright.LocatorWaitForOptions{
Timeout: playwright.Float(5000),
State: playwright.WaitForSelectorStateVisible,
}); err == nil {
fmt.Printf(" ✓ Input field found: %s\n", sel)
return nil
}
}
// 策略 B: 等待頁面完全載入
fmt.Println("[GeminiWeb] Waiting for page load...")
time.Sleep(3 * time.Second)
// 策略 C: 使用 JavaScript 檢查
fmt.Println("[GeminiWeb] Checking with JavaScript...")
result, err := p.page.Evaluate(`
() => {
// 檢查所有可能的輸入元素
const selectors = [
'.ql-editor.ql-blank',
'.ql-editor',
'div[contenteditable="true"][role="textbox"]',
'div[contenteditable="true"]',
'.ProseMirror',
'rich-textarea',
'textarea'
];
for (const sel of selectors) {
const el = document.querySelector(sel);
if (el) {
return {
found: true,
selector: sel,
tagName: el.tagName,
className: el.className,
visible: el.offsetParent !== null
};
}
}
return { found: false };
}
`)
if err == nil {
if m, ok := result.(map[string]interface{}); ok {
if found, _ := m["found"].(bool); found {
sel, _ := m["selector"].(string)
fmt.Printf(" ✓ JavaScript found: %s\n", sel)
return nil
}
}
}
// 策略 D: 調試模式 - 輸出頁面結構
if p.cfg.GeminiBrowserVisible {
fmt.Println("[GeminiWeb].dump: Page structure analysis")
_, _ = p.page.Evaluate(`
() => {
const allElements = document.querySelectorAll('*');
const inputLike = [];
for (const el of allElements) {
if (el.contentEditable === 'true' ||
el.role === 'textbox' ||
el.tagName === 'TEXTAREA' ||
el.tagName === 'INPUT') {
inputLike.push({
tag: el.tagName,
class: el.className,
id: el.id,
role: el.role,
contentEditable: el.contentEditable
});
}
}
console.log('Input-like elements:', inputLike);
}
`)
}
return fmt.Errorf("no input field found after all strategies")
}
// isLoggedIn 檢查是否已登入
func (p *PlaywrightProvider) isLoggedIn() bool {
// 嘗試找輸入框(登入狀態的主要特徵)
selectors := []string{
".ProseMirror",
"rich-textarea",
"div[role='textbox']",
"div[contenteditable='true']",
"textarea",
}
for _, sel := range selectors {
locator := p.page.Locator(sel)
if count, _ := locator.Count(); count > 0 {
return true
}
}
return false
}
// typeInput 輸入文字(使用 Playwright 的 Auto-wait
func (p *PlaywrightProvider) typeInput(text string) error {
fmt.Println("[GeminiWeb] Looking for input field...")
selectors := []string{
".ql-editor.ql-blank",
".ql-editor",
"div[contenteditable='true'][role='textbox']",
"div[contenteditable='true']",
".ProseMirror",
"rich-textarea",
"textarea",
}
var inputLocator playwright.Locator
var found bool
for _, sel := range selectors {
fmt.Printf(" Trying: %s\n", sel)
locator := p.page.Locator(sel)
if err := locator.WaitFor(playwright.LocatorWaitForOptions{
Timeout: playwright.Float(3000),
}); err == nil {
inputLocator = locator
found = true
fmt.Printf(" ✓ Found with: %s\n", sel)
break
}
}
if !found {
// 錯誤會被 Generate 的 defer 捕獲並保存診斷
url := p.page.URL()
title, _ := p.page.Title()
return fmt.Errorf("input field not found (URL=%s, Title=%s). Diagnostics will be saved to /tmp/", url, title)
}
// Focus 並填充Playwright 自動等待)
fmt.Printf("[GeminiWeb] Typing %d chars...\n", len(text))
if err := inputLocator.Fill(text); err != nil {
return fmt.Errorf("failed to fill: %w", err)
}
fmt.Println("[GeminiWeb] Input complete")
return nil
}
// sendMessage 發送訊息
func (p *PlaywrightProvider) sendMessage() error {
// 方法 1: 按 Enter最可靠
if err := p.page.Keyboard().Press("Enter"); err != nil {
return fmt.Errorf("failed to press Enter: %w", err)
}
time.Sleep(200 * time.Millisecond)
// 方法 2: 嘗試點擊發送按鈕(補強)
_, _ = p.page.Evaluate(`
() => {
const keywords = ['發送', 'Send', '傳送'];
const buttons = Array.from(document.querySelectorAll('button, [role="button"]'));
for (const btn of buttons) {
const text = (btn.innerText || btn.textContent || '').trim();
const label = (btn.getAttribute('aria-label') || '').trim();
// 跳過停止按鈕
if (['停止', 'Stop', '中斷'].includes(text) || label.toLowerCase().includes('stop')) {
continue;
}
if (keywords.some(kw => text.includes(kw) || label.includes(kw))) {
btn.click();
return true;
}
}
return false;
}
`)
return nil
}
// extractResponse 提取回應
func (p *PlaywrightProvider) extractResponse() (string, error) {
var lastText string
var stableCount int
lastUpdate := time.Now()
timeout := 120 * time.Second
startTime := time.Now()
for time.Since(startTime) < timeout {
time.Sleep(500 * time.Millisecond)
// 使用 JavaScript 提取回應文字(更精確)
result, err := p.page.Evaluate(`
() => {
// 尋找所有可能的回應容器
const selectors = [
'model-response',
'.model-response',
'message-content',
'.message-content'
];
for (const sel of selectors) {
const el = document.querySelector(sel);
if (el) {
// 嘗試找markdown內容
const markdown = el.querySelector('.markdown, .prose, [class*="markdown"]');
if (markdown && markdown.innerText.trim()) {
let text = markdown.innerText.trim();
// 移除常見的標籤前綴
text = text.replace(/^Gemini said\s*\n*/i, '').replace(/^Gemini\s*[:]\s*\n*/i, '').trim();
return { text: text, source: sel + ' .markdown' };
}
// 嘗試找純文字內容(排除標籤)
let textContent = el.innerText.trim();
if (textContent) {
// 移除常見的標籤前綴
textContent = textContent.replace(/^Gemini said\s*\n*/i, '').replace(/^Gemini\s*[:]\s*\n*/i, '').trim();
return { text: textContent, source: sel };
}
}
}
return { text: '', source: 'none' };
}
`)
if err == nil {
if m, ok := result.(map[string]interface{}); ok {
text, _ := m["text"].(string)
text = strings.TrimSpace(text)
if text != "" && len(text) > len(lastText) {
lastText = text
lastUpdate = time.Now()
stableCount = 0
fmt.Printf("[GeminiWeb] Response: %d chars\n", len(text))
}
}
}
// 檢查是否完成(需要連續 3 次穩定)
if time.Since(lastUpdate) > 500*time.Millisecond && lastText != "" {
stableCount++
if stableCount >= 3 {
// 最終檢查:停止按鈕是否還存在
stopBtn := p.page.Locator("button[aria-label*='Stop'], button[aria-label*='停止'], button[data-test-id='stop-button']")
count, _ := stopBtn.Count()
if count == 0 {
fmt.Println("[GeminiWeb] ✓ Response complete")
return lastText, nil
}
}
}
}
if lastText != "" {
fmt.Println("[GeminiWeb] ✓ Response complete (timeout)")
return lastText, nil
}
return "", fmt.Errorf("response timeout")
}
// selectModel 選擇 Gemini 模型
// Gemini Web 只有三種模型fast, thinking, pro
func (p *PlaywrightProvider) selectModel(model string) error {
// 映射模型名稱到 Gemini Web 的模型選擇器
modelMap := map[string]string{
"fast": "Fast",
"thinking": "Thinking",
"pro": "Pro",
"gemini-fast": "Fast",
"gemini-thinking": "Thinking",
"gemini-pro": "Pro",
"gemini-2.0-fast": "Fast",
"gemini-2.0-flash": "Fast", // 相容舊名稱
"gemini-2.5-pro": "Pro",
"gemini-2.5-pro-thinking": "Thinking",
}
// 從完整模型名稱中提取類型
modelType := ""
modelLower := strings.ToLower(model)
for key, value := range modelMap {
if strings.Contains(modelLower, strings.ToLower(key)) || modelLower == strings.ToLower(key) {
modelType = value
break
}
}
if modelType == "" {
// 默認使用 Fast
fmt.Printf("[GeminiWeb] Unknown model '%s', defaulting to Fast\n", model)
return nil
}
fmt.Printf("[GeminiWeb] Selecting model: %s\n", modelType)
// 點擊模型選擇器
modelSelector := p.page.Locator("button[aria-label*='Model'], button[aria-label*='模型'], [data-test-id='model-selector']")
if count, _ := modelSelector.Count(); count > 0 {
if err := modelSelector.First().Click(); err != nil {
fmt.Printf("[GeminiWeb] Warning: could not click model selector: %v\n", err)
} else {
time.Sleep(500 * time.Millisecond)
// 選擇對應的模型選項
optionSelector := p.page.Locator(fmt.Sprintf("button:has-text('%s'), [role='menuitem']:has-text('%s')", modelType, modelType))
if count, _ := optionSelector.Count(); count > 0 {
if err := optionSelector.First().Click(); err != nil {
fmt.Printf("[GeminiWeb] Warning: could not select model: %v\n", err)
} else {
fmt.Printf("[GeminiWeb] ✓ Model selected: %s\n", modelType)
time.Sleep(500 * time.Millisecond)
}
}
}
}
return nil
}
// buildPromptFromMessages 從訊息列表建構提示詞
func buildPromptFromMessagesPlaywright(messages []apitypes.Message) string {
var prompt string
for _, m := range messages {
switch m.Role {
case "system":
prompt += "System: " + m.Content + "\n\n"
case "user":
prompt += m.Content + "\n\n"
case "assistant":
prompt += "Assistant: " + m.Content + "\n\n"
}
}
return prompt
}

View File

@ -1,169 +0,0 @@
package geminiweb
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
)
type GeminiSession struct {
Name string `json:"name"`
CookieFile string `json:"cookie_file"`
LastUsed int64 `json:"last_used"`
ActiveCount int `json:"active_count"`
RateLimitEnd int64 `json:"rate_limit_end"`
}
type SessionPool struct {
mu sync.Mutex
sessions []*GeminiSession
dir string
maxCount int
}
func NewSessionPool(dir string, maxSessions int) (*SessionPool, error) {
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create session dir: %w", err)
}
sessions, err := loadSessions(dir)
if err != nil {
return nil, fmt.Errorf("failed to load sessions: %w", err)
}
return &SessionPool{
sessions: sessions,
dir: dir,
maxCount: maxSessions,
}, nil
}
func loadSessions(dir string) ([]*GeminiSession, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
var sessions []*GeminiSession
for _, entry := range entries {
if !entry.IsDir() {
continue
}
name := entry.Name()
metaPath := filepath.Join(dir, name, "session.json")
data, err := os.ReadFile(metaPath)
if err != nil {
continue
}
var s GeminiSession
if err := json.Unmarshal(data, &s); err != nil {
continue
}
sessions = append(sessions, &s)
}
return sessions, nil
}
func (p *SessionPool) Count() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.sessions)
}
func (p *SessionPool) GetAvailable() *GeminiSession {
p.mu.Lock()
defer p.mu.Unlock()
now := time.Now().UnixMilli()
var available []*GeminiSession
for _, s := range p.sessions {
if s.RateLimitEnd < now {
available = append(available, s)
}
}
if len(available) == 0 {
return nil
}
var best *GeminiSession
for _, s := range available {
if best == nil || s.ActiveCount < best.ActiveCount {
best = s
} else if s.ActiveCount == best.ActiveCount && s.LastUsed < best.LastUsed {
best = s
}
}
return best
}
func (p *SessionPool) StartSession(s *GeminiSession) {
p.mu.Lock()
defer p.mu.Unlock()
s.ActiveCount++
s.LastUsed = time.Now().UnixMilli()
p.saveSession(s)
}
func (p *SessionPool) EndSession(s *GeminiSession) {
p.mu.Lock()
defer p.mu.Unlock()
if s.ActiveCount > 0 {
s.ActiveCount--
}
p.saveSession(s)
}
func (p *SessionPool) RateLimitSession(s *GeminiSession, durationMs int64) {
p.mu.Lock()
defer p.mu.Unlock()
s.RateLimitEnd = time.Now().UnixMilli() + durationMs
p.saveSession(s)
}
func (p *SessionPool) saveSession(s *GeminiSession) {
metaPath := filepath.Join(p.dir, s.Name, "session.json")
data, err := json.MarshalIndent(s, "", " ")
if err != nil {
return
}
_ = os.WriteFile(metaPath, data, 0644)
}
func (p *SessionPool) CreateSession(name string) (*GeminiSession, error) {
p.mu.Lock()
defer p.mu.Unlock()
sessionDir := filepath.Join(p.dir, name)
if err := os.MkdirAll(sessionDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create session dir: %w", err)
}
s := &GeminiSession{
Name: name,
CookieFile: filepath.Join(sessionDir, "cookies.json"),
LastUsed: time.Now().UnixMilli(),
}
p.sessions = append(p.sessions, s)
p.saveSession(s)
return s, nil
}
func (p *SessionPool) GetSessionNames() []string {
p.mu.Lock()
defer p.mu.Unlock()
names := make([]string, len(p.sessions))
for i, s := range p.sessions {
names[i] = s.Name
}
return names
}

View File

@ -1,196 +0,0 @@
package geminiweb
import (
"context"
"cursor-api-proxy/internal/apitypes"
"cursor-api-proxy/internal/config"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
// Provider 使用持久化瀏覽器管理器
type Provider struct {
cfg config.BridgeConfig
managerOnce sync.Once
manager *BrowserManager
managerErr error
}
// NewProvider 建立新的 Provider
func NewProvider(cfg config.BridgeConfig) *Provider {
return &Provider{cfg: cfg}
}
// getName 返回 Provider 名稱
func (p *Provider) Name() string {
return "gemini-web"
}
// Close 關閉瀏覽器
func (p *Provider) Close() error {
if p.manager != nil {
return p.manager.Close()
}
return nil
}
// getManager 獲取或初始化瀏覽器管理器(單例)
func (p *Provider) getManager() (*BrowserManager, error) {
p.managerOnce.Do(func() {
sessionDir := p.getSessionDir()
p.manager, p.managerErr = GetBrowserManager(sessionDir, p.cfg.GeminiBrowserVisible)
})
return p.manager, p.managerErr
}
// getSessionDir 獲取 session 目錄
func (p *Provider) getSessionDir() string {
// 使用單一 session 目錄(簡化設計)
return filepath.Join(p.cfg.GeminiAccountDir, "default-session")
}
// Generate 生成回應
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
// 1. 獲取瀏覽器管理器
manager, err := p.getManager()
if err != nil {
return fmt.Errorf("failed to get browser manager: %w", err)
}
// 2. 啟動瀏覽器(如果尚未啟動)
if !manager.IsRunning() {
fmt.Printf("[GeminiWeb] Launching browser...\n")
if err := manager.Launch(); err != nil {
return fmt.Errorf("failed to launch browser: %w", err)
}
}
// 3. 獲取頁面
page, err := manager.GetPage()
if err != nil {
return fmt.Errorf("failed to get page: %w", err)
}
// 4. 檢查當前 URL如果不是 Gemini 則導航
currentURL, _ := page.Info()
if !strings.Contains(currentURL.URL, "gemini.google.com") {
fmt.Printf("[GeminiWeb] Navigating to Gemini...\n")
if err := NavigateToGemini(page); err != nil {
return fmt.Errorf("failed to navigate: %w", err)
}
time.Sleep(2 * time.Second)
}
// 5. 檢查登入狀態
fmt.Printf("[GeminiWeb] Checking login status...\n")
if !IsLoggedIn(page) {
fmt.Printf("[GeminiWeb] Not logged in, continuing anyway\n")
if p.cfg.GeminiBrowserVisible {
fmt.Println("\n========================================")
fmt.Println("Browser is open. You can:")
fmt.Println("1. Log in to Gemini now")
fmt.Println("2. Continue without login")
fmt.Println("========================================\n")
}
} else {
fmt.Printf("[GeminiWeb] Logged in\n")
}
// 6. 等待頁面就緒
if err := WaitForReady(page); err != nil {
fmt.Printf("[GeminiWeb] Warning: %v\n", err)
}
// 7. 建構提示詞
prompt := buildPromptFromMessages(messages)
fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt))
// 8. 輸入文字
if err := TypeInput(page, prompt); err != nil {
return fmt.Errorf("failed to type input: %w", err)
}
// 9. 發送
fmt.Printf("[GeminiWeb] Sending message...\n")
if err := ClickSend(page); err != nil {
return fmt.Errorf("failed to send: %w", err)
}
// 10. 提取回應
fmt.Printf("[GeminiWeb] Waiting for response...\n")
response, err := ExtractResponse(page)
if err != nil {
return fmt.Errorf("failed to extract response: %w", err)
}
// 11. 串流回調
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response})
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response))
return nil
}
// buildPromptFromMessages 從訊息列表建構提示詞
func buildPromptFromMessages(messages []apitypes.Message) string {
var prompt string
for _, m := range messages {
switch m.Role {
case "system":
prompt += "System: " + m.Content + "\n\n"
case "user":
prompt += m.Content + "\n\n"
case "assistant":
prompt += "Assistant: " + m.Content + "\n\n"
}
}
return prompt
}
// RunLogin 執行登入流程(供 gemini-login 命令使用)
func RunLogin(cfg config.BridgeConfig, sessionName string) error {
if sessionName == "" {
sessionName = "default-session"
}
sessionDir := filepath.Join(cfg.GeminiAccountDir, sessionName)
if err := os.MkdirAll(sessionDir, 0755); err != nil {
return fmt.Errorf("failed to create session dir: %w", err)
}
fmt.Printf("Starting browser for login. Session: %s\n", sessionName)
fmt.Printf("Session directory: %s\n", sessionDir)
fmt.Println("Please log in to your Gemini account in the browser window.")
fmt.Println("Press Ctrl+C when you have completed the login...")
manager, err := NewBrowserManager(sessionDir, true) // visible=true
if err != nil {
return fmt.Errorf("failed to create browser manager: %w", err)
}
if err := manager.Launch(); err != nil {
return fmt.Errorf("failed to launch browser: %w", err)
}
defer manager.Close()
page, err := manager.GetPage()
if err != nil {
return fmt.Errorf("failed to get page: %w", err)
}
if err := NavigateToGemini(page); err != nil {
return fmt.Errorf("failed to navigate: %w", err)
}
// 等待用戶手動登入...
// 使用 Ctrl+C 退出,瀏覽器資料會自動保存在 userDataDir
return nil
}

View File

@ -1,147 +0,0 @@
package router
import (
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/handlers"
"cursor-api-proxy/internal/httputil"
"cursor-api-proxy/internal/logger"
"cursor-api-proxy/internal/pool"
"fmt"
"net/http"
"os"
"time"
)
type RouterOptions struct {
Version string
Config config.BridgeConfig
ModelCache *handlers.ModelCacheRef
LastModel *string
Pool pool.PoolHandle
}
func NewRouter(opts RouterOptions) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cfg := opts.Config
pathname := r.URL.Path
method := r.Method
remoteAddress := r.RemoteAddr
if r.Header.Get("X-Real-IP") != "" {
remoteAddress = r.Header.Get("X-Real-IP")
}
logger.LogIncoming(method, pathname, remoteAddress)
defer func() {
logger.AppendSessionLine(cfg.SessionsLogPath, method, pathname, remoteAddress, 200)
}()
if cfg.RequiredKey != "" {
token := httputil.ExtractBearerToken(r)
if token != cfg.RequiredKey {
httputil.WriteJSON(w, 401, map[string]interface{}{
"error": map[string]string{"message": "Invalid API key", "code": "unauthorized"},
}, nil)
return
}
}
switch {
case method == "GET" && pathname == "/health":
handlers.HandleHealth(w, r, opts.Version, cfg)
case method == "GET" && pathname == "/v1/models":
opts.ModelCache.HandleModels(w, r, cfg)
case method == "POST" && pathname == "/v1/chat/completions":
raw, err := httputil.ReadBody(r)
if err != nil {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"message": "failed to read body", "code": "bad_request"},
}, nil)
return
}
// 根據 Provider 選擇處理方式
provider := cfg.Provider
if provider == "" {
provider = "cursor"
}
if provider == "gemini-web" {
handlers.HandleGeminiChatCompletions(w, r, cfg, raw, method, pathname, remoteAddress)
} else {
handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
}
case method == "POST" && pathname == "/v1/messages":
raw, err := httputil.ReadBody(r)
if err != nil {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"message": "failed to read body", "code": "bad_request"},
}, nil)
return
}
handlers.HandleAnthropicMessages(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
case (method == "POST" || method == "GET") && pathname == "/v1/completions":
httputil.WriteJSON(w, 404, map[string]interface{}{
"error": map[string]string{
"message": "Legacy completions endpoint is not supported. Use POST /v1/chat/completions instead.",
"code": "not_found",
},
}, nil)
case pathname == "/v1/embeddings":
httputil.WriteJSON(w, 404, map[string]interface{}{
"error": map[string]string{
"message": "Embeddings are not supported by this proxy.",
"code": "not_found",
},
}, nil)
default:
httputil.WriteJSON(w, 404, map[string]interface{}{
"error": map[string]string{"message": "Not found", "code": "not_found"},
}, nil)
}
}
}
func recoveryMiddleware(logPath string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
msg := fmt.Sprintf("%v", rec)
fmt.Fprintf(os.Stderr, "[%s] Proxy panic: %s\n", time.Now().UTC().Format(time.RFC3339), msg)
line := fmt.Sprintf("%s ERROR %s %s %s %s\n",
time.Now().UTC().Format(time.RFC3339), r.Method, r.URL.Path, r.RemoteAddr,
msg[:min(200, len(msg))])
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
_, _ = f.WriteString(line)
f.Close()
}
if !isHeaderWritten(w) {
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": msg, "code": "internal_error"},
}, nil)
}
}
}()
next(w, r)
}
}
func isHeaderWritten(w http.ResponseWriter) bool {
// Can't reliably detect without wrapping; always try to write
return false
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func WrapWithRecovery(logPath string, handler http.HandlerFunc) http.HandlerFunc {
return recoveryMiddleware(logPath, handler)
}

View File

@ -1,95 +0,0 @@
package sanitize
import "regexp"
type rule struct {
pattern *regexp.Regexp
replacement string
}
var rules = []rule{
{regexp.MustCompile(`(?i)x-anthropic-billing-header:[^\n]*\n?`), ""},
{regexp.MustCompile(`(?i)\bcc_version=[^\s;,\n]+[;,]?\s*`), ""},
{regexp.MustCompile(`(?i)\bcc_entrypoint=[^\s;,\n]+[;,]?\s*`), ""},
{regexp.MustCompile(`(?i)\bcch=[a-f0-9]+[;,]?\s*`), ""},
{regexp.MustCompile(`\bClaude Code\b`), "Cursor"},
{regexp.MustCompile(`(?i)Anthropic['']s official CLI for Claude`), "Cursor AI assistant"},
{regexp.MustCompile(`\bAnthropic\b`), "Cursor"},
{regexp.MustCompile(`(?i)anthropic\.com`), "cursor.com"},
{regexp.MustCompile(`(?i)claude\.ai`), "cursor.sh"},
{regexp.MustCompile(`^[;,\s]+`), ""},
}
func SanitizeText(text string) string {
for _, r := range rules {
text = r.pattern.ReplaceAllString(text, r.replacement)
}
return text
}
func SanitizeMessages(messages []interface{}) []interface{} {
result := make([]interface{}, len(messages))
for i, raw := range messages {
if raw == nil {
result[i] = raw
continue
}
m, ok := raw.(map[string]interface{})
if !ok {
result[i] = raw
continue
}
newMsg := make(map[string]interface{}, len(m))
for k, v := range m {
newMsg[k] = v
}
switch c := m["content"].(type) {
case string:
newMsg["content"] = SanitizeText(c)
case []interface{}:
newParts := make([]interface{}, len(c))
for j, p := range c {
if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" {
if t, ok := pm["text"].(string); ok {
newPart := make(map[string]interface{}, len(pm))
for k, v := range pm {
newPart[k] = v
}
newPart["text"] = SanitizeText(t)
newParts[j] = newPart
continue
}
}
newParts[j] = p
}
newMsg["content"] = newParts
}
result[i] = newMsg
}
return result
}
func SanitizeSystem(system interface{}) interface{} {
switch v := system.(type) {
case string:
return SanitizeText(v)
case []interface{}:
result := make([]interface{}, len(v))
for i, p := range v {
if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" {
if t, ok := pm["text"].(string); ok {
newPart := make(map[string]interface{}, len(pm))
for k, val := range pm {
newPart[k] = val
}
newPart["text"] = SanitizeText(t)
result[i] = newPart
continue
}
}
result[i] = p
}
return result
}
return system
}

View File

@ -1,60 +0,0 @@
package sanitize
import (
"strings"
"testing"
)
func TestSanitizeTextAnthropicBilling(t *testing.T) {
input := "x-anthropic-billing-header: abc123\nHello"
got := SanitizeText(input)
if strings.Contains(got, "x-anthropic-billing-header") {
t.Errorf("billing header not removed: %q", got)
}
}
func TestSanitizeTextClaudeCode(t *testing.T) {
input := "I am Claude Code assistant"
got := SanitizeText(input)
if strings.Contains(got, "Claude Code") {
t.Errorf("'Claude Code' not replaced: %q", got)
}
if !strings.Contains(got, "Cursor") {
t.Errorf("'Cursor' not present in output: %q", got)
}
}
func TestSanitizeTextAnthropic(t *testing.T) {
input := "Powered by Anthropic's technology and anthropic.com"
got := SanitizeText(input)
if strings.Contains(got, "Anthropic") {
t.Errorf("'Anthropic' not replaced: %q", got)
}
if strings.Contains(got, "anthropic.com") {
t.Errorf("'anthropic.com' not replaced: %q", got)
}
}
func TestSanitizeTextNoChange(t *testing.T) {
input := "Hello, this is a normal message about cursor."
got := SanitizeText(input)
if got != input {
t.Errorf("unexpected change: %q -> %q", input, got)
}
}
func TestSanitizeMessages(t *testing.T) {
messages := []interface{}{
map[string]interface{}{"role": "user", "content": "Ask Claude Code something"},
map[string]interface{}{"role": "system", "content": "Use Anthropic's tools"},
}
result := SanitizeMessages(messages)
for _, raw := range result {
m := raw.(map[string]interface{})
c := m["content"].(string)
if strings.Contains(c, "Claude Code") || strings.Contains(c, "Anthropic") {
t.Errorf("found unsanitized content: %q", c)
}
}
}

View File

@ -1,159 +0,0 @@
package server
import (
"context"
"crypto/tls"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/handlers"
"cursor-api-proxy/internal/pool"
"cursor-api-proxy/internal/process"
"cursor-api-proxy/internal/logger"
"cursor-api-proxy/internal/router"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
type ServerOptions struct {
Version string
Config config.BridgeConfig
Pool pool.PoolHandle
}
func StartBridgeServer(opts ServerOptions) []*http.Server {
cfg := opts.Config
var servers []*http.Server
if len(cfg.ConfigDirs) > 0 {
if cfg.MultiPort {
for i, dir := range cfg.ConfigDirs {
port := cfg.Port + i
subCfg := cfg
subCfg.Port = port
subCfg.ConfigDirs = []string{dir}
subCfg.MultiPort = false
subPool := pool.NewAccountPool([]string{dir})
srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg, Pool: subPool})
servers = append(servers, srv)
}
return servers
}
pool.InitAccountPool(cfg.ConfigDirs)
}
servers = append(servers, startSingleServer(opts))
return servers
}
func startSingleServer(opts ServerOptions) *http.Server {
cfg := opts.Config
modelCache := &handlers.ModelCacheRef{}
lastModel := cfg.DefaultModel
ph := opts.Pool
if ph == nil {
ph = pool.GlobalPoolHandle{}
}
handler := router.NewRouter(router.RouterOptions{
Version: opts.Version,
Config: cfg,
ModelCache: modelCache,
LastModel: &lastModel,
Pool: ph,
})
handler = router.WrapWithRecovery(cfg.SessionsLogPath, handler)
useTLS := cfg.TLSCertPath != "" && cfg.TLSKeyPath != ""
srv := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Handler: handler,
}
if useTLS {
cert, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath)
if err != nil {
fmt.Fprintf(os.Stderr, "TLS error: %v\n", err)
os.Exit(1)
}
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
scheme := "http"
if useTLS {
scheme = "https"
}
go func() {
var err error
if useTLS {
err = srv.ListenAndServeTLS("", "")
} else {
err = srv.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
if isAddrInUse(err) {
fmt.Fprintf(os.Stderr, "❌ Port %d is already in use. Set CURSOR_BRIDGE_PORT to use a different port.\n", cfg.Port)
} else {
fmt.Fprintf(os.Stderr, "❌ Server error: %v\n", err)
}
os.Exit(1)
}
}()
logger.LogServerStart(opts.Version, scheme, cfg.Host, cfg.Port, cfg)
return srv
}
func SetupGracefulShutdown(servers []*http.Server, timeoutMs int) {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
go func() {
sig := <-sigCh
logger.LogShutdown(sig.String())
process.KillAllChildProcesses()
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
for _, srv := range servers {
_ = srv.Shutdown(ctx)
}
close(done)
}()
select {
case <-done:
os.Exit(0)
case <-ctx.Done():
fmt.Fprintln(os.Stderr, "[shutdown] Timed out waiting for connections to drain — forcing exit.")
os.Exit(1)
}
}()
}
func isAddrInUse(err error) bool {
return err != nil && (contains(err.Error(), "address already in use") || contains(err.Error(), "bind: address already in use"))
}
func contains(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsHelper(s, sub))
}
func containsHelper(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}

View File

@ -1,331 +0,0 @@
package server_test
import (
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/server"
"encoding/json"
"fmt"
"io"
"net"
"context"
"net/http"
"os"
"strings"
"testing"
"time"
)
// freePort 取得一個暫時可用的隨機 port
func freePort(t *testing.T) int {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := l.Addr().(*net.TCPAddr).Port
l.Close()
return port
}
// makeFakeAgentBin 建立一個 shell script模擬 agent 固定輸出
// sync 模式:直接輸出一行文字
// stream 模式:輸出 JSON stream 行
func makeFakeAgentBin(t *testing.T, syncOutput string) string {
t.Helper()
dir := t.TempDir()
script := dir + "/agent"
content := fmt.Sprintf(`#!/bin/sh
# 若有 --stream-json 則輸出 stream 格式
for arg; do
if [ "$arg" = "--stream-json" ]; then
printf '%%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"%s"}]}}'
printf '%%s\n' '{"type":"result","subtype":"success"}'
exit 0
fi
done
# 否則輸出 sync 格式
printf '%%s' '%s'
`, syncOutput, syncOutput)
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
t.Fatal(err)
}
return script
}
// makeFakeAgentBinWithModels 額外支援 --list-models 輸出
func makeFakeAgentBinWithModels(t *testing.T) string {
t.Helper()
dir := t.TempDir()
script := dir + "/agent"
content := `#!/bin/sh
for arg; do
if [ "$arg" = "--list-models" ]; then
printf 'claude-3-opus - Claude 3 Opus\n'
printf 'claude-3-sonnet - Claude 3 Sonnet\n'
exit 0
fi
if [ "$arg" = "--stream-json" ]; then
printf '%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}'
printf '%s\n' '{"type":"result","subtype":"success"}'
exit 0
fi
done
printf 'Hello from agent'
`
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
t.Fatal(err)
}
return script
}
func makeTestConfig(agentBin string, port int, overrides ...func(*config.BridgeConfig)) config.BridgeConfig {
cfg := config.BridgeConfig{
AgentBin: agentBin,
Host: "127.0.0.1",
Port: port,
DefaultModel: "auto",
Mode: "ask",
Force: false,
ApproveMcps: false,
StrictModel: true,
Workspace: os.TempDir(),
TimeoutMs: 30000,
SessionsLogPath: os.TempDir() + "/test-sessions.log",
ChatOnlyWorkspace: true,
Verbose: false,
MaxMode: false,
ConfigDirs: []string{},
MultiPort: false,
WinCmdlineMax: 30000,
}
for _, fn := range overrides {
fn(&cfg)
}
return cfg
}
func waitListening(t *testing.T, host string, port int, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 50*time.Millisecond)
if err == nil {
conn.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server on port %d did not start within %v", port, timeout)
}
func doRequest(t *testing.T, method, url, body string, headers map[string]string) (int, string) {
t.Helper()
var reqBody io.Reader
if body != "" {
reqBody = strings.NewReader(body)
}
req, err := http.NewRequest(method, url, reqBody)
if err != nil {
t.Fatal(err)
}
if body != "" {
req.Header.Set("Content-Type", "application/json")
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
return resp.StatusCode, string(data)
}
func TestBridgeServer_Health(t *testing.T) {
port := freePort(t)
agentBin := makeFakeAgentBinWithModels(t)
cfg := makeTestConfig(agentBin, port)
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
waitListening(t, "127.0.0.1", port, 3*time.Second)
defer func() {
for _, s := range srvs {
s.Shutdown(context.Background())
}
}()
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
if status != 200 {
t.Fatalf("status = %d, want 200; body: %s", status, body)
}
var result map[string]interface{}
json.Unmarshal([]byte(body), &result)
if result["ok"] != true {
t.Errorf("ok = %v, want true", result["ok"])
}
if result["version"] != "1.0.0" {
t.Errorf("version = %v, want 1.0.0", result["version"])
}
}
func TestBridgeServer_Models(t *testing.T) {
port := freePort(t)
agentBin := makeFakeAgentBinWithModels(t)
cfg := makeTestConfig(agentBin, port)
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
waitListening(t, "127.0.0.1", port, 3*time.Second)
defer func() {
for _, s := range srvs {
s.Shutdown(context.Background())
}
}()
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/v1/models", port), "", nil)
if status != 200 {
t.Fatalf("status = %d, want 200; body: %s", status, body)
}
var result map[string]interface{}
json.Unmarshal([]byte(body), &result)
if result["object"] != "list" {
t.Errorf("object = %v, want list", result["object"])
}
data := result["data"].([]interface{})
if len(data) < 2 {
t.Errorf("data len = %d, want >= 2", len(data))
}
}
func TestBridgeServer_Unauthorized(t *testing.T) {
port := freePort(t)
agentBin := makeFakeAgentBinWithModels(t)
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
c.RequiredKey = "secret123"
})
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
waitListening(t, "127.0.0.1", port, 3*time.Second)
defer func() {
for _, s := range srvs {
s.Shutdown(context.Background())
}
}()
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
if status != 401 {
t.Fatalf("status = %d, want 401; body: %s", status, body)
}
var result map[string]interface{}
json.Unmarshal([]byte(body), &result)
errObj := result["error"].(map[string]interface{})
if errObj["message"] != "Invalid API key" {
t.Errorf("message = %v, want 'Invalid API key'", errObj["message"])
}
}
func TestBridgeServer_AuthorizedKey(t *testing.T) {
port := freePort(t)
agentBin := makeFakeAgentBinWithModels(t)
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
c.RequiredKey = "secret123"
})
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
waitListening(t, "127.0.0.1", port, 3*time.Second)
defer func() {
for _, s := range srvs {
s.Shutdown(context.Background())
}
}()
status, _ := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", map[string]string{
"Authorization": "Bearer secret123",
})
if status != 200 {
t.Errorf("status = %d, want 200", status)
}
}
func TestBridgeServer_NotFound(t *testing.T) {
port := freePort(t)
agentBin := makeFakeAgentBinWithModels(t)
cfg := makeTestConfig(agentBin, port)
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
waitListening(t, "127.0.0.1", port, 3*time.Second)
defer func() {
for _, s := range srvs {
s.Shutdown(context.Background())
}
}()
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/unknown", port), "", nil)
if status != 404 {
t.Fatalf("status = %d, want 404; body: %s", status, body)
}
var result map[string]interface{}
json.Unmarshal([]byte(body), &result)
errObj := result["error"].(map[string]interface{})
if errObj["code"] != "not_found" {
t.Errorf("code = %v, want not_found", errObj["code"])
}
}
func TestBridgeServer_ChatCompletions_Sync(t *testing.T) {
port := freePort(t)
agentBin := makeFakeAgentBin(t, "Hello from agent")
cfg := makeTestConfig(agentBin, port)
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
waitListening(t, "127.0.0.1", port, 3*time.Second)
defer func() {
for _, s := range srvs {
s.Shutdown(context.Background())
}
}()
reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hi"}]}`
status, body := doRequest(t, "POST", fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", port), reqBody, nil)
if status != 200 {
t.Fatalf("status = %d, want 200; body: %s", status, body)
}
var result map[string]interface{}
json.Unmarshal([]byte(body), &result)
if result["object"] != "chat.completion" {
t.Errorf("object = %v, want chat.completion", result["object"])
}
choices := result["choices"].([]interface{})
msg := choices[0].(map[string]interface{})["message"].(map[string]interface{})
if msg["content"] != "Hello from agent" {
t.Errorf("content = %v, want 'Hello from agent'", msg["content"])
}
}
func TestBridgeServer_MultiPort(t *testing.T) {
basePort := freePort(t)
agentBin := makeFakeAgentBinWithModels(t)
dir1 := t.TempDir()
dir2 := t.TempDir()
cfg := makeTestConfig(agentBin, basePort, func(c *config.BridgeConfig) {
c.ConfigDirs = []string{dir1, dir2}
c.MultiPort = true
})
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
if len(srvs) != 2 {
t.Fatalf("got %d servers, want 2", len(srvs))
}
// 等待兩個 server 啟動port 可能會衝突,這裡不嚴格測試 port 分配)
time.Sleep(200 * time.Millisecond)
defer func() {
for _, s := range srvs {
s.Shutdown(context.Background())
}
}()
}

View File

@ -1,154 +0,0 @@
package toolcall
import (
"encoding/json"
"regexp"
"strings"
)
type ToolCall struct {
Name string
Arguments string // JSON string
}
type ParsedResponse struct {
TextContent string
ToolCalls []ToolCall
}
func (p *ParsedResponse) HasToolCalls() bool {
return len(p.ToolCalls) > 0
}
// Modified regex to handle nested JSON
var toolCallTagRe = regexp.MustCompile(`(?s)行政法规\s*(\{(?:[^{}]|\{[^{}]*\})*\})\s*ugalakh`)
var antmlFunctionCallsRe = regexp.MustCompile(`(?s)<function_calls>\s*(.*?)\s*</function_calls>`)
var antmlInvokeRe = regexp.MustCompile(`(?s)<invoke\s+name="([^"]+)">\s*(.*?)\s*</invoke>`)
var antmlParamRe = regexp.MustCompile(`(?s)<parameter\s+name="([^"]+)">(.*?)</parameter>`)
func ExtractToolCalls(text string, toolNames map[string]bool) *ParsedResponse {
result := &ParsedResponse{}
if locs := toolCallTagRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
var calls []ToolCall
var textParts []string
last := 0
for _, loc := range locs {
if loc[0] > last {
textParts = append(textParts, text[last:loc[0]])
}
jsonStr := text[loc[2]:loc[3]]
if tc := parseToolCallJSON(jsonStr, toolNames); tc != nil {
calls = append(calls, *tc)
} else {
textParts = append(textParts, text[loc[0]:loc[1]])
}
last = loc[1]
}
if last < len(text) {
textParts = append(textParts, text[last:])
}
if len(calls) > 0 {
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
result.ToolCalls = calls
return result
}
}
if locs := antmlFunctionCallsRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
var calls []ToolCall
var textParts []string
last := 0
for _, loc := range locs {
if loc[0] > last {
textParts = append(textParts, text[last:loc[0]])
}
block := text[loc[2]:loc[3]]
invokes := antmlInvokeRe.FindAllStringSubmatch(block, -1)
for _, inv := range invokes {
name := inv[1]
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
continue
}
body := inv[2]
args := map[string]interface{}{}
params := antmlParamRe.FindAllStringSubmatch(body, -1)
for _, p := range params {
paramName := p[1]
paramValue := strings.TrimSpace(p[2])
var jsonVal interface{}
if err := json.Unmarshal([]byte(paramValue), &jsonVal); err == nil {
args[paramName] = jsonVal
} else {
args[paramName] = paramValue
}
}
argsJSON, _ := json.Marshal(args)
calls = append(calls, ToolCall{Name: name, Arguments: string(argsJSON)})
}
last = loc[1]
}
if last < len(text) {
textParts = append(textParts, text[last:])
}
if len(calls) > 0 {
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
result.ToolCalls = calls
return result
}
}
result.TextContent = text
return result
}
func parseToolCallJSON(jsonStr string, toolNames map[string]bool) *ToolCall {
var raw map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
return nil
}
name, _ := raw["name"].(string)
if name == "" {
return nil
}
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
return nil
}
var argsStr string
switch a := raw["arguments"].(type) {
case string:
argsStr = a
case map[string]interface{}, []interface{}:
b, _ := json.Marshal(a)
argsStr = string(b)
default:
if p, ok := raw["parameters"]; ok {
b, _ := json.Marshal(p)
argsStr = string(b)
} else {
argsStr = "{}"
}
}
return &ToolCall{Name: name, Arguments: argsStr}
}
func CollectToolNames(tools []interface{}) map[string]bool {
names := map[string]bool{}
for _, t := range tools {
m, ok := t.(map[string]interface{})
if !ok {
continue
}
if m["type"] == "function" {
if fn, ok := m["function"].(map[string]interface{}); ok {
if name, ok := fn["name"].(string); ok {
names[name] = true
}
}
}
if name, ok := m["name"].(string); ok {
names[name] = true
}
}
return names
}

View File

@ -1,181 +0,0 @@
package winlimit
import (
"cursor-api-proxy/internal/env"
"runtime"
)
const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n"
const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n"
// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux.
// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars.
func safeLinuxArgMax() int {
return 1536 * 1024
}
type FitPromptResult struct {
OK bool
Args []string
Truncated bool
OriginalLength int
FinalPromptLength int
Error string
}
func estimateCmdlineLength(resolved env.AgentCommand) int {
argv := append([]string{resolved.Command}, resolved.Args...)
if resolved.WindowsVerbatimArguments {
n := 0
for _, a := range argv {
n += len(a)
}
if len(argv) > 1 {
n += len(argv) - 1
}
return n + 512
}
dstLen := 0
for _, a := range argv {
dstLen += len(a)
}
dstLen = dstLen*2 + len(argv)*2
if len(argv) > 1 {
dstLen += len(argv) - 1
}
return dstLen + 512
}
func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult {
if runtime.GOOS != "windows" {
return fitPromptLinux(fixedArgs, prompt)
}
e := env.OsEnvToMap()
measured := func(p string) int {
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = p
resolved := env.ResolveAgentCommand(agentBin, args, e, cwd)
return estimateCmdlineLength(resolved)
}
if measured("") > maxCmdline {
return FitPromptResult{
OK: false,
Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.",
}
}
if measured(prompt) <= maxCmdline {
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = prompt
return FitPromptResult{
OK: true,
Args: args,
Truncated: false,
OriginalLength: len(prompt),
FinalPromptLength: len(prompt),
}
}
prefix := WinPromptOmissionPrefix
if measured(prefix) > maxCmdline {
return FitPromptResult{
OK: false,
Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.",
}
}
lo, hi, best := 0, len(prompt), 0
for lo <= hi {
mid := (lo + hi) / 2
var tail string
if mid > 0 {
tail = prompt[len(prompt)-mid:]
}
candidate := prefix + tail
if measured(candidate) <= maxCmdline {
best = mid
lo = mid + 1
} else {
hi = mid - 1
}
}
var finalPrompt string
if best == 0 {
finalPrompt = prefix
} else {
finalPrompt = prefix + prompt[len(prompt)-best:]
}
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = finalPrompt
return FitPromptResult{
OK: true,
Args: args,
Truncated: true,
OriginalLength: len(prompt),
FinalPromptLength: len(finalPrompt),
}
}
// fitPromptLinux handles Linux ARG_MAX truncation.
func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult {
argMax := safeLinuxArgMax()
// Estimate total cmdline size: sum of all fixed args + prompt + null terminators
fixedLen := 0
for _, a := range fixedArgs {
fixedLen += len(a) + 1
}
totalLen := fixedLen + len(prompt) + 1
if totalLen <= argMax {
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = prompt
return FitPromptResult{
OK: true,
Args: args,
Truncated: false,
OriginalLength: len(prompt),
FinalPromptLength: len(prompt),
}
}
// Need to truncate: keep the tail of the prompt (most recent messages)
prefix := LinuxPromptOmissionPrefix
available := argMax - fixedLen - len(prefix) - 1
if available < 0 {
available = 0
}
var finalPrompt string
if available <= 0 {
finalPrompt = prefix
} else if available >= len(prompt) {
finalPrompt = prefix + prompt
} else {
finalPrompt = prefix + prompt[len(prompt)-available:]
}
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = finalPrompt
return FitPromptResult{
OK: true,
Args: args,
Truncated: true,
OriginalLength: len(prompt),
FinalPromptLength: len(finalPrompt),
}
}
func WarnPromptTruncated(originalLength, finalLength int) {
_ = originalLength
_ = finalLength
}

View File

@ -1,37 +0,0 @@
package winlimit
import (
"runtime"
"strings"
"testing"
)
func TestNonWindowsPassThrough(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping non-Windows test on Windows")
}
fixedArgs := []string{"--print", "--model", "gpt-4"}
prompt := "Hello world"
result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp")
if !result.OK {
t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error)
}
if result.Truncated {
t.Error("expected no truncation on non-Windows")
}
if result.OriginalLength != len(prompt) {
t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength)
}
// Last arg should be the prompt
if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt {
t.Errorf("expected last arg to be prompt, got %v", result.Args)
}
}
func TestOmissionPrefix(t *testing.T) {
if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") {
t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix)
}
}

View File

@ -1,30 +0,0 @@
package workspace
import (
"cursor-api-proxy/internal/config"
"os"
"path/filepath"
"strings"
)
type WorkspaceResult struct {
WorkspaceDir string
TempDir string
}
func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult {
if cfg.ChatOnlyWorkspace {
tempDir, err := os.MkdirTemp("", "cursor-proxy-")
if err != nil {
tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback")
_ = os.MkdirAll(tempDir, 0700)
}
return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir}
}
headerWs := strings.TrimSpace(workspaceHeader)
if headerWs != "" {
return WorkspaceResult{WorkspaceDir: headerWs}
}
return WorkspaceResult{WorkspaceDir: cfg.Workspace}
}

73
main.go
View File

@ -1,26 +1,56 @@
package main package main
import ( import (
"cursor-api-proxy/cmd" "flag"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/env"
"cursor-api-proxy/internal/server"
"fmt" "fmt"
"os" "os"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/handler"
"cursor-api-proxy/internal/svc"
cmd "cursor-api-proxy/cmd/cli"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/rest"
) )
const version = "1.0.0" const version = "1.0.0"
var configFile = flag.String("f", "etc/chat-api.yaml", "the config file")
func main() { func main() {
args, err := cmd.ParseArgs(os.Args[1:]) // Check for CLI commands first (before flag.Parse)
if err != nil { args := os.Args[1:]
fmt.Fprintf(os.Stderr, "Error: %v\n", err) if len(args) > 0 {
os.Exit(1) parsed, err := cmd.ParseArgs(args)
if err != nil {
// Not a CLI command, proceed to HTTP server
} else if handleCLICommand(parsed) {
return
}
} }
// HTTP server mode (go-zero)
flag.Parse()
var c config.Config
conf.MustLoad(*configFile, &c)
server := rest.MustNewServer(c.RestConf)
defer server.Stop()
ctx := svc.NewServiceContext(c)
handler.RegisterHandlers(server, ctx)
fmt.Printf("Starting server at %s:%d...\n", c.Host, c.Port)
server.Start()
}
func handleCLICommand(args cmd.ParsedArgs) bool {
if args.Help { if args.Help {
cmd.PrintHelp(version) cmd.PrintHelp(version)
return return true
} }
if args.Login { if args.Login {
@ -28,7 +58,7 @@ func main() {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
return return true
} }
if args.Logout { if args.Logout {
@ -36,7 +66,7 @@ func main() {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
return return true
} }
if args.AccountsList { if args.AccountsList {
@ -44,7 +74,7 @@ func main() {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
return return true
} }
if args.ResetHwid { if args.ResetHwid {
@ -52,22 +82,9 @@ func main() {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
return return true
} }
e := env.OsEnvToMap() // Not a CLI command
if args.Tailscale { return false
e["CURSOR_BRIDGE_HOST"] = "0.0.0.0"
}
cwd, _ := os.Getwd()
cfg := config.LoadBridgeConfig(e, cwd)
servers := server.StartBridgeServer(server.ServerOptions{
Version: version,
Config: cfg,
})
server.SetupGracefulShutdown(servers, 10000)
select {}
} }

View File

@ -1,13 +1,14 @@
package logger package logger
import ( import (
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/pool"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/pkg/domain/entity"
) )
const ( const (
@ -192,7 +193,7 @@ func LogAccountAssigned(configDir string) {
fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset) fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset)
} }
func LogAccountStats(verbose bool, stats []pool.AccountStat) { func LogAccountStats(verbose bool, stats []entity.AccountStat) {
if !verbose || len(stats) == 0 { if !verbose || len(stats) == 0 {
return return
} }

View File

@ -2,7 +2,7 @@ package process_test
import ( import (
"context" "context"
"cursor-api-proxy/internal/process" "cursor-api-proxy/pkg/infrastructure/process"
"os" "os"
"testing" "testing"
"time" "time"

View File

@ -3,7 +3,7 @@ package process
import ( import (
"bufio" "bufio"
"context" "context"
"cursor-api-proxy/internal/env" "cursor-api-proxy/pkg/infrastructure/env"
"fmt" "fmt"
"os/exec" "os/exec"
"strings" "strings"

View File

@ -1,7 +1,7 @@
package winlimit package winlimit
import ( import (
"cursor-api-proxy/internal/env" "cursor-api-proxy/pkg/infrastructure/env"
"runtime" "runtime"
) )

View File

@ -182,7 +182,7 @@ func (p *PlaywrightProvider) Generate(ctx context.Context, model string, message
fmt.Println("Browser is open. You can:") fmt.Println("Browser is open. You can:")
fmt.Println("1. Log in to Gemini now") fmt.Println("1. Log in to Gemini now")
fmt.Println("2. Continue without login") fmt.Println("2. Continue without login")
fmt.Println("========================================\n") fmt.Println("========================================")
} }
} else { } else {
fmt.Println("[GeminiWeb] ✓ Logged in") fmt.Println("[GeminiWeb] ✓ Logged in")

View File

@ -97,7 +97,7 @@ func (p *Provider) Generate(ctx context.Context, model string, messages []entity
fmt.Println("Browser is open. You can:") fmt.Println("Browser is open. You can:")
fmt.Println("1. Log in to Gemini now") fmt.Println("1. Log in to Gemini now")
fmt.Println("2. Continue without login") fmt.Println("2. Continue without login")
fmt.Println("========================================\n") fmt.Println("========================================")
} }
} else { } else {
fmt.Printf("[GeminiWeb] Logged in\n") fmt.Printf("[GeminiWeb] Logged in\n")

View File

@ -1,7 +1,7 @@
package main package main
import ( import (
"cursor-api-proxy/internal/providers/geminiweb" "cursor-api-proxy/pkg/provider/geminiweb"
"fmt" "fmt"
"os" "os"
@ -92,7 +92,7 @@ func analyzeDOM(page *rod.Page) {
ariaLabel, _ := el.Attribute("aria-label") ariaLabel, _ := el.Attribute("aria-label")
placeholder, _ := el.Attribute("placeholder") placeholder, _ := el.Attribute("placeholder")
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s placeholder=%s\n", fmt.Printf(" [%d] tag=%s class=%s aria-label=%s placeholder=%s\n",
i, tag, class, ariaLabel, placeholder) i, tag, ptrToStr(class), ptrToStr(ariaLabel), ptrToStr(placeholder))
} }
} }
} }
@ -120,7 +120,7 @@ func analyzeDOM(page *rod.Page) {
text, _ := el.Text() text, _ := el.Text()
text = truncate(text, 30) text = truncate(text, 30)
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n", fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n",
i, tag, class, ariaLabel, text) i, tag, ptrToStr(class), ptrToStr(ariaLabel), text)
} }
} }
} }
@ -145,7 +145,7 @@ func analyzeDOM(page *rod.Page) {
ariaLabel, _ := el.Attribute("aria-label") ariaLabel, _ := el.Attribute("aria-label")
text, _ := el.Text() text, _ := el.Text()
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n", fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n",
i, tag, class, ariaLabel, truncate(text, 30)) i, tag, ptrToStr(class), ptrToStr(ariaLabel), truncate(text, 30))
} }
} }
} }
@ -157,3 +157,10 @@ func truncate(s string, max int) string {
} }
return s[:max] + "..." return s[:max] + "..."
} }
func ptrToStr(s *string) string {
if s == nil {
return "<nil>"
}
return *s
}