first commit
This commit is contained in:
commit
a2f1d05391
|
|
@ -0,0 +1,102 @@
|
|||
# Cursor API Proxy
|
||||
|
||||
[English](./README.md) | 繁體中文
|
||||
|
||||
一個讓你可以透過標準 OpenAI/Anthropic API 格式存取 Cursor AI 編輯器的代理伺服器。
|
||||
|
||||
## 功能特色
|
||||
|
||||
- **API 相容**:支援 OpenAI 格式和 Anthropic 格式的 API 呼叫
|
||||
- **多帳號管理**:支援新增、移除、切換多個 Cursor 帳號
|
||||
- **Tailscale 支援**:可綁定到 `0.0.0.0` 供區域網路存取
|
||||
- **HWID 重置**:內建反偵測功能,可重置機器識別碼
|
||||
- **連線池**:最佳化的連線管理
|
||||
|
||||
## 安裝
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-repo/cursor-api-proxy-go.git
|
||||
cd cursor-api-proxy-go
|
||||
go build -o cursor-api-proxy .
|
||||
```
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 啟動伺服器
|
||||
|
||||
```bash
|
||||
./cursor-api-proxy
|
||||
```
|
||||
|
||||
預設監聽 `127.0.0.1:8080`。
|
||||
|
||||
### 登入帳號
|
||||
|
||||
```bash
|
||||
# 登入帳號
|
||||
./cursor-api-proxy login myaccount
|
||||
|
||||
# 使用代理登入
|
||||
./cursor-api-proxy login myaccount --proxy=http://127.0.0.1:7890
|
||||
```
|
||||
|
||||
### 列出帳號
|
||||
|
||||
```bash
|
||||
./cursor-api-proxy accounts
|
||||
```
|
||||
|
||||
### 登出帳號
|
||||
|
||||
```bash
|
||||
./cursor-api-proxy logout myaccount
|
||||
```
|
||||
|
||||
### 重置 HWID(反BAN)
|
||||
|
||||
```bash
|
||||
# 基本重置
|
||||
./cursor-api-proxy reset-hwid
|
||||
|
||||
# 深度清理(清除 session 和 cookies)
|
||||
./cursor-api-proxy reset-hwid --deep-clean
|
||||
```
|
||||
|
||||
### 其他選項
|
||||
|
||||
| 選項 | 說明 |
|
||||
|------|------|
|
||||
| `--tailscale` | 綁定到 `0.0.0.0` 供區域網路存取 |
|
||||
| `-h, --help` | 顯示說明 |
|
||||
|
||||
## API 端點
|
||||
|
||||
| 端點 | 方法 | 說明 |
|
||||
|------|------|------|
|
||||
| `http://127.0.0.1:8080/v1/chat/completions` | POST | OpenAI 格式聊天完成 |
|
||||
| `http://127.0.0.1:8080/v1/models` | GET | 列出可用模型 |
|
||||
| `http://127.0.0.1:8080/v1/chat/messages` | POST | Anthropic 格式聊天 |
|
||||
| `http://127.0.0.1:8080/health` | GET | 健康檢查 |
|
||||
|
||||
## 環境變數
|
||||
|
||||
| 變數 | 預設值 | 說明 |
|
||||
|------|--------|------|
|
||||
| `CURSOR_BRIDGE_HOST` | `127.0.0.1` | 監聽位址 |
|
||||
| `CURSOR_BRIDGE_PORT` | `8080` | 監聽連接埠 |
|
||||
| `HTTPS_PROXY` | - | HTTP 代理伺服器 |
|
||||
|
||||
## 常見問題
|
||||
|
||||
**Q: 為什麼需要登入帳號?**
|
||||
A: Cursor API 需要驗證才能使用,請先登入你的 Cursor 帳號。
|
||||
|
||||
**Q: 如何處理被BAN的問題?**
|
||||
A: 使用 `reset-hwid` 命令重置機器識別碼,加上 `--deep-clean` 進行更徹底的清理。
|
||||
|
||||
**Q: 可以在其他設備上使用嗎?**
|
||||
A: 可以,使用 `--tailscale` 選項啟動伺服器,然後透過區域網路 IP 存取。
|
||||
|
||||
## 授權
|
||||
|
||||
MIT License
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type AccountInfo struct {
|
||||
Name string
|
||||
ConfigDir string
|
||||
Authenticated bool
|
||||
Email string
|
||||
DisplayName string
|
||||
AuthID string
|
||||
Plan string
|
||||
SubscriptionStatus string
|
||||
ExpiresAt string
|
||||
}
|
||||
|
||||
func ReadAccountInfo(name, configDir string) AccountInfo {
|
||||
info := AccountInfo{Name: name, ConfigDir: configDir}
|
||||
|
||||
configFile := filepath.Join(configDir, "cli-config.json")
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
var raw struct {
|
||||
AuthInfo *struct {
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"displayName"`
|
||||
AuthID string `json:"authId"`
|
||||
} `json:"authInfo"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err == nil && raw.AuthInfo != nil {
|
||||
info.Authenticated = true
|
||||
info.Email = raw.AuthInfo.Email
|
||||
info.DisplayName = raw.AuthInfo.DisplayName
|
||||
info.AuthID = raw.AuthInfo.AuthID
|
||||
}
|
||||
|
||||
statsigFile := filepath.Join(configDir, "statsig-cache.json")
|
||||
statsigData, err := os.ReadFile(statsigFile)
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
var statsigWrapper struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(statsigData, &statsigWrapper); err != nil || statsigWrapper.Data == "" {
|
||||
return info
|
||||
}
|
||||
|
||||
var statsig struct {
|
||||
User *struct {
|
||||
Custom *struct {
|
||||
IsEnterpriseUser bool `json:"isEnterpriseUser"`
|
||||
StripeSubscriptionStatus string `json:"stripeSubscriptionStatus"`
|
||||
StripeMembershipExpiration string `json:"stripeMembershipExpiration"`
|
||||
} `json:"custom"`
|
||||
} `json:"user"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(statsigWrapper.Data), &statsig); err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
if statsig.User != nil && statsig.User.Custom != nil {
|
||||
c := statsig.User.Custom
|
||||
if c.IsEnterpriseUser {
|
||||
info.Plan = "Enterprise"
|
||||
} else if c.StripeSubscriptionStatus == "active" {
|
||||
info.Plan = "Pro"
|
||||
} else {
|
||||
info.Plan = "Free"
|
||||
}
|
||||
info.SubscriptionStatus = c.StripeSubscriptionStatus
|
||||
info.ExpiresAt = c.StripeMembershipExpiration
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func HandleAccountsList() error {
|
||||
accountsDir := agent.AccountsDir()
|
||||
|
||||
entries, err := os.ReadDir(accountsDir)
|
||||
if err != nil {
|
||||
fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.")
|
||||
return nil
|
||||
}
|
||||
|
||||
var names []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Print("Cursor Accounts:\n\n")
|
||||
|
||||
keychainToken := agent.ReadKeychainToken()
|
||||
|
||||
for i, name := range names {
|
||||
configDir := filepath.Join(accountsDir, name)
|
||||
info := ReadAccountInfo(name, configDir)
|
||||
|
||||
fmt.Printf(" %d. %s\n", i+1, name)
|
||||
|
||||
if info.Authenticated {
|
||||
cachedToken := agent.ReadCachedToken(configDir)
|
||||
keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID
|
||||
token := cachedToken
|
||||
if token == "" && keychainMatchesAccount {
|
||||
token = keychainToken
|
||||
}
|
||||
|
||||
var liveProfile *StripeProfile
|
||||
var liveUsage *UsageData
|
||||
if token != "" {
|
||||
liveUsage, _ = FetchAccountUsage(token)
|
||||
liveProfile, _ = FetchStripeProfile(token)
|
||||
}
|
||||
|
||||
if info.Email != "" {
|
||||
display := ""
|
||||
if info.DisplayName != "" {
|
||||
display = " (" + info.DisplayName + ")"
|
||||
}
|
||||
fmt.Printf(" %s%s\n", info.Email, display)
|
||||
}
|
||||
|
||||
if info.Plan != "" && liveProfile == nil {
|
||||
canceled := ""
|
||||
if info.SubscriptionStatus == "canceled" {
|
||||
canceled = " · canceled"
|
||||
}
|
||||
expiry := ""
|
||||
if info.ExpiresAt != "" {
|
||||
expiry = " · expires " + info.ExpiresAt
|
||||
}
|
||||
fmt.Printf(" %s%s%s\n", info.Plan, canceled, expiry)
|
||||
}
|
||||
fmt.Println(" Authenticated")
|
||||
|
||||
if liveProfile != nil {
|
||||
fmt.Printf(" %s\n", DescribePlan(liveProfile))
|
||||
}
|
||||
if liveUsage != nil {
|
||||
for _, line := range FormatUsageSummary(liveUsage) {
|
||||
fmt.Println(line)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Println(" Not authenticated")
|
||||
}
|
||||
|
||||
fmt.Println("")
|
||||
}
|
||||
|
||||
fmt.Println("Tip: run 'cursor-api-proxy logout <name>' to remove an account.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleLogout(accountName string) error {
|
||||
if accountName == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: Please specify the account name to remove.")
|
||||
fmt.Fprintln(os.Stderr, "Usage: cursor-api-proxy logout <account-name>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
accountsDir := agent.AccountsDir()
|
||||
configDir := filepath.Join(accountsDir, accountName)
|
||||
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "Account '%s' not found.\n", accountName)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(configDir); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error removing account: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Account '%s' removed.\n", accountName)
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
package cmd
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ParsedArgs struct {
|
||||
Tailscale bool
|
||||
Help bool
|
||||
Login bool
|
||||
AccountsList bool
|
||||
Logout bool
|
||||
AccountName string
|
||||
Proxies []string
|
||||
ResetHwid bool
|
||||
DeepClean bool
|
||||
DryRun bool
|
||||
}
|
||||
|
||||
func ParseArgs(argv []string) (ParsedArgs, error) {
|
||||
var args ParsedArgs
|
||||
|
||||
for i := 0; i < len(argv); i++ {
|
||||
arg := argv[i]
|
||||
|
||||
switch arg {
|
||||
case "login", "add-account":
|
||||
args.Login = true
|
||||
if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' {
|
||||
i++
|
||||
args.AccountName = argv[i]
|
||||
}
|
||||
|
||||
case "logout", "remove-account":
|
||||
args.Logout = true
|
||||
if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' {
|
||||
i++
|
||||
args.AccountName = argv[i]
|
||||
}
|
||||
|
||||
case "accounts", "list-accounts":
|
||||
args.AccountsList = true
|
||||
|
||||
case "reset-hwid", "reset":
|
||||
args.ResetHwid = true
|
||||
|
||||
case "--deep-clean":
|
||||
args.DeepClean = true
|
||||
|
||||
case "--dry-run":
|
||||
args.DryRun = true
|
||||
|
||||
case "--tailscale":
|
||||
args.Tailscale = true
|
||||
|
||||
case "--help", "-h":
|
||||
args.Help = true
|
||||
|
||||
default:
|
||||
if len(arg) > len("--proxy=") && arg[:len("--proxy=")] == "--proxy=" {
|
||||
raw := arg[len("--proxy="):]
|
||||
parts := splitComma(raw)
|
||||
for _, p := range parts {
|
||||
if p != "" {
|
||||
args.Proxies = append(args.Proxies, p)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return args, fmt.Errorf("Unknown argument: %s", arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func splitComma(s string) []string {
|
||||
var result []string
|
||||
start := 0
|
||||
for i := 0; i <= len(s); i++ {
|
||||
if i == len(s) || s[i] == ',' {
|
||||
part := trim(s[start:i])
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func trim(s string) string {
|
||||
start := 0
|
||||
end := len(s)
|
||||
for start < end && (s[start] == ' ' || s[start] == '\t') {
|
||||
start++
|
||||
}
|
||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
|
||||
end--
|
||||
}
|
||||
return s[start:end]
|
||||
}
|
||||
|
||||
func PrintHelp(version string) {
|
||||
fmt.Printf("cursor-api-proxy v%s\n\n", version)
|
||||
fmt.Println("Usage:")
|
||||
fmt.Println(" cursor-api-proxy [options]")
|
||||
fmt.Println("")
|
||||
fmt.Println("Commands:")
|
||||
fmt.Println(" login [name] Log into a Cursor account (saved to ~/.cursor-api-proxy/accounts/)")
|
||||
fmt.Println(" login [name] --proxy=... Same, but with a proxy from a comma-separated list")
|
||||
fmt.Println(" logout <name> Remove a saved Cursor account")
|
||||
fmt.Println(" accounts List saved accounts with plan info")
|
||||
fmt.Println(" reset-hwid Reset Cursor machine/telemetry IDs (anti-ban)")
|
||||
fmt.Println(" reset-hwid --deep-clean Also wipe session storage and cookies")
|
||||
fmt.Println("")
|
||||
fmt.Println("Options:")
|
||||
fmt.Println(" --tailscale Bind to 0.0.0.0 for tailnet/LAN access")
|
||||
fmt.Println(" -h, --help Show this help message")
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`)
|
||||
|
||||
func HandleLogin(accountName string, proxies []string) error {
|
||||
e := env.OsEnvToMap()
|
||||
loaded := env.LoadEnvConfig(e, "")
|
||||
agentBin := loaded.AgentBin
|
||||
|
||||
if accountName == "" {
|
||||
accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000)
|
||||
}
|
||||
|
||||
accountsDir := agent.AccountsDir()
|
||||
configDir := filepath.Join(accountsDir, accountName)
|
||||
dirWasNew := !fileExists(configDir)
|
||||
|
||||
if err := os.MkdirAll(accountsDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create accounts dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config dir: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Logging into Cursor account: %s\n", accountName)
|
||||
fmt.Printf("Config: %s\n\n", configDir)
|
||||
fmt.Println("Run the login command — complete the login in your browser.")
|
||||
fmt.Println("")
|
||||
|
||||
cleanupDir := func() {
|
||||
if dirWasNew {
|
||||
_ = os.RemoveAll(configDir)
|
||||
}
|
||||
}
|
||||
|
||||
cmdEnv := make([]string, 0, len(e)+2)
|
||||
for k, v := range e {
|
||||
cmdEnv = append(cmdEnv, k+"="+v)
|
||||
}
|
||||
cmdEnv = append(cmdEnv, "CURSOR_CONFIG_DIR="+configDir)
|
||||
cmdEnv = append(cmdEnv, "NO_OPEN_BROWSER=1")
|
||||
|
||||
child := exec.Command(agentBin, "login")
|
||||
child.Env = cmdEnv
|
||||
child.Stdin = os.Stdin
|
||||
child.Stderr = os.Stderr
|
||||
|
||||
stdoutPipe, err := child.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := child.Start(); err != nil {
|
||||
cleanupDir()
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("could not find '%s'. Make sure the Cursor CLI is installed", agentBin)
|
||||
}
|
||||
return fmt.Errorf("error launching agent login: %w", err)
|
||||
}
|
||||
|
||||
// Handle cancellation signals
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
go func() {
|
||||
sig := <-sigCh
|
||||
_ = child.Process.Kill()
|
||||
cleanupDir()
|
||||
if sig == syscall.SIGINT {
|
||||
fmt.Println("\n\nLogin cancelled.")
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
defer signal.Stop(sigCh)
|
||||
|
||||
var stdoutBuf string
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fmt.Println(line)
|
||||
stdoutBuf += line + "\n"
|
||||
|
||||
if loginURLRe.MatchString(stdoutBuf) {
|
||||
match := loginURLRe.FindString(stdoutBuf)
|
||||
if match != "" {
|
||||
fmt.Printf("\nOpen this URL in your browser (incognito recommended):\n%s\n\n", match)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := child.Wait(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
cleanupDir()
|
||||
return fmt.Errorf("login failed with code %d", exitErr.ExitCode())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache keychain token for this account
|
||||
token := agent.ReadKeychainToken()
|
||||
if token != "" {
|
||||
agent.WriteCachedToken(configDir, token)
|
||||
}
|
||||
|
||||
fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
|
@ -0,0 +1,261 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func sha256hex() string {
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
h := sha256.Sum256(b)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func sha512hex() string {
|
||||
b := make([]byte, 64)
|
||||
_, _ = rand.Read(b)
|
||||
h := sha512.Sum512(b)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func newUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func log(icon, msg string) {
|
||||
fmt.Printf(" %s %s\n", icon, msg)
|
||||
}
|
||||
|
||||
func getCursorGlobalStorage() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, "Library", "Application Support", "Cursor", "User", "globalStorage")
|
||||
case "windows":
|
||||
appdata := os.Getenv("APPDATA")
|
||||
return filepath.Join(appdata, "Cursor", "User", "globalStorage")
|
||||
default:
|
||||
xdg := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdg == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
xdg = filepath.Join(home, ".config")
|
||||
}
|
||||
return filepath.Join(xdg, "Cursor", "User", "globalStorage")
|
||||
}
|
||||
}
|
||||
|
||||
func getCursorRoot() string {
|
||||
gs := getCursorGlobalStorage()
|
||||
return filepath.Dir(filepath.Dir(gs))
|
||||
}
|
||||
|
||||
func generateNewIDs() map[string]string {
|
||||
return map[string]string{
|
||||
"telemetry.machineId": sha256hex(),
|
||||
"telemetry.macMachineId": sha512hex(),
|
||||
"telemetry.devDeviceId": newUUID(),
|
||||
"telemetry.sqmId": "{" + fmt.Sprintf("%s", newUUID()+"") + "}",
|
||||
"storage.serviceMachineId": newUUID(),
|
||||
}
|
||||
}
|
||||
|
||||
func killCursor() {
|
||||
log("", "Stopping Cursor processes...")
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
exec.Command("taskkill", "/F", "/IM", "Cursor.exe").Run()
|
||||
default:
|
||||
exec.Command("pkill", "-x", "Cursor").Run()
|
||||
exec.Command("pkill", "-f", "Cursor.app").Run()
|
||||
}
|
||||
log("", "Cursor stopped (or was not running)")
|
||||
}
|
||||
|
||||
func updateStorageJSON(storagePath string, ids map[string]string) {
|
||||
if _, err := os.Stat(storagePath); os.IsNotExist(err) {
|
||||
log("", fmt.Sprintf("storage.json not found: %s", storagePath))
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
exec.Command("chflags", "nouchg", storagePath).Run()
|
||||
exec.Command("chmod", "644", storagePath).Run()
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(storagePath)
|
||||
if err != nil {
|
||||
log("", fmt.Sprintf("storage.json read error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal(data, &obj); err != nil {
|
||||
log("", fmt.Sprintf("storage.json parse error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range ids {
|
||||
obj[k] = v
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
log("", fmt.Sprintf("storage.json marshal error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile(storagePath, out, 0644); err != nil {
|
||||
log("", fmt.Sprintf("storage.json write error: %v", err))
|
||||
return
|
||||
}
|
||||
log("", "storage.json updated")
|
||||
}
|
||||
|
||||
func updateStateVscdb(dbPath string, ids map[string]string) {
|
||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||
log("", fmt.Sprintf("state.vscdb not found: %s", dbPath))
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
exec.Command("chflags", "nouchg", dbPath).Run()
|
||||
exec.Command("chmod", "644", dbPath).Run()
|
||||
}
|
||||
|
||||
if err := updateVscdbPureGo(dbPath, ids); err != nil {
|
||||
log("", fmt.Sprintf("state.vscdb error: %v", err))
|
||||
} else {
|
||||
log("", "state.vscdb updated")
|
||||
}
|
||||
}
|
||||
|
||||
func updateMachineIDFile(machineID, cursorRoot string) {
|
||||
var candidates []string
|
||||
if runtime.GOOS == "linux" {
|
||||
candidates = []string{
|
||||
filepath.Join(cursorRoot, "machineid"),
|
||||
filepath.Join(cursorRoot, "machineId"),
|
||||
}
|
||||
} else {
|
||||
candidates = []string{filepath.Join(cursorRoot, "machineId")}
|
||||
}
|
||||
|
||||
filePath := candidates[0]
|
||||
for _, c := range candidates {
|
||||
if _, err := os.Stat(c); err == nil {
|
||||
filePath = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
|
||||
log("", fmt.Sprintf("machineId dir error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
exec.Command("chflags", "nouchg", filePath).Run()
|
||||
exec.Command("chmod", "644", filePath).Run()
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filePath, []byte(machineID+"\n"), 0644); err != nil {
|
||||
log("", fmt.Sprintf("machineId write error: %v", err))
|
||||
return
|
||||
}
|
||||
log("", fmt.Sprintf("machineId file updated (%s)", filepath.Base(filePath)))
|
||||
}
|
||||
|
||||
var dirsToWipe = []string{
|
||||
"Session Storage", "Local Storage", "IndexedDB", "Cache", "Code Cache",
|
||||
"GPUCache", "Service Worker", "Network", "Cookies", "Cookies-journal",
|
||||
}
|
||||
|
||||
func deepClean(cursorRoot string) {
|
||||
log("", "Deep-cleaning session data...")
|
||||
wiped := 0
|
||||
for _, name := range dirsToWipe {
|
||||
target := filepath.Join(cursorRoot, name)
|
||||
if _, err := os.Stat(target); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(target)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
if err := os.RemoveAll(target); err == nil {
|
||||
wiped++
|
||||
}
|
||||
} else {
|
||||
if err := os.Remove(target); err == nil {
|
||||
wiped++
|
||||
}
|
||||
}
|
||||
}
|
||||
log("", fmt.Sprintf("Wiped %d cache/session items", wiped))
|
||||
}
|
||||
|
||||
func HandleResetHwid(doDeepClean, dryRun bool) error {
|
||||
fmt.Print("\nCursor HWID Reset\n\n")
|
||||
fmt.Println(" Resets all machine / telemetry IDs so Cursor sees a fresh install.")
|
||||
fmt.Print(" Cursor must be closed — it will be killed automatically.\n\n")
|
||||
|
||||
globalStorage := getCursorGlobalStorage()
|
||||
cursorRoot := getCursorRoot()
|
||||
|
||||
if _, err := os.Stat(globalStorage); os.IsNotExist(err) {
|
||||
fmt.Printf("Cursor config not found at:\n %s\n", globalStorage)
|
||||
fmt.Println(" Make sure Cursor is installed and has been run at least once.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
fmt.Println(" [DRY RUN] Would reset IDs in:")
|
||||
fmt.Printf(" %s\n", filepath.Join(globalStorage, "storage.json"))
|
||||
fmt.Printf(" %s\n", filepath.Join(globalStorage, "state.vscdb"))
|
||||
fmt.Printf(" %s\n", filepath.Join(cursorRoot, "machineId"))
|
||||
return nil
|
||||
}
|
||||
|
||||
killCursor()
|
||||
|
||||
time.Sleep(800 * time.Millisecond)
|
||||
|
||||
newIDs := generateNewIDs()
|
||||
log("", "Generated new IDs:")
|
||||
for k, v := range newIDs {
|
||||
fmt.Printf(" %s: %s\n", k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
log("", "Updating storage.json...")
|
||||
updateStorageJSON(filepath.Join(globalStorage, "storage.json"), newIDs)
|
||||
|
||||
log("", "Updating state.vscdb...")
|
||||
updateStateVscdb(filepath.Join(globalStorage, "state.vscdb"), newIDs)
|
||||
|
||||
log("", "Updating machineId file...")
|
||||
updateMachineIDFile(newIDs["telemetry.machineId"], cursorRoot)
|
||||
|
||||
if doDeepClean {
|
||||
fmt.Println()
|
||||
deepClean(cursorRoot)
|
||||
}
|
||||
|
||||
fmt.Print("\nHWID reset complete. You can now restart Cursor.\n\n")
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func updateVscdbPureGo(dbPath string, ids map[string]string) error {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS ItemTable (key TEXT PRIMARY KEY, value TEXT NOT NULL)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create table: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range ids {
|
||||
_, err = db.Exec(`INSERT OR REPLACE INTO ItemTable (key, value) VALUES (?, ?)`, k, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert %s: %w", k, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ModelUsage struct {
|
||||
NumRequests int `json:"numRequests"`
|
||||
NumRequestsTotal int `json:"numRequestsTotal"`
|
||||
NumTokens int `json:"numTokens"`
|
||||
MaxTokenUsage *int `json:"maxTokenUsage"`
|
||||
MaxRequestUsage *int `json:"maxRequestUsage"`
|
||||
}
|
||||
|
||||
type UsageData struct {
|
||||
StartOfMonth string `json:"startOfMonth"`
|
||||
Models map[string]ModelUsage `json:"-"`
|
||||
}
|
||||
|
||||
type StripeProfile struct {
|
||||
MembershipType string `json:"membershipType"`
|
||||
SubscriptionStatus string `json:"subscriptionStatus"`
|
||||
DaysRemainingOnTrial *int `json:"daysRemainingOnTrial"`
|
||||
IsTeamMember bool `json:"isTeamMember"`
|
||||
IsYearlyPlan bool `json:"isYearlyPlan"`
|
||||
}
|
||||
|
||||
func DecodeJWTPayload(token string) map[string]interface{} {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil
|
||||
}
|
||||
padded := strings.ReplaceAll(parts[1], "-", "+")
|
||||
padded = strings.ReplaceAll(padded, "_", "/")
|
||||
data, err := base64.StdEncoding.DecodeString(padded + strings.Repeat("=", (4-len(padded)%4)%4))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TokenSub(token string) string {
|
||||
payload := DecodeJWTPayload(token)
|
||||
if payload == nil {
|
||||
return ""
|
||||
}
|
||||
if sub, ok := payload["sub"].(string); ok {
|
||||
return sub
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func apiGet(path, token string) (map[string]interface{}, error) {
|
||||
client := &http.Client{Timeout: 8 * time.Second}
|
||||
req, err := http.NewRequest("GET", "https://api2.cursor.sh"+path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func FetchAccountUsage(token string) (*UsageData, error) {
|
||||
raw, err := apiGet("/auth/usage", token)
|
||||
if err != nil || raw == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startOfMonth, _ := raw["startOfMonth"].(string)
|
||||
usage := &UsageData{
|
||||
StartOfMonth: startOfMonth,
|
||||
Models: make(map[string]ModelUsage),
|
||||
}
|
||||
|
||||
for k, v := range raw {
|
||||
if k == "startOfMonth" {
|
||||
continue
|
||||
}
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var mu ModelUsage
|
||||
if err := json.Unmarshal(data, &mu); err == nil {
|
||||
usage.Models[k] = mu
|
||||
}
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func FetchStripeProfile(token string) (*StripeProfile, error) {
|
||||
raw, err := apiGet("/auth/full_stripe_profile", token)
|
||||
if err != nil || raw == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profile := &StripeProfile{
|
||||
MembershipType: fmt.Sprintf("%v", raw["membershipType"]),
|
||||
SubscriptionStatus: fmt.Sprintf("%v", raw["subscriptionStatus"]),
|
||||
IsTeamMember: raw["isTeamMember"] == true,
|
||||
IsYearlyPlan: raw["isYearlyPlan"] == true,
|
||||
}
|
||||
if d, ok := raw["daysRemainingOnTrial"].(float64); ok {
|
||||
di := int(d)
|
||||
profile.DaysRemainingOnTrial = &di
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func DescribePlan(profile *StripeProfile) string {
|
||||
if profile == nil {
|
||||
return ""
|
||||
}
|
||||
switch profile.MembershipType {
|
||||
case "free_trial":
|
||||
days := 0
|
||||
if profile.DaysRemainingOnTrial != nil {
|
||||
days = *profile.DaysRemainingOnTrial
|
||||
}
|
||||
return fmt.Sprintf("Pro Trial (%dd left) — unlimited fast requests", days)
|
||||
case "pro":
|
||||
return "Pro — extended limits"
|
||||
case "pro_plus":
|
||||
return "Pro+ — extended limits"
|
||||
case "ultra":
|
||||
return "Ultra — extended limits"
|
||||
case "free", "hobby":
|
||||
return "Hobby (free) — limited agent requests"
|
||||
default:
|
||||
return fmt.Sprintf("%s · %s", profile.MembershipType, profile.SubscriptionStatus)
|
||||
}
|
||||
}
|
||||
|
||||
var modelLabels = map[string]string{
|
||||
"gpt-4": "Fast Premium Requests",
|
||||
"claude-sonnet-4-6": "Claude Sonnet 4.6",
|
||||
"claude-sonnet-4-5-20250929-v1": "Claude Sonnet 4.5",
|
||||
"claude-sonnet-4-20250514-v1": "Claude Sonnet 4",
|
||||
"claude-opus-4-6-v1": "Claude Opus 4.6",
|
||||
"claude-opus-4-5-20251101-v1": "Claude Opus 4.5",
|
||||
"claude-opus-4-1-20250805-v1": "Claude Opus 4.1",
|
||||
"claude-opus-4-20250514-v1": "Claude Opus 4",
|
||||
"claude-haiku-4-5-20251001-v1": "Claude Haiku 4.5",
|
||||
"claude-3-5-haiku-20241022-v1": "Claude 3.5 Haiku",
|
||||
"gpt-5": "GPT-5",
|
||||
"gpt-4o": "GPT-4o",
|
||||
"o1": "o1",
|
||||
"o3-mini": "o3-mini",
|
||||
"cursor-small": "Cursor Small (free)",
|
||||
}
|
||||
|
||||
func modelLabel(key string) string {
|
||||
if label, ok := modelLabels[key]; ok {
|
||||
return label
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func FormatUsageSummary(usage *UsageData) []string {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
var lines []string
|
||||
|
||||
start := "?"
|
||||
if usage.StartOfMonth != "" {
|
||||
if t, err := time.Parse(time.RFC3339, usage.StartOfMonth); err == nil {
|
||||
start = t.Format("2006-01-02")
|
||||
} else {
|
||||
start = usage.StartOfMonth
|
||||
}
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(" Billing period from %s", start))
|
||||
|
||||
if len(usage.Models) == 0 {
|
||||
lines = append(lines, " No requests this billing period")
|
||||
return lines
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
key string
|
||||
usage ModelUsage
|
||||
}
|
||||
var entries []entry
|
||||
for k, v := range usage.Models {
|
||||
entries = append(entries, entry{k, v})
|
||||
}
|
||||
|
||||
// Sort: entries with limits first, then by usage descending
|
||||
for i := 1; i < len(entries); i++ {
|
||||
for j := i; j > 0; j-- {
|
||||
a, b := entries[j-1], entries[j]
|
||||
aHasLimit := a.usage.MaxRequestUsage != nil
|
||||
bHasLimit := b.usage.MaxRequestUsage != nil
|
||||
if !aHasLimit && bHasLimit {
|
||||
entries[j-1], entries[j] = entries[j], entries[j-1]
|
||||
} else if aHasLimit == bHasLimit && a.usage.NumRequests < b.usage.NumRequests {
|
||||
entries[j-1], entries[j] = entries[j], entries[j-1]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
used := e.usage.NumRequests
|
||||
max := e.usage.MaxRequestUsage
|
||||
label := modelLabel(e.key)
|
||||
if max != nil && *max > 0 {
|
||||
pct := int(float64(used) / float64(*max) * 100)
|
||||
bar := makeBar(used, *max, 12)
|
||||
lines = append(lines, fmt.Sprintf(" %s: %d/%d (%d%%) [%s]", label, used, *max, pct, bar))
|
||||
} else if used > 0 {
|
||||
lines = append(lines, fmt.Sprintf(" %s: %d requests", label, used))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf(" %s: 0 requests (unlimited)", label))
|
||||
}
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
func makeBar(used, max, width int) string {
|
||||
fill := int(float64(used) / float64(max) * float64(width))
|
||||
if fill > width {
|
||||
fill = width
|
||||
}
|
||||
return strings.Repeat("█", fill) + strings.Repeat("░", width-fill)
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
module cursor-api-proxy
|
||||
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
modernc.org/libc v1.70.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.48.0 // indirect
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
|
||||
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.48.0 h1:ElZyLop3Q2mHYk5IFPPXADejZrlHu7APbpB0sF78bq4=
|
||||
modernc.org/sqlite v1.48.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
package agent
|
||||
|
||||
import "cursor-api-proxy/internal/config"
|
||||
|
||||
func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, stream bool) []string {
|
||||
args := []string{"--print"}
|
||||
if cfg.ApproveMcps {
|
||||
args = append(args, "--approve-mcps")
|
||||
}
|
||||
if cfg.Force {
|
||||
args = append(args, "--force")
|
||||
}
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
args = append(args, "--trust")
|
||||
}
|
||||
args = append(args, "--mode", "ask")
|
||||
args = append(args, "--workspace", workspaceDir)
|
||||
args = append(args, "--model", model)
|
||||
if stream {
|
||||
args = append(args, "--stream-partial-output", "--output-format", "stream-json")
|
||||
} else {
|
||||
args = append(args, "--output-format", "text")
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func BuildAgentCmdArgs(cfg config.BridgeConfig, workspaceDir, model, prompt string, stream bool) []string {
|
||||
return append(BuildAgentFixedArgs(cfg, workspaceDir, model, stream), prompt)
|
||||
}
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func getCandidates(agentScriptPath, configDirOverride string) []string {
|
||||
if configDirOverride != "" {
|
||||
return []string{filepath.Join(configDirOverride, "cli-config.json")}
|
||||
}
|
||||
|
||||
var result []string
|
||||
|
||||
if dir := os.Getenv("CURSOR_CONFIG_DIR"); dir != "" {
|
||||
result = append(result, filepath.Join(dir, "cli-config.json"))
|
||||
}
|
||||
|
||||
if agentScriptPath != "" {
|
||||
agentDir := filepath.Dir(agentScriptPath)
|
||||
result = append(result, filepath.Join(agentDir, "..", "data", "config", "cli-config.json"))
|
||||
}
|
||||
|
||||
home := os.Getenv("HOME")
|
||||
if home == "" {
|
||||
home = os.Getenv("USERPROFILE")
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
local := os.Getenv("LOCALAPPDATA")
|
||||
if local == "" {
|
||||
local = filepath.Join(home, "AppData", "Local")
|
||||
}
|
||||
result = append(result, filepath.Join(local, "cursor-agent", "cli-config.json"))
|
||||
case "darwin":
|
||||
result = append(result, filepath.Join(home, "Library", "Application Support", "cursor-agent", "cli-config.json"))
|
||||
default:
|
||||
xdg := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdg == "" {
|
||||
xdg = filepath.Join(home, ".config")
|
||||
}
|
||||
result = append(result, filepath.Join(xdg, "cursor-agent", "cli-config.json"))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func RunMaxModePreflight(agentScriptPath, configDirOverride string) {
|
||||
for _, candidate := range getCandidates(agentScriptPath, configDirOverride) {
|
||||
data, err := os.ReadFile(candidate)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Strip BOM if present
|
||||
if len(data) >= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF {
|
||||
data = data[3:]
|
||||
}
|
||||
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
continue
|
||||
}
|
||||
if raw == nil || len(raw) <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
raw["maxMode"] = true
|
||||
if model, ok := raw["model"].(map[string]interface{}); ok {
|
||||
model["maxMode"] = true
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(raw, "", " ")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := os.WriteFile(candidate, out, 0644); err != nil {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/process"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func init() {
|
||||
process.MaxModeFn = RunMaxModePreflight
|
||||
}
|
||||
|
||||
func cacheTokenForAccount(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
token := ReadKeychainToken()
|
||||
if token != "" {
|
||||
WriteCachedToken(configDir, token)
|
||||
}
|
||||
}
|
||||
|
||||
func AccountsDir() string {
|
||||
home := os.Getenv("HOME")
|
||||
if home == "" {
|
||||
home = os.Getenv("USERPROFILE")
|
||||
}
|
||||
return filepath.Join(home, ".cursor-api-proxy", "accounts")
|
||||
}
|
||||
|
||||
func RunAgentSync(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, tempDir, configDir string, ctx context.Context) (process.RunResult, error) {
|
||||
opts := process.RunOptions{
|
||||
Cwd: workspaceDir,
|
||||
TimeoutMs: cfg.TimeoutMs,
|
||||
MaxMode: cfg.MaxMode,
|
||||
ConfigDir: configDir,
|
||||
Ctx: ctx,
|
||||
}
|
||||
|
||||
result, err := process.Run(cfg.AgentBin, cmdArgs, opts)
|
||||
|
||||
cacheTokenForAccount(configDir)
|
||||
if tempDir != "" {
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func RunAgentStreamWithContext(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, onLine func(string), tempDir, configDir string, ctx context.Context) (process.StreamResult, error) {
|
||||
opts := process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{
|
||||
Cwd: workspaceDir,
|
||||
TimeoutMs: cfg.TimeoutMs,
|
||||
MaxMode: cfg.MaxMode,
|
||||
ConfigDir: configDir,
|
||||
Ctx: ctx,
|
||||
},
|
||||
OnLine: onLine,
|
||||
}
|
||||
|
||||
result, err := process.RunStreaming(cfg.AgentBin, cmdArgs, opts)
|
||||
|
||||
cacheTokenForAccount(configDir)
|
||||
if tempDir != "" {
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const tokenFile = ".cursor-token"
|
||||
|
||||
func ReadCachedToken(configDir string) string {
|
||||
p := filepath.Join(configDir, tokenFile)
|
||||
data, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
func WriteCachedToken(configDir, token string) {
|
||||
p := filepath.Join(configDir, tokenFile)
|
||||
_ = os.WriteFile(p, []byte(token), 0600)
|
||||
}
|
||||
|
||||
func ReadKeychainToken() string {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return ""
|
||||
}
|
||||
out, err := exec.Command("security", "find-generic-password", "-s", "cursor-access-token", "-w").Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
package anthropic
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/openai"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System interface{} `json:"system"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []interface{} `json:"tools"`
|
||||
}
|
||||
|
||||
func systemToText(system interface{}) string {
|
||||
if system == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if m, ok := p.(map[string]interface{}); ok {
|
||||
if m["type"] == "text" {
|
||||
if t, ok := m["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func anthropicBlockToText(p interface{}) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := p.(type) {
|
||||
case string:
|
||||
return v
|
||||
case map[string]interface{}:
|
||||
typ, _ := v["type"].(string)
|
||||
switch typ {
|
||||
case "text":
|
||||
if t, ok := v["text"].(string); ok {
|
||||
return t
|
||||
}
|
||||
case "image":
|
||||
if src, ok := v["source"].(map[string]interface{}); ok {
|
||||
srcType, _ := src["type"].(string)
|
||||
switch srcType {
|
||||
case "base64":
|
||||
mt, _ := src["media_type"].(string)
|
||||
if mt == "" {
|
||||
mt = "image"
|
||||
}
|
||||
return "[Image: base64 " + mt + "]"
|
||||
case "url":
|
||||
url, _ := src["url"].(string)
|
||||
return "[Image: " + url + "]"
|
||||
}
|
||||
}
|
||||
return "[Image]"
|
||||
case "document":
|
||||
title, _ := v["title"].(string)
|
||||
if title == "" {
|
||||
if src, ok := v["source"].(map[string]interface{}); ok {
|
||||
title, _ = src["url"].(string)
|
||||
}
|
||||
}
|
||||
if title != "" {
|
||||
return "[Document: " + title + "]"
|
||||
}
|
||||
return "[Document]"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func anthropicContentToText(content interface{}) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if t := anthropicBlockToText(p); t != "" {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func BuildPromptFromAnthropicMessages(messages []MessageParam, system interface{}) string {
|
||||
var oaiMessages []interface{}
|
||||
|
||||
systemText := systemToText(system)
|
||||
if systemText != "" {
|
||||
oaiMessages = append(oaiMessages, map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": systemText,
|
||||
})
|
||||
}
|
||||
|
||||
for _, m := range messages {
|
||||
text := anthropicContentToText(m.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
role := m.Role
|
||||
if role != "user" && role != "assistant" {
|
||||
role = "user"
|
||||
}
|
||||
oaiMessages = append(oaiMessages, map[string]interface{}{
|
||||
"role": role,
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
|
||||
return openai.BuildPromptFromMessages(oaiMessages)
|
||||
}
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
package anthropic_test
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/anthropic"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_Simple(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "Hello") {
|
||||
t.Errorf("prompt missing user message: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "Hi there") {
|
||||
t.Errorf("prompt missing assistant message: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_WithSystem(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "ping"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, "You are a helpful bot.")
|
||||
if !strings.Contains(prompt, "You are a helpful bot.") {
|
||||
t.Errorf("prompt missing system: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "ping") {
|
||||
t.Errorf("prompt missing user: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_SystemArray(t *testing.T) {
|
||||
system := []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "Part A"},
|
||||
map[string]interface{}{"type": "text", "text": "Part B"},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "test"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, system)
|
||||
if !strings.Contains(prompt, "Part A") {
|
||||
t.Errorf("prompt missing Part A: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "Part B") {
|
||||
t.Errorf("prompt missing Part B: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_ContentBlocks(t *testing.T) {
|
||||
content := []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "block one"},
|
||||
map[string]interface{}{"type": "text", "text": "block two"},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: content},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "block one") {
|
||||
t.Errorf("prompt missing 'block one': %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "block two") {
|
||||
t.Errorf("prompt missing 'block two': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_ImageBlock(t *testing.T) {
|
||||
content := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "abc123",
|
||||
},
|
||||
},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: content},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "[Image") {
|
||||
t.Errorf("prompt missing [Image]: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_EmptyContentSkipped(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: ""},
|
||||
{Role: "assistant", Content: "response"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "response") {
|
||||
t.Errorf("prompt missing 'response': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_UnknownRoleBecomesUser(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "system", Content: "system-as-user"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "system-as-user") {
|
||||
t.Errorf("prompt missing 'system-as-user': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/env"
|
||||
)
|
||||
|
||||
type BridgeConfig struct {
|
||||
AgentBin string
|
||||
Host string
|
||||
Port int
|
||||
RequiredKey string
|
||||
DefaultModel string
|
||||
Mode string
|
||||
Force bool
|
||||
ApproveMcps bool
|
||||
StrictModel bool
|
||||
Workspace string
|
||||
TimeoutMs int
|
||||
TLSCertPath string
|
||||
TLSKeyPath string
|
||||
SessionsLogPath string
|
||||
ChatOnlyWorkspace bool
|
||||
Verbose bool
|
||||
MaxMode bool
|
||||
ConfigDirs []string
|
||||
MultiPort bool
|
||||
WinCmdlineMax int
|
||||
}
|
||||
|
||||
func LoadBridgeConfig(e env.EnvSource, cwd string) BridgeConfig {
|
||||
loaded := env.LoadEnvConfig(e, cwd)
|
||||
return BridgeConfig{
|
||||
AgentBin: loaded.AgentBin,
|
||||
Host: loaded.Host,
|
||||
Port: loaded.Port,
|
||||
RequiredKey: loaded.RequiredKey,
|
||||
DefaultModel: loaded.DefaultModel,
|
||||
Mode: "ask",
|
||||
Force: loaded.Force,
|
||||
ApproveMcps: loaded.ApproveMcps,
|
||||
StrictModel: loaded.StrictModel,
|
||||
Workspace: loaded.Workspace,
|
||||
TimeoutMs: loaded.TimeoutMs,
|
||||
TLSCertPath: loaded.TLSCertPath,
|
||||
TLSKeyPath: loaded.TLSKeyPath,
|
||||
SessionsLogPath: loaded.SessionsLogPath,
|
||||
ChatOnlyWorkspace: loaded.ChatOnlyWorkspace,
|
||||
Verbose: loaded.Verbose,
|
||||
MaxMode: loaded.MaxMode,
|
||||
ConfigDirs: loaded.ConfigDirs,
|
||||
MultiPort: loaded.MultiPort,
|
||||
WinCmdlineMax: loaded.WinCmdlineMax,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
package config_test
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadBridgeConfig_Defaults(t *testing.T) {
|
||||
cfg := config.LoadBridgeConfig(env.EnvSource{}, "/workspace")
|
||||
|
||||
if cfg.AgentBin != "agent" {
|
||||
t.Errorf("AgentBin = %q, want %q", cfg.AgentBin, "agent")
|
||||
}
|
||||
if cfg.Host != "127.0.0.1" {
|
||||
t.Errorf("Host = %q, want %q", cfg.Host, "127.0.0.1")
|
||||
}
|
||||
if cfg.Port != 8765 {
|
||||
t.Errorf("Port = %d, want 8765", cfg.Port)
|
||||
}
|
||||
if cfg.RequiredKey != "" {
|
||||
t.Errorf("RequiredKey = %q, want empty", cfg.RequiredKey)
|
||||
}
|
||||
if cfg.DefaultModel != "auto" {
|
||||
t.Errorf("DefaultModel = %q, want %q", cfg.DefaultModel, "auto")
|
||||
}
|
||||
if cfg.Force {
|
||||
t.Error("Force should be false")
|
||||
}
|
||||
if cfg.ApproveMcps {
|
||||
t.Error("ApproveMcps should be false")
|
||||
}
|
||||
if !cfg.StrictModel {
|
||||
t.Error("StrictModel should be true")
|
||||
}
|
||||
if cfg.Mode != "ask" {
|
||||
t.Errorf("Mode = %q, want %q", cfg.Mode, "ask")
|
||||
}
|
||||
if cfg.Workspace != "/workspace" {
|
||||
t.Errorf("Workspace = %q, want /workspace", cfg.Workspace)
|
||||
}
|
||||
if !cfg.ChatOnlyWorkspace {
|
||||
t.Error("ChatOnlyWorkspace should be true")
|
||||
}
|
||||
if cfg.WinCmdlineMax != 30000 {
|
||||
t.Errorf("WinCmdlineMax = %d, want 30000", cfg.WinCmdlineMax)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBridgeConfig_FromEnv(t *testing.T) {
|
||||
e := env.EnvSource{
|
||||
"CURSOR_AGENT_BIN": "/usr/bin/agent",
|
||||
"CURSOR_BRIDGE_HOST": "0.0.0.0",
|
||||
"CURSOR_BRIDGE_PORT": "9999",
|
||||
"CURSOR_BRIDGE_API_KEY": "sk-secret",
|
||||
"CURSOR_BRIDGE_DEFAULT_MODEL": "org/claude-3-opus",
|
||||
"CURSOR_BRIDGE_FORCE": "true",
|
||||
"CURSOR_BRIDGE_APPROVE_MCPS": "yes",
|
||||
"CURSOR_BRIDGE_STRICT_MODEL": "false",
|
||||
"CURSOR_BRIDGE_WORKSPACE": "./my-workspace",
|
||||
"CURSOR_BRIDGE_TIMEOUT_MS": "60000",
|
||||
"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE": "false",
|
||||
"CURSOR_BRIDGE_VERBOSE": "1",
|
||||
"CURSOR_BRIDGE_TLS_CERT": "./certs/test.crt",
|
||||
"CURSOR_BRIDGE_TLS_KEY": "./certs/test.key",
|
||||
}
|
||||
cfg := config.LoadBridgeConfig(e, "/tmp/project")
|
||||
|
||||
if cfg.AgentBin != "/usr/bin/agent" {
|
||||
t.Errorf("AgentBin = %q, want /usr/bin/agent", cfg.AgentBin)
|
||||
}
|
||||
if cfg.Host != "0.0.0.0" {
|
||||
t.Errorf("Host = %q, want 0.0.0.0", cfg.Host)
|
||||
}
|
||||
if cfg.Port != 9999 {
|
||||
t.Errorf("Port = %d, want 9999", cfg.Port)
|
||||
}
|
||||
if cfg.RequiredKey != "sk-secret" {
|
||||
t.Errorf("RequiredKey = %q, want sk-secret", cfg.RequiredKey)
|
||||
}
|
||||
if cfg.DefaultModel != "claude-3-opus" {
|
||||
t.Errorf("DefaultModel = %q, want claude-3-opus", cfg.DefaultModel)
|
||||
}
|
||||
if !cfg.Force {
|
||||
t.Error("Force should be true")
|
||||
}
|
||||
if !cfg.ApproveMcps {
|
||||
t.Error("ApproveMcps should be true")
|
||||
}
|
||||
if cfg.StrictModel {
|
||||
t.Error("StrictModel should be false")
|
||||
}
|
||||
if !filepath.IsAbs(cfg.Workspace) {
|
||||
t.Errorf("Workspace should be absolute, got %q", cfg.Workspace)
|
||||
}
|
||||
if !strings.Contains(cfg.Workspace, "my-workspace") {
|
||||
t.Errorf("Workspace %q should contain 'my-workspace'", cfg.Workspace)
|
||||
}
|
||||
if cfg.TimeoutMs != 60000 {
|
||||
t.Errorf("TimeoutMs = %d, want 60000", cfg.TimeoutMs)
|
||||
}
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
t.Error("ChatOnlyWorkspace should be false")
|
||||
}
|
||||
if !cfg.Verbose {
|
||||
t.Error("Verbose should be true")
|
||||
}
|
||||
if cfg.TLSCertPath != "/tmp/project/certs/test.crt" {
|
||||
t.Errorf("TLSCertPath = %q, want /tmp/project/certs/test.crt", cfg.TLSCertPath)
|
||||
}
|
||||
if cfg.TLSKeyPath != "/tmp/project/certs/test.key" {
|
||||
t.Errorf("TLSKeyPath = %q, want /tmp/project/certs/test.key", cfg.TLSKeyPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBridgeConfig_WideHost(t *testing.T) {
|
||||
cfg := config.LoadBridgeConfig(env.EnvSource{"CURSOR_BRIDGE_HOST": "0.0.0.0"}, "/workspace")
|
||||
if cfg.Host != "0.0.0.0" {
|
||||
t.Errorf("Host = %q, want 0.0.0.0", cfg.Host)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
package env
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EnvSource map[string]string
|
||||
|
||||
type LoadedEnv struct {
|
||||
AgentBin string
|
||||
AgentNode string
|
||||
AgentScript string
|
||||
CommandShell string
|
||||
Host string
|
||||
Port int
|
||||
RequiredKey string
|
||||
DefaultModel string
|
||||
Force bool
|
||||
ApproveMcps bool
|
||||
StrictModel bool
|
||||
Workspace string
|
||||
TimeoutMs int
|
||||
TLSCertPath string
|
||||
TLSKeyPath string
|
||||
SessionsLogPath string
|
||||
ChatOnlyWorkspace bool
|
||||
Verbose bool
|
||||
MaxMode bool
|
||||
ConfigDirs []string
|
||||
MultiPort bool
|
||||
WinCmdlineMax int
|
||||
}
|
||||
|
||||
type AgentCommand struct {
|
||||
Command string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
WindowsVerbatimArguments bool
|
||||
AgentScriptPath string
|
||||
ConfigDir string
|
||||
}
|
||||
|
||||
func getEnvVal(e EnvSource, names []string) string {
|
||||
for _, name := range names {
|
||||
if v, ok := e[name]; ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func envBool(e EnvSource, names []string, def bool) bool {
|
||||
raw := getEnvVal(e, names)
|
||||
if raw == "" {
|
||||
return def
|
||||
}
|
||||
switch strings.ToLower(raw) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func envInt(e EnvSource, names []string, def int) int {
|
||||
raw := getEnvVal(e, names)
|
||||
if raw == "" {
|
||||
return def
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func normalizeModelId(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "auto"
|
||||
}
|
||||
parts := strings.Split(raw, "/")
|
||||
last := parts[len(parts)-1]
|
||||
if last == "" {
|
||||
return "auto"
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func resolveAbs(raw, cwd string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if filepath.IsAbs(raw) {
|
||||
return raw
|
||||
}
|
||||
return filepath.Join(cwd, raw)
|
||||
}
|
||||
|
||||
func isAuthenticatedAccountDir(dir string) bool {
|
||||
data, err := os.ReadFile(filepath.Join(dir, "cli-config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
var cfg struct {
|
||||
AuthInfo *struct {
|
||||
Email string `json:"email"`
|
||||
} `json:"authInfo"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
return cfg.AuthInfo != nil && cfg.AuthInfo.Email != ""
|
||||
}
|
||||
|
||||
func discoverAccountDirs(homeDir string) []string {
|
||||
if homeDir == "" {
|
||||
return nil
|
||||
}
|
||||
accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts")
|
||||
entries, err := os.ReadDir(accountsDir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var dirs []string
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
dir := filepath.Join(accountsDir, e.Name())
|
||||
if isAuthenticatedAccountDir(dir) {
|
||||
dirs = append(dirs, dir)
|
||||
}
|
||||
}
|
||||
return dirs
|
||||
}
|
||||
|
||||
func OsEnvToMap() EnvSource {
|
||||
m := make(EnvSource)
|
||||
for _, kv := range os.Environ() {
|
||||
parts := strings.SplitN(kv, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
m[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv {
|
||||
if e == nil {
|
||||
e = OsEnvToMap()
|
||||
}
|
||||
if cwd == "" {
|
||||
var err error
|
||||
cwd, err = os.Getwd()
|
||||
if err != nil {
|
||||
cwd = "."
|
||||
}
|
||||
}
|
||||
|
||||
host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"})
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765)
|
||||
if port <= 0 {
|
||||
port = 8765
|
||||
}
|
||||
|
||||
home := getEnvVal(e, []string{"HOME", "USERPROFILE"})
|
||||
|
||||
sessionsLogPath := func() string {
|
||||
if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" {
|
||||
return p
|
||||
}
|
||||
if home != "" {
|
||||
return filepath.Join(home, ".cursor-api-proxy", "sessions.log")
|
||||
}
|
||||
return filepath.Join(cwd, "sessions.log")
|
||||
}()
|
||||
|
||||
var configDirs []string
|
||||
if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" {
|
||||
for _, d := range strings.Split(raw, ",") {
|
||||
d = strings.TrimSpace(d)
|
||||
if d != "" {
|
||||
if p := resolveAbs(d, cwd); p != "" {
|
||||
configDirs = append(configDirs, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(configDirs) == 0 {
|
||||
configDirs = discoverAccountDirs(home)
|
||||
}
|
||||
|
||||
winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000)
|
||||
if winMax < 4096 {
|
||||
winMax = 4096
|
||||
}
|
||||
if winMax > 32700 {
|
||||
winMax = 32700
|
||||
}
|
||||
|
||||
agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"})
|
||||
if agentBin == "" {
|
||||
agentBin = "agent"
|
||||
}
|
||||
commandShell := getEnvVal(e, []string{"COMSPEC"})
|
||||
if commandShell == "" {
|
||||
commandShell = "cmd.exe"
|
||||
}
|
||||
workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd)
|
||||
if workspace == "" {
|
||||
workspace = cwd
|
||||
}
|
||||
|
||||
return LoadedEnv{
|
||||
AgentBin: agentBin,
|
||||
AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}),
|
||||
AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}),
|
||||
CommandShell: commandShell,
|
||||
Host: host,
|
||||
Port: port,
|
||||
RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}),
|
||||
DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})),
|
||||
Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false),
|
||||
ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false),
|
||||
StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true),
|
||||
Workspace: workspace,
|
||||
TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000),
|
||||
TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd),
|
||||
TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd),
|
||||
SessionsLogPath: sessionsLogPath,
|
||||
ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true),
|
||||
Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false),
|
||||
MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false),
|
||||
ConfigDirs: configDirs,
|
||||
MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false),
|
||||
WinCmdlineMax: winMax,
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand {
|
||||
if e == nil {
|
||||
e = OsEnvToMap()
|
||||
}
|
||||
loaded := LoadEnvConfig(e, cwd)
|
||||
|
||||
cloneEnv := func() map[string]string {
|
||||
m := make(map[string]string, len(e))
|
||||
for k, v := range e {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if loaded.AgentNode != "" && loaded.AgentScript != "" {
|
||||
agentScriptPath := loaded.AgentScript
|
||||
if !filepath.IsAbs(agentScriptPath) {
|
||||
agentScriptPath = filepath.Join(cwd, agentScriptPath)
|
||||
}
|
||||
agentDir := filepath.Dir(agentScriptPath)
|
||||
configDir := filepath.Join(agentDir, "..", "data", "config")
|
||||
env2 := cloneEnv()
|
||||
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
|
||||
ac := AgentCommand{
|
||||
Command: loaded.AgentNode,
|
||||
Args: append([]string{loaded.AgentScript}, args...),
|
||||
Env: env2,
|
||||
AgentScriptPath: agentScriptPath,
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
|
||||
ac.ConfigDir = configDir
|
||||
}
|
||||
return ac
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.ToLower(cmd), ".cmd") {
|
||||
cmdResolved := cmd
|
||||
if !filepath.IsAbs(cmd) {
|
||||
cmdResolved = filepath.Join(cwd, cmd)
|
||||
}
|
||||
dir := filepath.Dir(cmdResolved)
|
||||
nodeBin := filepath.Join(dir, "node.exe")
|
||||
script := filepath.Join(dir, "index.js")
|
||||
if _, err1 := os.Stat(nodeBin); err1 == nil {
|
||||
if _, err2 := os.Stat(script); err2 == nil {
|
||||
configDir := filepath.Join(dir, "..", "data", "config")
|
||||
env2 := cloneEnv()
|
||||
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
|
||||
ac := AgentCommand{
|
||||
Command: nodeBin,
|
||||
Args: append([]string{script}, args...),
|
||||
Env: env2,
|
||||
AgentScriptPath: script,
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
|
||||
ac.ConfigDir = configDir
|
||||
}
|
||||
return ac
|
||||
}
|
||||
}
|
||||
|
||||
quotedArgs := make([]string, len(args))
|
||||
for i, a := range args {
|
||||
if strings.Contains(a, " ") {
|
||||
quotedArgs[i] = `"` + a + `"`
|
||||
} else {
|
||||
quotedArgs[i] = a
|
||||
}
|
||||
}
|
||||
cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"`
|
||||
return AgentCommand{
|
||||
Command: loaded.CommandShell,
|
||||
Args: []string{"/d", "/s", "/c", cmdLine},
|
||||
Env: cloneEnv(),
|
||||
WindowsVerbatimArguments: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()}
|
||||
}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
package env
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLoadEnvConfigDefaults(t *testing.T) {
|
||||
e := EnvSource{}
|
||||
loaded := LoadEnvConfig(e, "/tmp")
|
||||
|
||||
if loaded.Host != "127.0.0.1" {
|
||||
t.Errorf("expected 127.0.0.1, got %s", loaded.Host)
|
||||
}
|
||||
if loaded.Port != 8765 {
|
||||
t.Errorf("expected 8765, got %d", loaded.Port)
|
||||
}
|
||||
if loaded.DefaultModel != "auto" {
|
||||
t.Errorf("expected auto, got %s", loaded.DefaultModel)
|
||||
}
|
||||
if loaded.AgentBin != "agent" {
|
||||
t.Errorf("expected agent, got %s", loaded.AgentBin)
|
||||
}
|
||||
if !loaded.StrictModel {
|
||||
t.Error("expected strictModel=true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEnvConfigOverride(t *testing.T) {
|
||||
e := EnvSource{
|
||||
"CURSOR_BRIDGE_HOST": "0.0.0.0",
|
||||
"CURSOR_BRIDGE_PORT": "9000",
|
||||
"CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4",
|
||||
"CURSOR_AGENT_BIN": "/usr/local/bin/agent",
|
||||
}
|
||||
loaded := LoadEnvConfig(e, "/tmp")
|
||||
|
||||
if loaded.Host != "0.0.0.0" {
|
||||
t.Errorf("expected 0.0.0.0, got %s", loaded.Host)
|
||||
}
|
||||
if loaded.Port != 9000 {
|
||||
t.Errorf("expected 9000, got %d", loaded.Port)
|
||||
}
|
||||
if loaded.DefaultModel != "gpt-4" {
|
||||
t.Errorf("expected gpt-4, got %s", loaded.DefaultModel)
|
||||
}
|
||||
if loaded.AgentBin != "/usr/local/bin/agent" {
|
||||
t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"gpt-4", "gpt-4"},
|
||||
{"openai/gpt-4", "gpt-4"},
|
||||
{"", "auto"},
|
||||
{" ", "auto"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := normalizeModelId(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/anthropic"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/models"
|
||||
"cursor-api-proxy/internal/openai"
|
||||
"cursor-api-proxy/internal/parser"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"cursor-api-proxy/internal/sanitize"
|
||||
"cursor-api-proxy/internal/winlimit"
|
||||
"cursor-api-proxy/internal/workspace"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||||
var req anthropic.MessagesRequest
|
||||
if err := json.Unmarshal([]byte(rawBody), &req); err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
requested := openai.NormalizeModelID(req.Model)
|
||||
model := ResolveModel(requested, lastModelRef, cfg)
|
||||
|
||||
// Parse system from raw body to handle both string and array
|
||||
var rawMap map[string]interface{}
|
||||
_ = json.Unmarshal([]byte(rawBody), &rawMap)
|
||||
|
||||
cleanSystem := sanitize.SanitizeSystem(req.System)
|
||||
|
||||
// SanitizeMessages expects []interface{}
|
||||
rawMessages := make([]interface{}, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
||||
}
|
||||
cleanRawMessages := sanitize.SanitizeMessages(rawMessages)
|
||||
|
||||
var cleanMessages []anthropic.MessageParam
|
||||
for _, raw := range cleanRawMessages {
|
||||
if m, ok := raw.(map[string]interface{}); ok {
|
||||
role, _ := m["role"].(string)
|
||||
cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]})
|
||||
}
|
||||
}
|
||||
|
||||
toolsText := openai.ToolsToSystemText(req.Tools, nil)
|
||||
var systemWithTools interface{}
|
||||
if toolsText != "" {
|
||||
sysStr := ""
|
||||
switch v := cleanSystem.(type) {
|
||||
case string:
|
||||
sysStr = v
|
||||
}
|
||||
if sysStr != "" {
|
||||
systemWithTools = sysStr + "\n\n" + toolsText
|
||||
} else {
|
||||
systemWithTools = toolsText
|
||||
}
|
||||
} else {
|
||||
systemWithTools = cleanSystem
|
||||
}
|
||||
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools)
|
||||
|
||||
if req.MaxTokens == 0 {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
cursorModel := models.ResolveToCursorModel(model)
|
||||
if cursorModel == "" {
|
||||
cursorModel = model
|
||||
}
|
||||
|
||||
var trafficMsgs []logger.TrafficMessage
|
||||
if s := systemToString(cleanSystem); s != "" {
|
||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s})
|
||||
}
|
||||
for _, m := range cleanMessages {
|
||||
text := contentToString(m.Content)
|
||||
if text != "" {
|
||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text})
|
||||
}
|
||||
}
|
||||
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream)
|
||||
|
||||
headerWs := r.Header.Get("x-cursor-workspace")
|
||||
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
||||
|
||||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream)
|
||||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
||||
|
||||
if !fit.OK {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": fit.Error},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
if fit.Truncated {
|
||||
fmt.Printf("[%s] Windows: prompt truncated (%d -> %d chars).\n",
|
||||
time.Now().UTC().Format(time.RFC3339), fit.OriginalLength, fit.FinalPromptLength)
|
||||
}
|
||||
|
||||
cmdArgs := fit.Args
|
||||
msgID := "msg_" + uuid.New().String()
|
||||
|
||||
var truncatedHeaders map[string]string
|
||||
if fit.Truncated {
|
||||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
writeEvent := func(evt interface{}) {
|
||||
data, _ := json.Marshal(evt)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []interface{}{},
|
||||
},
|
||||
})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
|
||||
var accumulated string
|
||||
parseLine := parser.CreateStreamParser(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": 0})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": 0},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||||
},
|
||||
)
|
||||
|
||||
configDir := pool.GetNextAccountConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
pool.ReportRequestStart(configDir)
|
||||
streamStart := time.Now().UnixMilli()
|
||||
|
||||
ctx := r.Context()
|
||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx)
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - streamStart
|
||||
pool.ReportRequestEnd(configDir)
|
||||
|
||||
if err == nil && isRateLimited(result.Stderr) {
|
||||
pool.ReportRateLimit(configDir, 60000)
|
||||
}
|
||||
if err != nil || result.Code != 0 {
|
||||
pool.ReportRequestError(configDir, latencyMs)
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
} else {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
||||
}
|
||||
} else {
|
||||
pool.ReportRequestSuccess(configDir, latencyMs)
|
||||
}
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
return
|
||||
}
|
||||
|
||||
configDir := pool.GetNextAccountConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
pool.ReportRequestStart(configDir)
|
||||
syncStart := time.Now().UnixMilli()
|
||||
|
||||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
||||
syncLatency := time.Now().UnixMilli() - syncStart
|
||||
pool.ReportRequestEnd(configDir)
|
||||
|
||||
if err != nil {
|
||||
pool.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": err.Error()},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if isRateLimited(out.Stderr) {
|
||||
pool.ReportRateLimit(configDir, 60000)
|
||||
}
|
||||
|
||||
if out.Code != 0 {
|
||||
pool.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": errMsg},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
pool.ReportRequestSuccess(configDir, syncLatency)
|
||||
content := trimSpace(out.Stdout)
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []map[string]string{{"type": "text", "text": content}},
|
||||
"model": model,
|
||||
"stop_reason": "end_turn",
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
}
|
||||
|
||||
func systemToString(system interface{}) string {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
result := ""
|
||||
for _, p := range v {
|
||||
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
|
||||
if t, ok := m["text"].(string); ok {
|
||||
result += t
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func contentToString(content interface{}) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
result := ""
|
||||
for _, p := range v {
|
||||
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
|
||||
if t, ok := m["text"].(string); ok {
|
||||
result += t
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/models"
|
||||
"cursor-api-proxy/internal/openai"
|
||||
"cursor-api-proxy/internal/parser"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"cursor-api-proxy/internal/sanitize"
|
||||
"cursor-api-proxy/internal/winlimit"
|
||||
"cursor-api-proxy/internal/workspace"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`)
|
||||
|
||||
func isRateLimited(stderr string) bool {
|
||||
return rateLimitRe.MatchString(stderr)
|
||||
}
|
||||
|
||||
func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||||
var bodyMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
rawModel, _ := bodyMap["model"].(string)
|
||||
requested := openai.NormalizeModelID(rawModel)
|
||||
model := ResolveModel(requested, lastModelRef, cfg)
|
||||
cursorModel := models.ResolveToCursorModel(model)
|
||||
if cursorModel == "" {
|
||||
cursorModel = model
|
||||
}
|
||||
|
||||
var messages []interface{}
|
||||
if m, ok := bodyMap["messages"].([]interface{}); ok {
|
||||
messages = m
|
||||
}
|
||||
|
||||
var tools []interface{}
|
||||
if t, ok := bodyMap["tools"].([]interface{}); ok {
|
||||
tools = t
|
||||
}
|
||||
var funcs []interface{}
|
||||
if f, ok := bodyMap["functions"].([]interface{}); ok {
|
||||
funcs = f
|
||||
}
|
||||
|
||||
cleanMessages := sanitize.SanitizeMessages(messages)
|
||||
|
||||
toolsText := openai.ToolsToSystemText(tools, funcs)
|
||||
messagesWithTools := cleanMessages
|
||||
if toolsText != "" {
|
||||
messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...)
|
||||
}
|
||||
prompt := openai.BuildPromptFromMessages(messagesWithTools)
|
||||
|
||||
var trafficMsgs []logger.TrafficMessage
|
||||
for _, raw := range cleanMessages {
|
||||
if m, ok := raw.(map[string]interface{}); ok {
|
||||
role, _ := m["role"].(string)
|
||||
content := openai.MessageContentToText(m["content"])
|
||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
|
||||
}
|
||||
}
|
||||
|
||||
isStream := false
|
||||
if s, ok := bodyMap["stream"].(bool); ok {
|
||||
isStream = s
|
||||
}
|
||||
|
||||
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream)
|
||||
|
||||
headerWs := r.Header.Get("x-cursor-workspace")
|
||||
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
||||
|
||||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream)
|
||||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
||||
|
||||
if !fit.OK {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
if fit.Truncated {
|
||||
fmt.Printf("[%s] Windows: prompt truncated for CreateProcess limit (%d -> %d chars, tail preserved).\n",
|
||||
time.Now().UTC().Format(time.RFC3339), fit.OriginalLength, fit.FinalPromptLength)
|
||||
}
|
||||
|
||||
cmdArgs := fit.Args
|
||||
id := "chatcmpl_" + uuid.New().String()
|
||||
created := time.Now().Unix()
|
||||
|
||||
var truncatedHeaders map[string]string
|
||||
if fit.Truncated {
|
||||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
||||
}
|
||||
|
||||
if isStream {
|
||||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
var accumulated string
|
||||
parseLine := parser.CreateStreamParser(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
configDir := pool.GetNextAccountConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
pool.ReportRequestStart(configDir)
|
||||
streamStart := time.Now().UnixMilli()
|
||||
|
||||
ctx := r.Context()
|
||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx)
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - streamStart
|
||||
pool.ReportRequestEnd(configDir)
|
||||
|
||||
if err == nil && isRateLimited(result.Stderr) {
|
||||
pool.ReportRateLimit(configDir, 60000)
|
||||
}
|
||||
|
||||
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
||||
pool.ReportRequestError(configDir, latencyMs)
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
} else {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
||||
}
|
||||
} else {
|
||||
pool.ReportRequestSuccess(configDir, latencyMs)
|
||||
}
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
return
|
||||
}
|
||||
|
||||
configDir := pool.GetNextAccountConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
pool.ReportRequestStart(configDir)
|
||||
syncStart := time.Now().UnixMilli()
|
||||
|
||||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
||||
syncLatency := time.Now().UnixMilli() - syncStart
|
||||
pool.ReportRequestEnd(configDir)
|
||||
|
||||
if err != nil {
|
||||
pool.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if isRateLimited(out.Stderr) {
|
||||
pool.ReportRateLimit(configDir, 60000)
|
||||
}
|
||||
|
||||
if out.Code != 0 {
|
||||
pool.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": errMsg, "code": "cursor_cli_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
pool.ReportRequestSuccess(configDir, syncLatency)
|
||||
content := trimSpace(out.Stdout)
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"message": map[string]string{"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
}
|
||||
|
||||
func trimSpace(s string) string {
|
||||
result := ""
|
||||
start := 0
|
||||
end := len(s)
|
||||
for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') {
|
||||
start++
|
||||
}
|
||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') {
|
||||
end--
|
||||
}
|
||||
result = s[start:end]
|
||||
return result
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func HandleHealth(w http.ResponseWriter, r *http.Request, version string, cfg config.BridgeConfig) {
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"ok": true,
|
||||
"version": version,
|
||||
"workspace": cfg.Workspace,
|
||||
"mode": cfg.Mode,
|
||||
"defaultModel": cfg.DefaultModel,
|
||||
"force": cfg.Force,
|
||||
"approveMcps": cfg.ApproveMcps,
|
||||
"strictModel": cfg.StrictModel,
|
||||
}, nil)
|
||||
}
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/models"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const modelCacheTTLMs = 5 * 60 * 1000
|
||||
|
||||
type ModelCache struct {
|
||||
At int64
|
||||
Models []models.CursorCliModel
|
||||
}
|
||||
|
||||
type ModelCacheRef struct {
|
||||
mu sync.Mutex
|
||||
cache *ModelCache
|
||||
inflight bool
|
||||
waiters []chan struct{}
|
||||
}
|
||||
|
||||
func (ref *ModelCacheRef) HandleModels(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig) {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
ref.mu.Lock()
|
||||
if ref.cache != nil && now-ref.cache.At <= modelCacheTTLMs {
|
||||
cache := ref.cache
|
||||
ref.mu.Unlock()
|
||||
writeModels(w, cache.Models)
|
||||
return
|
||||
}
|
||||
|
||||
if ref.inflight {
|
||||
// Wait for the in-flight fetch
|
||||
ch := make(chan struct{}, 1)
|
||||
ref.waiters = append(ref.waiters, ch)
|
||||
ref.mu.Unlock()
|
||||
<-ch
|
||||
ref.mu.Lock()
|
||||
cache := ref.cache
|
||||
ref.mu.Unlock()
|
||||
writeModels(w, cache.Models)
|
||||
return
|
||||
}
|
||||
|
||||
ref.inflight = true
|
||||
ref.mu.Unlock()
|
||||
|
||||
fetched, err := models.ListCursorCliModels(cfg.AgentBin, 60000)
|
||||
|
||||
ref.mu.Lock()
|
||||
ref.inflight = false
|
||||
if err == nil {
|
||||
ref.cache = &ModelCache{At: time.Now().UnixMilli(), Models: fetched}
|
||||
}
|
||||
waiters := ref.waiters
|
||||
ref.waiters = nil
|
||||
ref.mu.Unlock()
|
||||
|
||||
for _, ch := range waiters {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": err.Error(), "code": "models_fetch_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
writeModels(w, fetched)
|
||||
}
|
||||
|
||||
func writeModels(w http.ResponseWriter, mods []models.CursorCliModel) {
|
||||
cursorModels := make([]map[string]interface{}, len(mods))
|
||||
for i, m := range mods {
|
||||
cursorModels[i] = map[string]interface{}{
|
||||
"id": m.ID,
|
||||
"object": "model",
|
||||
"owned_by": "cursor",
|
||||
"name": m.Name,
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, len(mods))
|
||||
for i, m := range mods {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
aliases := models.GetAnthropicModelAliases(ids)
|
||||
for _, a := range aliases {
|
||||
cursorModels = append(cursorModels, map[string]interface{}{
|
||||
"id": a.ID,
|
||||
"object": "model",
|
||||
"owned_by": "cursor",
|
||||
"name": a.Name,
|
||||
})
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"object": "list",
|
||||
"data": cursorModels,
|
||||
}, nil)
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package handlers
|
||||
|
||||
import "cursor-api-proxy/internal/config"
|
||||
|
||||
func ResolveModel(requested string, lastModelRef *string, cfg config.BridgeConfig) string {
|
||||
isAuto := requested == "auto"
|
||||
var explicitModel string
|
||||
if requested != "" && !isAuto {
|
||||
explicitModel = requested
|
||||
}
|
||||
if explicitModel != "" {
|
||||
*lastModelRef = explicitModel
|
||||
}
|
||||
if isAuto {
|
||||
return "auto"
|
||||
}
|
||||
if explicitModel != "" {
|
||||
return explicitModel
|
||||
}
|
||||
if cfg.StrictModel && *lastModelRef != "" {
|
||||
return *lastModelRef
|
||||
}
|
||||
if *lastModelRef != "" {
|
||||
return *lastModelRef
|
||||
}
|
||||
return cfg.DefaultModel
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`)
|
||||
|
||||
func ExtractBearerToken(r *http.Request) string {
|
||||
h := r.Header.Get("Authorization")
|
||||
if h == "" {
|
||||
return ""
|
||||
}
|
||||
m := bearerRe.FindStringSubmatch(h)
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m[1]
|
||||
}
|
||||
|
||||
func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) {
|
||||
for k, v := range extraHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) {
|
||||
for k, v := range extraHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
|
||||
func ReadBody(r *http.Request) (string, error) {
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractBearerToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{"Bearer mytoken123", "mytoken123"},
|
||||
{"bearer MYTOKEN", "MYTOKEN"},
|
||||
{"", ""},
|
||||
{"Basic abc", ""},
|
||||
{"Bearer ", ""},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.header != "" {
|
||||
req.Header.Set("Authorization", tc.header)
|
||||
}
|
||||
got := ExtractBearerToken(req)
|
||||
if got != tc.want {
|
||||
t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
WriteJSON(w, 200, map[string]string{"ok": "true"}, nil)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Content-Type") != "application/json" {
|
||||
t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSONWithExtraHeaders(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"})
|
||||
|
||||
if w.Header().Get("X-Custom") != "value" {
|
||||
t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom"))
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
cReset = "\x1b[0m"
|
||||
cBold = "\x1b[1m"
|
||||
cDim = "\x1b[2m"
|
||||
cCyan = "\x1b[36m"
|
||||
cBCyan = "\x1b[1;96m"
|
||||
cGreen = "\x1b[32m"
|
||||
cBGreen = "\x1b[1;92m"
|
||||
cYellow = "\x1b[33m"
|
||||
cMagenta = "\x1b[35m"
|
||||
cBMagenta = "\x1b[1;95m"
|
||||
cRed = "\x1b[31m"
|
||||
cGray = "\x1b[90m"
|
||||
cWhite = "\x1b[97m"
|
||||
)
|
||||
|
||||
var roleStyle = map[string]string{
|
||||
"system": cYellow,
|
||||
"user": cCyan,
|
||||
"assistant": cGreen,
|
||||
}
|
||||
|
||||
var roleEmoji = map[string]string{
|
||||
"system": "🔧",
|
||||
"user": "👤",
|
||||
"assistant": "🤖",
|
||||
}
|
||||
|
||||
func ts() string {
|
||||
return cGray + time.Now().UTC().Format(time.RFC3339Nano) + cReset
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
head := int(float64(max) * 0.6)
|
||||
tail := max - head
|
||||
omitted := len(s) - head - tail
|
||||
return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset
|
||||
}
|
||||
|
||||
func hr(ch string, length int) string {
|
||||
return cGray + strings.Repeat(ch, length) + cReset
|
||||
}
|
||||
|
||||
type TrafficMessage struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
func LogIncoming(method, pathname, remoteAddress string) {
|
||||
fmt.Printf("[%s] Incoming: %s %s (from %s)\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress)
|
||||
}
|
||||
|
||||
func LogAccountAssigned(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
name := filepath.Base(configDir)
|
||||
fmt.Printf("[%s] %s→ account%s %s%s%s\n", time.Now().UTC().Format(time.RFC3339), cBCyan, cReset, cBold, name, cReset)
|
||||
}
|
||||
|
||||
func LogAccountStats(verbose bool, stats []pool.AccountStat) {
|
||||
if !verbose || len(stats) == 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset)
|
||||
for _, s := range stats {
|
||||
name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir))
|
||||
active := fmt.Sprintf("%sdim%sactive:0%s", cDim, "", cReset)
|
||||
if s.ActiveRequests > 0 {
|
||||
active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset)
|
||||
}
|
||||
total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset)
|
||||
ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset)
|
||||
errStr := fmt.Sprintf("%sdim%serr:0%s", cDim, "", cReset)
|
||||
if s.TotalErrors > 0 {
|
||||
errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset)
|
||||
}
|
||||
rl := fmt.Sprintf("%sdim%srl:0%s", cDim, "", cReset)
|
||||
if s.TotalRateLimits > 0 {
|
||||
rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset)
|
||||
}
|
||||
avg := "avg:-"
|
||||
if s.TotalRequests > 0 {
|
||||
avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests))
|
||||
}
|
||||
status := fmt.Sprintf("%s✓%s", cGreen, cReset)
|
||||
if s.IsRateLimited {
|
||||
recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339)
|
||||
_ = now
|
||||
status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset)
|
||||
}
|
||||
fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n",
|
||||
cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status)
|
||||
}
|
||||
fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset)
|
||||
}
|
||||
|
||||
func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) {
|
||||
if !verbose {
|
||||
return
|
||||
}
|
||||
modeTag := fmt.Sprintf("%sdim%ssync%s", cDim, "", cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
|
||||
}
|
||||
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
|
||||
fmt.Println(hr("─", 60))
|
||||
fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag)
|
||||
for _, m := range messages {
|
||||
roleColor := cWhite
|
||||
if c, ok := roleStyle[m.Role]; ok {
|
||||
roleColor = c
|
||||
}
|
||||
emoji := "💬"
|
||||
if e, ok := roleEmoji[m.Role]; ok {
|
||||
emoji = e
|
||||
}
|
||||
label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset)
|
||||
charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset)
|
||||
preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280)
|
||||
fmt.Printf(" %s %s %s\n", emoji, label, charCount)
|
||||
fmt.Printf(" %s%s%s\n", cDim, preview, cReset)
|
||||
}
|
||||
}
|
||||
|
||||
func LogTrafficResponse(verbose bool, model, text string, isStream bool) {
|
||||
if !verbose {
|
||||
return
|
||||
}
|
||||
modeTag := fmt.Sprintf("%sdim%ssync%s", cDim, "", cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset)
|
||||
}
|
||||
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
|
||||
charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset)
|
||||
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480)
|
||||
fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount)
|
||||
fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset)
|
||||
fmt.Println(hr("─", 60))
|
||||
}
|
||||
|
||||
func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) {
|
||||
line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode)
|
||||
dir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(dir, 0755); err == nil {
|
||||
f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err == nil {
|
||||
_, _ = f.WriteString(line)
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string {
|
||||
errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr))
|
||||
fmt.Fprintf(os.Stderr, "[%s] Agent error: %s\n", time.Now().UTC().Format(time.RFC3339), errMsg)
|
||||
truncated := strings.TrimSpace(stderr)
|
||||
if len(truncated) > 200 {
|
||||
truncated = truncated[:200]
|
||||
}
|
||||
truncated = strings.ReplaceAll(truncated, "\n", " ")
|
||||
line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n",
|
||||
time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated)
|
||||
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
|
||||
_, _ = f.WriteString(line)
|
||||
f.Close()
|
||||
}
|
||||
return errMsg
|
||||
}
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/process"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CursorCliModel struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
var modelLineRe = regexp.MustCompile(`^([A-Za-z0-9][A-Za-z0-9._:/-]*)\s+-\s+(.*)$`)
|
||||
var trailingParenRe = regexp.MustCompile(`\s*\([^)]*\)\s*$`)
|
||||
|
||||
func ParseCursorCliModels(output string) []CursorCliModel {
|
||||
lines := strings.Split(output, "\n")
|
||||
seen := make(map[string]CursorCliModel)
|
||||
var order []string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
m := modelLineRe.FindStringSubmatch(line)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
id := m[1]
|
||||
rawName := m[2]
|
||||
name := strings.TrimSpace(trailingParenRe.ReplaceAllString(rawName, ""))
|
||||
if name == "" {
|
||||
name = id
|
||||
}
|
||||
if _, exists := seen[id]; !exists {
|
||||
seen[id] = CursorCliModel{ID: id, Name: name}
|
||||
order = append(order, id)
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]CursorCliModel, 0, len(order))
|
||||
for _, id := range order {
|
||||
result = append(result, seen[id])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func ListCursorCliModels(agentBin string, timeoutMs int) ([]CursorCliModel, error) {
|
||||
tmpDir := os.TempDir()
|
||||
result, err := process.Run(agentBin, []string{"--list-models"}, process.RunOptions{
|
||||
Cwd: tmpDir,
|
||||
TimeoutMs: timeoutMs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("agent --list-models failed: %s", strings.TrimSpace(result.Stderr))
|
||||
}
|
||||
return ParseCursorCliModels(result.Stdout), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseCursorCliModels(t *testing.T) {
|
||||
output := `
|
||||
gpt-4o - GPT-4o (some info)
|
||||
claude-3-5-sonnet - Claude 3.5 Sonnet
|
||||
gpt-4o - GPT-4o duplicate
|
||||
invalid line without dash
|
||||
`
|
||||
result := ParseCursorCliModels(output)
|
||||
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 unique models, got %d: %v", len(result), result)
|
||||
}
|
||||
if result[0].ID != "gpt-4o" {
|
||||
t.Errorf("expected gpt-4o, got %s", result[0].ID)
|
||||
}
|
||||
if result[0].Name != "GPT-4o" {
|
||||
t.Errorf("expected 'GPT-4o', got %s", result[0].Name)
|
||||
}
|
||||
if result[1].ID != "claude-3-5-sonnet" {
|
||||
t.Errorf("expected claude-3-5-sonnet, got %s", result[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCursorCliModelsEmpty(t *testing.T) {
|
||||
result := ParseCursorCliModels("")
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("expected empty, got %v", result)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
package models
|
||||
|
||||
import "strings"
|
||||
|
||||
var anthropicToCursor = map[string]string{
|
||||
"claude-opus-4-6": "opus-4.6",
|
||||
"claude-opus-4.6": "opus-4.6",
|
||||
"claude-sonnet-4-6": "sonnet-4.6",
|
||||
"claude-sonnet-4.6": "sonnet-4.6",
|
||||
"claude-opus-4-5": "opus-4.5",
|
||||
"claude-opus-4.5": "opus-4.5",
|
||||
"claude-sonnet-4-5": "sonnet-4.5",
|
||||
"claude-sonnet-4.5": "sonnet-4.5",
|
||||
"claude-opus-4": "opus-4.6",
|
||||
"claude-sonnet-4": "sonnet-4.6",
|
||||
"claude-haiku-4-5-20251001": "sonnet-4.5",
|
||||
"claude-haiku-4-5": "sonnet-4.5",
|
||||
"claude-haiku-4-6": "sonnet-4.6",
|
||||
"claude-haiku-4": "sonnet-4.5",
|
||||
"claude-opus-4-6-thinking": "opus-4.6-thinking",
|
||||
"claude-sonnet-4-6-thinking": "sonnet-4.6-thinking",
|
||||
"claude-opus-4-5-thinking": "opus-4.5-thinking",
|
||||
"claude-sonnet-4-5-thinking": "sonnet-4.5-thinking",
|
||||
}
|
||||
|
||||
type ModelAlias struct {
|
||||
CursorID string
|
||||
AnthropicID string
|
||||
Name string
|
||||
}
|
||||
|
||||
var cursorToAnthropicAlias = []ModelAlias{
|
||||
{"opus-4.6", "claude-opus-4-6", "Claude 4.6 Opus"},
|
||||
{"opus-4.6-thinking", "claude-opus-4-6-thinking", "Claude 4.6 Opus (Thinking)"},
|
||||
{"sonnet-4.6", "claude-sonnet-4-6", "Claude 4.6 Sonnet"},
|
||||
{"sonnet-4.6-thinking", "claude-sonnet-4-6-thinking", "Claude 4.6 Sonnet (Thinking)"},
|
||||
{"opus-4.5", "claude-opus-4-5", "Claude 4.5 Opus"},
|
||||
{"opus-4.5-thinking", "claude-opus-4-5-thinking", "Claude 4.5 Opus (Thinking)"},
|
||||
{"sonnet-4.5", "claude-sonnet-4-5", "Claude 4.5 Sonnet"},
|
||||
{"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"},
|
||||
}
|
||||
|
||||
func ResolveToCursorModel(requested string) string {
|
||||
if strings.TrimSpace(requested) == "" {
|
||||
return ""
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(requested))
|
||||
if v, ok := anthropicToCursor[key]; ok {
|
||||
return v
|
||||
}
|
||||
return strings.TrimSpace(requested)
|
||||
}
|
||||
|
||||
type AnthropicAlias struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func GetAnthropicModelAliases(availableCursorIDs []string) []AnthropicAlias {
|
||||
set := make(map[string]bool, len(availableCursorIDs))
|
||||
for _, id := range availableCursorIDs {
|
||||
set[id] = true
|
||||
}
|
||||
var result []AnthropicAlias
|
||||
for _, a := range cursorToAnthropicAlias {
|
||||
if set[a.CursorID] {
|
||||
result = append(result, AnthropicAlias{ID: a.AnthropicID, Name: a.Name})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []interface{} `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []interface{} `json:"tools"`
|
||||
ToolChoice interface{} `json:"tool_choice"`
|
||||
Functions []interface{} `json:"functions"`
|
||||
FunctionCall interface{} `json:"function_call"`
|
||||
}
|
||||
|
||||
func NormalizeModelID(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(trimmed, "/")
|
||||
last := parts[len(parts)-1]
|
||||
if last == "" {
|
||||
return ""
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func imageURLToText(imageURL interface{}) string {
|
||||
if imageURL == nil {
|
||||
return "[Image]"
|
||||
}
|
||||
var url string
|
||||
switch v := imageURL.(type) {
|
||||
case string:
|
||||
url = v
|
||||
case map[string]interface{}:
|
||||
if u, ok := v["url"].(string); ok {
|
||||
url = u
|
||||
}
|
||||
}
|
||||
if url == "" {
|
||||
return "[Image]"
|
||||
}
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
end := strings.Index(url, ";")
|
||||
mime := "image"
|
||||
if end > 5 {
|
||||
mime = url[5:end]
|
||||
}
|
||||
return "[Image: base64 " + mime + "]"
|
||||
}
|
||||
return "[Image: " + url + "]"
|
||||
}
|
||||
|
||||
func MessageContentToText(content interface{}) string {
|
||||
if content == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
switch part := p.(type) {
|
||||
case string:
|
||||
parts = append(parts, part)
|
||||
case map[string]interface{}:
|
||||
typ, _ := part["type"].(string)
|
||||
switch typ {
|
||||
case "text":
|
||||
if t, ok := part["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
case "image_url":
|
||||
parts = append(parts, imageURLToText(part["image_url"]))
|
||||
case "image":
|
||||
src := part["source"]
|
||||
if src == nil {
|
||||
src = part["url"]
|
||||
}
|
||||
parts = append(parts, imageURLToText(src))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func ToolsToSystemText(tools []interface{}, functions []interface{}) string {
|
||||
var defs []interface{}
|
||||
|
||||
for _, t := range tools {
|
||||
if m, ok := t.(map[string]interface{}); ok {
|
||||
if m["type"] == "function" {
|
||||
if fn := m["function"]; fn != nil {
|
||||
defs = append(defs, fn)
|
||||
}
|
||||
} else {
|
||||
defs = append(defs, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
defs = append(defs, functions...)
|
||||
|
||||
if len(defs) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, "Available tools (respond with a JSON object to call one):", "")
|
||||
|
||||
for _, raw := range defs {
|
||||
fn, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
params := "{}"
|
||||
if p := fn["parameters"]; p != nil {
|
||||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
||||
params = string(b)
|
||||
}
|
||||
}
|
||||
lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params)
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
type SimpleMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func BuildPromptFromMessages(messages []interface{}) string {
|
||||
var systemParts []string
|
||||
var convo []string
|
||||
|
||||
for _, raw := range messages {
|
||||
m, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := m["role"].(string)
|
||||
text := MessageContentToText(m["content"])
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
switch role {
|
||||
case "system", "developer":
|
||||
systemParts = append(systemParts, text)
|
||||
case "user":
|
||||
convo = append(convo, "User: "+text)
|
||||
case "assistant":
|
||||
convo = append(convo, "Assistant: "+text)
|
||||
case "tool", "function":
|
||||
convo = append(convo, "Tool: "+text)
|
||||
}
|
||||
}
|
||||
|
||||
system := ""
|
||||
if len(systemParts) > 0 {
|
||||
system = "System:\n" + strings.Join(systemParts, "\n\n") + "\n\n"
|
||||
}
|
||||
transcript := strings.Join(convo, "\n\n")
|
||||
return system + transcript + "\n\nAssistant:"
|
||||
}
|
||||
|
||||
func BuildPromptFromSimpleMessages(messages []SimpleMessage) string {
|
||||
ifaces := make([]interface{}, len(messages))
|
||||
for i, m := range messages {
|
||||
ifaces[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
||||
}
|
||||
return BuildPromptFromMessages(ifaces)
|
||||
}
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"gpt-4", "gpt-4"},
|
||||
{"openai/gpt-4", "gpt-4"},
|
||||
{"anthropic/claude-3", "claude-3"},
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"a/b/c", "c"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := NormalizeModelID(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("NormalizeModelID(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromMessages(t *testing.T) {
|
||||
messages := []interface{}{
|
||||
map[string]interface{}{"role": "system", "content": "You are helpful."},
|
||||
map[string]interface{}{"role": "user", "content": "Hello"},
|
||||
map[string]interface{}{"role": "assistant", "content": "Hi there"},
|
||||
}
|
||||
got := BuildPromptFromMessages(messages)
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
containsSystem := false
|
||||
containsUser := false
|
||||
containsAssistant := false
|
||||
for i := 0; i < len(got)-10; i++ {
|
||||
if got[i:i+6] == "System" {
|
||||
containsSystem = true
|
||||
}
|
||||
if got[i:i+4] == "User" {
|
||||
containsUser = true
|
||||
}
|
||||
if got[i:i+9] == "Assistant" {
|
||||
containsAssistant = true
|
||||
}
|
||||
}
|
||||
if !containsSystem || !containsUser || !containsAssistant {
|
||||
t.Errorf("prompt missing sections: system=%v user=%v assistant=%v\n%s",
|
||||
containsSystem, containsUser, containsAssistant, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolsToSystemText(t *testing.T) {
|
||||
tools := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": map[string]interface{}{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := ToolsToSystemText(tools, nil)
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty tools text")
|
||||
}
|
||||
if len(got) < 10 {
|
||||
t.Errorf("tools text too short: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolsToSystemTextEmpty(t *testing.T) {
|
||||
got := ToolsToSystemText(nil, nil)
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for no tools, got %q", got)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
package parser
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type StreamParser func(line string)
|
||||
|
||||
func CreateStreamParser(onText func(string), onDone func()) StreamParser {
|
||||
accumulated := ""
|
||||
done := false
|
||||
|
||||
return func(line string) {
|
||||
if done {
|
||||
return
|
||||
}
|
||||
|
||||
var obj struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
Message *struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(line), &obj); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if obj.Type == "assistant" && obj.Message != nil {
|
||||
text := ""
|
||||
for _, p := range obj.Message.Content {
|
||||
if p.Type == "text" && p.Text != "" {
|
||||
text += p.Text
|
||||
}
|
||||
}
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
if text == accumulated {
|
||||
return
|
||||
}
|
||||
if len(accumulated) > 0 && len(text) > len(accumulated) && text[:len(accumulated)] == accumulated {
|
||||
delta := text[len(accumulated):]
|
||||
if delta != "" {
|
||||
onText(delta)
|
||||
}
|
||||
accumulated = text
|
||||
} else {
|
||||
onText(text)
|
||||
accumulated += text
|
||||
}
|
||||
}
|
||||
|
||||
if obj.Type == "result" && obj.Subtype == "success" {
|
||||
done = true
|
||||
onDone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeAssistantLine(text string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func makeResultLine() string {
|
||||
b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"})
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestStreamParserIncrementalDeltas(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
parse := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
parse(makeAssistantLine("Hello"))
|
||||
if len(texts) != 1 || texts[0] != "Hello" {
|
||||
t.Fatalf("expected [Hello], got %v", texts)
|
||||
}
|
||||
|
||||
parse(makeAssistantLine("Hello world"))
|
||||
if len(texts) != 2 || texts[1] != " world" {
|
||||
t.Fatalf("expected second call with ' world', got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserDeduplicatesFinalMessage(t *testing.T) {
|
||||
var texts []string
|
||||
parse := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
parse(makeAssistantLine("Hi"))
|
||||
parse(makeAssistantLine("Hi there"))
|
||||
if len(texts) != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d: %v", len(texts), texts)
|
||||
}
|
||||
if texts[0] != "Hi" || texts[1] != " there" {
|
||||
t.Fatalf("unexpected texts: %v", texts)
|
||||
}
|
||||
|
||||
// Final duplicate: full accumulated text again
|
||||
parse(makeAssistantLine("Hi there"))
|
||||
if len(texts) != 2 {
|
||||
t.Fatalf("expected no new call after duplicate, got %d: %v", len(texts), texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserCallsOnDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
parse := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
parse(makeResultLine())
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no text, got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresLinesAfterDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
parse := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
parse(makeResultLine())
|
||||
parse(makeAssistantLine("late"))
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no text after done, got %v", texts)
|
||||
}
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresNonAssistantLines(t *testing.T) {
|
||||
var texts []string
|
||||
parse := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}})
|
||||
parse(string(b1))
|
||||
b2, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{"content": []interface{}{}},
|
||||
})
|
||||
parse(string(b2))
|
||||
parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`)
|
||||
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no texts, got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresParseErrors(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
parse := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
parse("not json")
|
||||
parse("{")
|
||||
parse("")
|
||||
|
||||
if len(texts) != 0 || doneCount != 0 {
|
||||
t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserJoinsMultipleTextParts(t *testing.T) {
|
||||
var texts []string
|
||||
parse := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": " "},
|
||||
{"type": "text", "text": "world"},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
parse(string(b))
|
||||
|
||||
if len(texts) != 1 || texts[0] != "Hello world" {
|
||||
t.Fatalf("expected ['Hello world'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,259 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type accountStatus struct {
|
||||
configDir string
|
||||
activeRequests int
|
||||
lastUsed int64
|
||||
rateLimitUntil int64
|
||||
totalRequests int
|
||||
totalSuccess int
|
||||
totalErrors int
|
||||
totalRateLimits int
|
||||
totalLatencyMs int64
|
||||
}
|
||||
|
||||
type AccountStat struct {
|
||||
ConfigDir string
|
||||
ActiveRequests int
|
||||
TotalRequests int
|
||||
TotalSuccess int
|
||||
TotalErrors int
|
||||
TotalRateLimits int
|
||||
TotalLatencyMs int64
|
||||
IsRateLimited bool
|
||||
RateLimitUntil int64
|
||||
}
|
||||
|
||||
type AccountPool struct {
|
||||
mu sync.Mutex
|
||||
accounts []*accountStatus
|
||||
}
|
||||
|
||||
func NewAccountPool(configDirs []string) *AccountPool {
|
||||
accounts := make([]*accountStatus, 0, len(configDirs))
|
||||
for _, dir := range configDirs {
|
||||
accounts = append(accounts, &accountStatus{configDir: dir})
|
||||
}
|
||||
return &AccountPool{accounts: accounts}
|
||||
}
|
||||
|
||||
func (p *AccountPool) GetNextConfigDir() string {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.accounts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
available := make([]*accountStatus, 0, len(p.accounts))
|
||||
for _, a := range p.accounts {
|
||||
if a.rateLimitUntil < now {
|
||||
available = append(available, a)
|
||||
}
|
||||
}
|
||||
|
||||
target := available
|
||||
if len(target) == 0 {
|
||||
target = make([]*accountStatus, len(p.accounts))
|
||||
copy(target, p.accounts)
|
||||
// sort by earliest recovery
|
||||
for i := 1; i < len(target); i++ {
|
||||
for j := i; j > 0 && target[j].rateLimitUntil < target[j-1].rateLimitUntil; j-- {
|
||||
target[j], target[j-1] = target[j-1], target[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pick least busy then least recently used
|
||||
best := target[0]
|
||||
for _, a := range target[1:] {
|
||||
if a.activeRequests < best.activeRequests {
|
||||
best = a
|
||||
} else if a.activeRequests == best.activeRequests && a.lastUsed < best.lastUsed {
|
||||
best = a
|
||||
}
|
||||
}
|
||||
best.lastUsed = now
|
||||
return best.configDir
|
||||
}
|
||||
|
||||
func (p *AccountPool) find(configDir string) *accountStatus {
|
||||
for _, a := range p.accounts {
|
||||
if a.configDir == configDir {
|
||||
return a
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRequestStart(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil {
|
||||
a.activeRequests++
|
||||
a.totalRequests++
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRequestEnd(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil && a.activeRequests > 0 {
|
||||
a.activeRequests--
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRequestSuccess(configDir string, latencyMs int64) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil {
|
||||
a.totalSuccess++
|
||||
a.totalLatencyMs += latencyMs
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRequestError(configDir string, latencyMs int64) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil {
|
||||
a.totalErrors++
|
||||
a.totalLatencyMs += latencyMs
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRateLimit(configDir string, penaltyMs int64) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
if penaltyMs <= 0 {
|
||||
penaltyMs = 60000
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil {
|
||||
a.rateLimitUntil = time.Now().UnixMilli() + penaltyMs
|
||||
a.totalRateLimits++
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) GetStats() []AccountStat {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
now := time.Now().UnixMilli()
|
||||
stats := make([]AccountStat, len(p.accounts))
|
||||
for i, a := range p.accounts {
|
||||
stats[i] = AccountStat{
|
||||
ConfigDir: a.configDir,
|
||||
ActiveRequests: a.activeRequests,
|
||||
TotalRequests: a.totalRequests,
|
||||
TotalSuccess: a.totalSuccess,
|
||||
TotalErrors: a.totalErrors,
|
||||
TotalRateLimits: a.totalRateLimits,
|
||||
TotalLatencyMs: a.totalLatencyMs,
|
||||
IsRateLimited: a.rateLimitUntil > now,
|
||||
RateLimitUntil: a.rateLimitUntil,
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
func (p *AccountPool) Count() int {
|
||||
return len(p.accounts)
|
||||
}
|
||||
|
||||
// ─── Global pool ───────────────────────────────────────────────────────────
|
||||
|
||||
var (
|
||||
globalPool *AccountPool
|
||||
globalMu sync.Mutex
|
||||
)
|
||||
|
||||
func InitAccountPool(configDirs []string) {
|
||||
globalMu.Lock()
|
||||
defer globalMu.Unlock()
|
||||
globalPool = NewAccountPool(configDirs)
|
||||
}
|
||||
|
||||
func GetNextAccountConfigDir() string {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return p.GetNextConfigDir()
|
||||
}
|
||||
|
||||
func ReportRequestStart(configDir string) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRequestStart(configDir)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportRequestEnd(configDir string) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRequestEnd(configDir)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportRequestSuccess(configDir string, latencyMs int64) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRequestSuccess(configDir, latencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportRequestError(configDir string, latencyMs int64) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRequestError(configDir, latencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportRateLimit(configDir string, penaltyMs int64) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRateLimit(configDir, penaltyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func GetAccountStats() []AccountStat {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return p.GetStats()
|
||||
}
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEmptyPool(t *testing.T) {
|
||||
p := NewAccountPool(nil)
|
||||
if got := p.GetNextConfigDir(); got != "" {
|
||||
t.Fatalf("expected empty string for empty pool, got %q", got)
|
||||
}
|
||||
if p.Count() != 0 {
|
||||
t.Fatalf("expected count 0, got %d", p.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleDir(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1"})
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("expected /dir1, got %q", got)
|
||||
}
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("expected /dir1 again, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobin(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/a", "/b", "/c"})
|
||||
got := []string{
|
||||
p.GetNextConfigDir(),
|
||||
p.GetNextConfigDir(),
|
||||
p.GetNextConfigDir(),
|
||||
p.GetNextConfigDir(),
|
||||
}
|
||||
want := []string{"/a", "/b", "/c", "/a"}
|
||||
for i, w := range want {
|
||||
if got[i] != w {
|
||||
t.Fatalf("call %d: expected %q, got %q", i, w, got[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastBusy(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1", "/dir2", "/dir3"})
|
||||
p.ReportRequestStart("/dir1")
|
||||
p.ReportRequestStart("/dir2")
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir3" {
|
||||
t.Fatalf("expected /dir3 (least busy), got %q", got)
|
||||
}
|
||||
|
||||
p.ReportRequestStart("/dir3")
|
||||
p.ReportRequestEnd("/dir1")
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("expected /dir1 after end, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipsRateLimited(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
||||
p.ReportRateLimit("/dir1", 60000)
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
||||
t.Fatalf("expected /dir2, got %q", got)
|
||||
}
|
||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
||||
t.Fatalf("expected /dir2 again, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackToSoonestRecovery(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
||||
p.ReportRateLimit("/dir1", 60000)
|
||||
p.ReportRateLimit("/dir2", 30000)
|
||||
|
||||
// dir2 recovers sooner — should be selected
|
||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
||||
t.Fatalf("expected /dir2 (sooner recovery), got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveRequestsDoesNotGoNegative(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1"})
|
||||
p.ReportRequestEnd("/dir1")
|
||||
p.ReportRequestEnd("/dir1")
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("pool should still work, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnoreUnknownConfigDir(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1"})
|
||||
p.ReportRequestStart("/nonexistent")
|
||||
p.ReportRequestEnd("/nonexistent")
|
||||
p.ReportRateLimit("/nonexistent", 60000)
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("expected /dir1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitExpires(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
||||
p.ReportRateLimit("/dir1", 50)
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
||||
t.Fatalf("immediately expected /dir2, got %q", got)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("after expiry expected /dir1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalPool(t *testing.T) {
|
||||
InitAccountPool([]string{"/g1", "/g2"})
|
||||
if got := GetNextAccountConfigDir(); got != "/g1" {
|
||||
t.Fatalf("expected /g1, got %q", got)
|
||||
}
|
||||
if got := GetNextAccountConfigDir(); got != "/g2" {
|
||||
t.Fatalf("expected /g2, got %q", got)
|
||||
}
|
||||
if got := GetNextAccountConfigDir(); got != "/g1" {
|
||||
t.Fatalf("expected /g1 again, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalPoolEmpty(t *testing.T) {
|
||||
InitAccountPool(nil)
|
||||
if got := GetNextAccountConfigDir(); got != "" {
|
||||
t.Fatalf("expected empty string for empty global pool, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalPoolReinit(t *testing.T) {
|
||||
InitAccountPool([]string{"/old1", "/old2"})
|
||||
GetNextAccountConfigDir()
|
||||
InitAccountPool([]string{"/new1"})
|
||||
if got := GetNextAccountConfigDir(); got != "/new1" {
|
||||
t.Fatalf("expected /new1 after reinit, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalPoolFunctionsNoopBeforeInit(t *testing.T) {
|
||||
InitAccountPool(nil)
|
||||
ReportRequestStart("/dir1")
|
||||
ReportRequestEnd("/dir1")
|
||||
ReportRateLimit("/dir1", 1000)
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func killProcessGroup(c *exec.Cmd) error {
|
||||
if c.Process == nil {
|
||||
return nil
|
||||
}
|
||||
// 殺死整個 process group(負號表示 group)
|
||||
pgid, err := syscall.Getpgid(c.Process.Pid)
|
||||
if err == nil {
|
||||
_ = syscall.Kill(-pgid, syscall.SIGKILL)
|
||||
}
|
||||
// 同時也 kill 主程序,以防萬一
|
||||
return c.Process.Kill()
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func killProcessGroup(c *exec.Cmd) error {
|
||||
if c.Process == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Process.Kill()
|
||||
}
|
||||
|
|
@ -0,0 +1,248 @@
|
|||
package process
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RunResult struct {
|
||||
Code int
|
||||
Stdout string
|
||||
Stderr string
|
||||
}
|
||||
|
||||
type RunOptions struct {
|
||||
Cwd string
|
||||
TimeoutMs int
|
||||
MaxMode bool
|
||||
ConfigDir string
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
type RunStreamingOptions struct {
|
||||
RunOptions
|
||||
OnLine func(line string)
|
||||
}
|
||||
|
||||
// ─── Global child process registry ──────────────────────────────────────────
|
||||
|
||||
var (
|
||||
activeMu sync.Mutex
|
||||
activeChildren []*exec.Cmd
|
||||
)
|
||||
|
||||
func registerChild(c *exec.Cmd) {
|
||||
activeMu.Lock()
|
||||
activeChildren = append(activeChildren, c)
|
||||
activeMu.Unlock()
|
||||
}
|
||||
|
||||
func unregisterChild(c *exec.Cmd) {
|
||||
activeMu.Lock()
|
||||
for i, ch := range activeChildren {
|
||||
if ch == c {
|
||||
activeChildren = append(activeChildren[:i], activeChildren[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
activeMu.Unlock()
|
||||
}
|
||||
|
||||
func KillAllChildProcesses() {
|
||||
activeMu.Lock()
|
||||
all := make([]*exec.Cmd, len(activeChildren))
|
||||
copy(all, activeChildren)
|
||||
activeChildren = nil
|
||||
activeMu.Unlock()
|
||||
for _, c := range all {
|
||||
killProcessGroup(c)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Spawn ────────────────────────────────────────────────────────────────
|
||||
|
||||
func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd {
|
||||
envSrc := env.OsEnvToMap()
|
||||
resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd)
|
||||
|
||||
if opts.MaxMode && maxModeFn != nil {
|
||||
maxModeFn(resolved.AgentScriptPath, opts.ConfigDir)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string, len(resolved.Env))
|
||||
for k, v := range resolved.Env {
|
||||
envMap[k] = v
|
||||
}
|
||||
if opts.ConfigDir != "" {
|
||||
envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir
|
||||
} else if resolved.ConfigDir != "" {
|
||||
if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists {
|
||||
envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir
|
||||
}
|
||||
}
|
||||
|
||||
envSlice := make([]string, 0, len(envMap))
|
||||
for k, v := range envMap {
|
||||
envSlice = append(envSlice, k+"="+v)
|
||||
}
|
||||
|
||||
ctx := opts.Ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出
|
||||
c := exec.CommandContext(ctx, resolved.Command, resolved.Args...)
|
||||
c.Dir = opts.Cwd
|
||||
c.Env = envSlice
|
||||
// 設定新的 process group,使 kill 能傳遞給所有子孫程序
|
||||
c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
// WaitDelay:context cancel 後額外等待這麼久再強制關閉 pipes
|
||||
c.WaitDelay = 5 * time.Second
|
||||
// Cancel 函式:殺死整個 process group
|
||||
c.Cancel = func() error {
|
||||
return killProcessGroup(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// MaxModeFn is set by the agent package to avoid import cycle.
|
||||
var MaxModeFn func(agentScriptPath, configDir string)
|
||||
|
||||
func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) {
|
||||
ctx := opts.Ctx
|
||||
var cancel context.CancelFunc
|
||||
if opts.TimeoutMs > 0 {
|
||||
if ctx == nil {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
} else {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
opts.Ctx = ctx
|
||||
} else if ctx == nil {
|
||||
opts.Ctx = context.Background()
|
||||
}
|
||||
|
||||
c := spawnChild(cmdStr, args, &opts, MaxModeFn)
|
||||
var stdoutBuf, stderrBuf strings.Builder
|
||||
c.Stdout = &stdoutBuf
|
||||
c.Stderr = &stderrBuf
|
||||
|
||||
if err := c.Start(); err != nil {
|
||||
// context 已取消或命令找不到時
|
||||
if opts.Ctx != nil && opts.Ctx.Err() != nil {
|
||||
return RunResult{Code: -1}, nil
|
||||
}
|
||||
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
|
||||
return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
|
||||
}
|
||||
return RunResult{}, err
|
||||
}
|
||||
registerChild(c)
|
||||
defer unregisterChild(c)
|
||||
|
||||
err := c.Wait()
|
||||
code := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
code = exitErr.ExitCode()
|
||||
if code == 0 {
|
||||
code = -1
|
||||
}
|
||||
} else {
|
||||
// context cancelled or killed — return -1 but no error
|
||||
return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil
|
||||
}
|
||||
}
|
||||
return RunResult{
|
||||
Code: code,
|
||||
Stdout: stdoutBuf.String(),
|
||||
Stderr: stderrBuf.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type StreamResult struct {
|
||||
Code int
|
||||
Stderr string
|
||||
}
|
||||
|
||||
func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) {
|
||||
ctx := opts.Ctx
|
||||
var cancel context.CancelFunc
|
||||
if opts.TimeoutMs > 0 {
|
||||
if ctx == nil {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
} else {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
opts.RunOptions.Ctx = ctx
|
||||
} else if opts.RunOptions.Ctx == nil {
|
||||
opts.RunOptions.Ctx = context.Background()
|
||||
}
|
||||
|
||||
c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn)
|
||||
stdoutPipe, err := c.StdoutPipe()
|
||||
if err != nil {
|
||||
return StreamResult{}, err
|
||||
}
|
||||
stderrPipe, err := c.StderrPipe()
|
||||
if err != nil {
|
||||
return StreamResult{}, err
|
||||
}
|
||||
|
||||
if err := c.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
|
||||
return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
|
||||
}
|
||||
return StreamResult{}, err
|
||||
}
|
||||
registerChild(c)
|
||||
defer unregisterChild(c)
|
||||
|
||||
var stderrBuf strings.Builder
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) != "" {
|
||||
opts.OnLine(line)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
for scanner.Scan() {
|
||||
stderrBuf.WriteString(scanner.Text())
|
||||
stderrBuf.WriteString("\n")
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
err = c.Wait()
|
||||
code := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
code = exitErr.ExitCode()
|
||||
if code == 0 {
|
||||
code = -1
|
||||
}
|
||||
}
|
||||
}
|
||||
return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,283 @@
|
|||
package process_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/process"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// sh 是跨平台 shell 執行小 script 的輔助函式
|
||||
func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) {
|
||||
t.Helper()
|
||||
return process.Run("sh", []string{"-c", script}, opts)
|
||||
}
|
||||
|
||||
func TestRun_StdoutAndStderr(t *testing.T) {
|
||||
result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if result.Stdout != "hello\n" {
|
||||
t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n")
|
||||
}
|
||||
if result.Stderr != "world\n" {
|
||||
t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_BasicSpawn(t *testing.T) {
|
||||
result, err := sh(t, "printf ok", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if result.Stdout != "ok" {
|
||||
t.Errorf("Stdout = %q, want ok", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ConfigDir_Propagated(t *testing.T) {
|
||||
result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
|
||||
process.RunOptions{ConfigDir: "/test/account/dir"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stdout != "/test/account/dir" {
|
||||
t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ConfigDir_Absent(t *testing.T) {
|
||||
// 確保沒有殘留的環境變數
|
||||
_ = os.Unsetenv("CURSOR_CONFIG_DIR")
|
||||
result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`},
|
||||
process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stdout != "unset" {
|
||||
t.Errorf("Stdout = %q, want unset", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_NonZeroExit(t *testing.T) {
|
||||
result, err := sh(t, "exit 42", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 42 {
|
||||
t.Errorf("Code = %d, want 42", result.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_Timeout(t *testing.T) {
|
||||
start := time.Now()
|
||||
result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300})
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after timeout")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_OnLine(t *testing.T) {
|
||||
var lines []string
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if len(lines) != 3 {
|
||||
t.Errorf("got %d lines, want 3: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
|
||||
t.Errorf("lines = %v, want [a b c]", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_FlushFinalLine(t *testing.T) {
|
||||
var lines []string
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "printf tail"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if len(lines) != 1 {
|
||||
t.Errorf("got %d lines, want 1: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "tail" {
|
||||
t.Errorf("lines[0] = %q, want tail", lines[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_ConfigDir(t *testing.T) {
|
||||
var lines []string
|
||||
_, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"},
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(lines) != 1 || lines[0] != "/my/config/dir" {
|
||||
t.Errorf("lines = %v, want [/my/config/dir]", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Stderr(t *testing.T) {
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"},
|
||||
process.RunStreamingOptions{OnLine: func(string) {}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stderr == "" {
|
||||
t.Error("expected stderr to contain output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Timeout(t *testing.T) {
|
||||
start := time.Now()
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{TimeoutMs: 300},
|
||||
OnLine: func(string) {},
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after timeout")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Concurrent(t *testing.T) {
|
||||
var lines1, lines2 []string
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
run := func(label string, target *[]string) {
|
||||
process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { *target = append(*target, line) },
|
||||
})
|
||||
done <- struct{}{}
|
||||
}
|
||||
|
||||
go run("stream1", &lines1)
|
||||
go run("stream2", &lines2)
|
||||
|
||||
<-done
|
||||
<-done
|
||||
|
||||
if len(lines1) != 1 || lines1[0] != "stream1" {
|
||||
t.Errorf("lines1 = %v, want [stream1]", lines1)
|
||||
}
|
||||
if len(lines2) != 1 || lines2[0] != "stream2" {
|
||||
t.Errorf("lines2 = %v, want [stream2]", lines2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_ContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
process.RunStreaming("sh", []string{"-c", "sleep 30"},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{Ctx: ctx},
|
||||
OnLine: func(string) {},
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.AfterFunc(100*time.Millisecond, cancel)
|
||||
<-done
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan process.RunResult, 1)
|
||||
|
||||
go func() {
|
||||
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
|
||||
done <- r
|
||||
}()
|
||||
|
||||
time.AfterFunc(100*time.Millisecond, cancel)
|
||||
result := <-done
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after cancel")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_AlreadyCancelledContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 已取消
|
||||
|
||||
start := time.Now()
|
||||
result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKillAllChildProcesses(t *testing.T) {
|
||||
done := make(chan process.RunResult, 1)
|
||||
go func() {
|
||||
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{})
|
||||
done <- r
|
||||
}()
|
||||
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
process.KillAllChildProcesses()
|
||||
result := <-done
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after kill")
|
||||
}
|
||||
// 再次呼叫不應 panic
|
||||
process.KillAllChildProcesses()
|
||||
}
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/handlers"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RouterOptions struct {
|
||||
Version string
|
||||
Config config.BridgeConfig
|
||||
ModelCache *handlers.ModelCacheRef
|
||||
LastModel *string
|
||||
}
|
||||
|
||||
func NewRouter(opts RouterOptions) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cfg := opts.Config
|
||||
pathname := r.URL.Path
|
||||
method := r.Method
|
||||
remoteAddress := r.RemoteAddr
|
||||
if r.Header.Get("X-Real-IP") != "" {
|
||||
remoteAddress = r.Header.Get("X-Real-IP")
|
||||
}
|
||||
|
||||
logger.LogIncoming(method, pathname, remoteAddress)
|
||||
|
||||
defer func() {
|
||||
logger.AppendSessionLine(cfg.SessionsLogPath, method, pathname, remoteAddress, 200)
|
||||
}()
|
||||
|
||||
if cfg.RequiredKey != "" {
|
||||
token := httputil.ExtractBearerToken(r)
|
||||
if token != cfg.RequiredKey {
|
||||
httputil.WriteJSON(w, 401, map[string]interface{}{
|
||||
"error": map[string]string{"message": "Invalid API key", "code": "unauthorized"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case method == "GET" && pathname == "/health":
|
||||
handlers.HandleHealth(w, r, opts.Version, cfg)
|
||||
|
||||
case method == "GET" && pathname == "/v1/models":
|
||||
opts.ModelCache.HandleModels(w, r, cfg)
|
||||
|
||||
case method == "POST" && pathname == "/v1/chat/completions":
|
||||
raw, err := httputil.ReadBody(r)
|
||||
if err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"message": "failed to read body", "code": "bad_request"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
handlers.HandleChatCompletions(w, r, cfg, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
|
||||
case method == "POST" && pathname == "/v1/messages":
|
||||
raw, err := httputil.ReadBody(r)
|
||||
if err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"message": "failed to read body", "code": "bad_request"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
handlers.HandleAnthropicMessages(w, r, cfg, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
|
||||
case (method == "POST" || method == "GET") && pathname == "/v1/completions":
|
||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
||||
"error": map[string]string{
|
||||
"message": "Legacy completions endpoint is not supported. Use POST /v1/chat/completions instead.",
|
||||
"code": "not_found",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
case pathname == "/v1/embeddings":
|
||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
||||
"error": map[string]string{
|
||||
"message": "Embeddings are not supported by this proxy.",
|
||||
"code": "not_found",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
default:
|
||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
||||
"error": map[string]string{"message": "Not found", "code": "not_found"},
|
||||
}, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func recoveryMiddleware(logPath string, next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
msg := fmt.Sprintf("%v", rec)
|
||||
fmt.Fprintf(os.Stderr, "[%s] Proxy panic: %s\n", time.Now().UTC().Format(time.RFC3339), msg)
|
||||
line := fmt.Sprintf("%s ERROR %s %s %s %s\n",
|
||||
time.Now().UTC().Format(time.RFC3339), r.Method, r.URL.Path, r.RemoteAddr,
|
||||
msg[:min(200, len(msg))])
|
||||
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
|
||||
_, _ = f.WriteString(line)
|
||||
f.Close()
|
||||
}
|
||||
if !isHeaderWritten(w) {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": msg, "code": "internal_error"},
|
||||
}, nil)
|
||||
}
|
||||
}
|
||||
}()
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func isHeaderWritten(w http.ResponseWriter) bool {
|
||||
// Can't reliably detect without wrapping; always try to write
|
||||
return false
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func WrapWithRecovery(logPath string, handler http.HandlerFunc) http.HandlerFunc {
|
||||
return recoveryMiddleware(logPath, handler)
|
||||
}
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
package sanitize
|
||||
|
||||
import "regexp"
|
||||
|
||||
type rule struct {
|
||||
pattern *regexp.Regexp
|
||||
replacement string
|
||||
}
|
||||
|
||||
var rules = []rule{
|
||||
{regexp.MustCompile(`(?i)x-anthropic-billing-header:[^\n]*\n?`), ""},
|
||||
{regexp.MustCompile(`(?i)\bcc_version=[^\s;,\n]+[;,]?\s*`), ""},
|
||||
{regexp.MustCompile(`(?i)\bcc_entrypoint=[^\s;,\n]+[;,]?\s*`), ""},
|
||||
{regexp.MustCompile(`(?i)\bcch=[a-f0-9]+[;,]?\s*`), ""},
|
||||
{regexp.MustCompile(`\bClaude Code\b`), "Cursor"},
|
||||
{regexp.MustCompile(`(?i)Anthropic['']s official CLI for Claude`), "Cursor AI assistant"},
|
||||
{regexp.MustCompile(`\bAnthropic\b`), "Cursor"},
|
||||
{regexp.MustCompile(`(?i)anthropic\.com`), "cursor.com"},
|
||||
{regexp.MustCompile(`(?i)claude\.ai`), "cursor.sh"},
|
||||
{regexp.MustCompile(`^[;,\s]+`), ""},
|
||||
}
|
||||
|
||||
func SanitizeText(text string) string {
|
||||
for _, r := range rules {
|
||||
text = r.pattern.ReplaceAllString(text, r.replacement)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func SanitizeMessages(messages []interface{}) []interface{} {
|
||||
result := make([]interface{}, len(messages))
|
||||
for i, raw := range messages {
|
||||
if raw == nil {
|
||||
result[i] = raw
|
||||
continue
|
||||
}
|
||||
m, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
result[i] = raw
|
||||
continue
|
||||
}
|
||||
newMsg := make(map[string]interface{}, len(m))
|
||||
for k, v := range m {
|
||||
newMsg[k] = v
|
||||
}
|
||||
switch c := m["content"].(type) {
|
||||
case string:
|
||||
newMsg["content"] = SanitizeText(c)
|
||||
case []interface{}:
|
||||
newParts := make([]interface{}, len(c))
|
||||
for j, p := range c {
|
||||
if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" {
|
||||
if t, ok := pm["text"].(string); ok {
|
||||
newPart := make(map[string]interface{}, len(pm))
|
||||
for k, v := range pm {
|
||||
newPart[k] = v
|
||||
}
|
||||
newPart["text"] = SanitizeText(t)
|
||||
newParts[j] = newPart
|
||||
continue
|
||||
}
|
||||
}
|
||||
newParts[j] = p
|
||||
}
|
||||
newMsg["content"] = newParts
|
||||
}
|
||||
result[i] = newMsg
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func SanitizeSystem(system interface{}) interface{} {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return SanitizeText(v)
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(v))
|
||||
for i, p := range v {
|
||||
if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" {
|
||||
if t, ok := pm["text"].(string); ok {
|
||||
newPart := make(map[string]interface{}, len(pm))
|
||||
for k, val := range pm {
|
||||
newPart[k] = val
|
||||
}
|
||||
newPart["text"] = SanitizeText(t)
|
||||
result[i] = newPart
|
||||
continue
|
||||
}
|
||||
}
|
||||
result[i] = p
|
||||
}
|
||||
return result
|
||||
}
|
||||
return system
|
||||
}
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
package sanitize
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeTextAnthropicBilling(t *testing.T) {
|
||||
input := "x-anthropic-billing-header: abc123\nHello"
|
||||
got := SanitizeText(input)
|
||||
if strings.Contains(got, "x-anthropic-billing-header") {
|
||||
t.Errorf("billing header not removed: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeTextClaudeCode(t *testing.T) {
|
||||
input := "I am Claude Code assistant"
|
||||
got := SanitizeText(input)
|
||||
if strings.Contains(got, "Claude Code") {
|
||||
t.Errorf("'Claude Code' not replaced: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Cursor") {
|
||||
t.Errorf("'Cursor' not present in output: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeTextAnthropic(t *testing.T) {
|
||||
input := "Powered by Anthropic's technology and anthropic.com"
|
||||
got := SanitizeText(input)
|
||||
if strings.Contains(got, "Anthropic") {
|
||||
t.Errorf("'Anthropic' not replaced: %q", got)
|
||||
}
|
||||
if strings.Contains(got, "anthropic.com") {
|
||||
t.Errorf("'anthropic.com' not replaced: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeTextNoChange(t *testing.T) {
|
||||
input := "Hello, this is a normal message about cursor."
|
||||
got := SanitizeText(input)
|
||||
if got != input {
|
||||
t.Errorf("unexpected change: %q -> %q", input, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeMessages(t *testing.T) {
|
||||
messages := []interface{}{
|
||||
map[string]interface{}{"role": "user", "content": "Ask Claude Code something"},
|
||||
map[string]interface{}{"role": "system", "content": "Use Anthropic's tools"},
|
||||
}
|
||||
result := SanitizeMessages(messages)
|
||||
|
||||
for _, raw := range result {
|
||||
m := raw.(map[string]interface{})
|
||||
c := m["content"].(string)
|
||||
if strings.Contains(c, "Claude Code") || strings.Contains(c, "Anthropic") {
|
||||
t.Errorf("found unsanitized content: %q", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/handlers"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"cursor-api-proxy/internal/process"
|
||||
"cursor-api-proxy/internal/router"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ServerOptions struct {
|
||||
Version string
|
||||
Config config.BridgeConfig
|
||||
}
|
||||
|
||||
func StartBridgeServer(opts ServerOptions) []*http.Server {
|
||||
cfg := opts.Config
|
||||
var servers []*http.Server
|
||||
|
||||
if len(cfg.ConfigDirs) > 0 {
|
||||
if cfg.MultiPort {
|
||||
for i, dir := range cfg.ConfigDirs {
|
||||
port := cfg.Port + i
|
||||
subCfg := cfg
|
||||
subCfg.Port = port
|
||||
subCfg.ConfigDirs = []string{dir}
|
||||
subCfg.MultiPort = false
|
||||
pool.InitAccountPool([]string{dir})
|
||||
srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg})
|
||||
servers = append(servers, srv)
|
||||
}
|
||||
return servers
|
||||
}
|
||||
pool.InitAccountPool(cfg.ConfigDirs)
|
||||
}
|
||||
|
||||
servers = append(servers, startSingleServer(opts))
|
||||
return servers
|
||||
}
|
||||
|
||||
func startSingleServer(opts ServerOptions) *http.Server {
|
||||
cfg := opts.Config
|
||||
|
||||
modelCache := &handlers.ModelCacheRef{}
|
||||
lastModel := cfg.DefaultModel
|
||||
|
||||
handler := router.NewRouter(router.RouterOptions{
|
||||
Version: opts.Version,
|
||||
Config: cfg,
|
||||
ModelCache: modelCache,
|
||||
LastModel: &lastModel,
|
||||
})
|
||||
handler = router.WrapWithRecovery(cfg.SessionsLogPath, handler)
|
||||
|
||||
useTLS := cfg.TLSCertPath != "" && cfg.TLSKeyPath != ""
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
if useTLS {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "TLS error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if useTLS {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
if useTLS {
|
||||
err = srv.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
if isAddrInUse(err) {
|
||||
fmt.Fprintf(os.Stderr, "❌ Port %d is already in use. Set CURSOR_BRIDGE_PORT to use a different port.\n", cfg.Port)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "❌ Server error: %v\n", err)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
fmt.Printf("cursor-api-proxy listening on %s://%s:%d\n", scheme, cfg.Host, cfg.Port)
|
||||
fmt.Printf("- agent bin: %s\n", cfg.AgentBin)
|
||||
fmt.Printf("- workspace: %s\n", cfg.Workspace)
|
||||
fmt.Printf("- mode: %s\n", cfg.Mode)
|
||||
fmt.Printf("- default model: %s\n", cfg.DefaultModel)
|
||||
fmt.Printf("- force: %v\n", cfg.Force)
|
||||
fmt.Printf("- approve mcps: %v\n", cfg.ApproveMcps)
|
||||
fmt.Printf("- required api key: %v\n", cfg.RequiredKey != "")
|
||||
fmt.Printf("- sessions log: %s\n", cfg.SessionsLogPath)
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
fmt.Println("- chat-only workspace: yes (isolated temp dir)")
|
||||
} else {
|
||||
fmt.Println("- chat-only workspace: no")
|
||||
}
|
||||
if cfg.Verbose {
|
||||
fmt.Println("- verbose traffic: yes (CURSOR_BRIDGE_VERBOSE=true)")
|
||||
} else {
|
||||
fmt.Println("- verbose traffic: no")
|
||||
}
|
||||
if cfg.MaxMode {
|
||||
fmt.Println("- max mode: yes (CURSOR_BRIDGE_MAX_MODE=true)")
|
||||
} else {
|
||||
fmt.Println("- max mode: no")
|
||||
}
|
||||
fmt.Printf("- Windows cmdline budget: %d (prompt tail truncation when over limit; Windows only)\n", cfg.WinCmdlineMax)
|
||||
if len(cfg.ConfigDirs) > 0 {
|
||||
fmt.Printf("- account pool: enabled with %d configuration directories\n", len(cfg.ConfigDirs))
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func SetupGracefulShutdown(servers []*http.Server, timeoutMs int) {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
sig := <-sigCh
|
||||
fmt.Printf("\n[%s] %s received — shutting down gracefully…\n",
|
||||
time.Now().UTC().Format(time.RFC3339), sig)
|
||||
|
||||
process.KillAllChildProcesses()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for _, srv := range servers {
|
||||
_ = srv.Shutdown(ctx)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
os.Exit(0)
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintln(os.Stderr, "[shutdown] Timed out waiting for connections to drain — forcing exit.")
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func isAddrInUse(err error) bool {
|
||||
return err != nil && (contains(err.Error(), "address already in use") || contains(err.Error(), "bind: address already in use"))
|
||||
}
|
||||
|
||||
func contains(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsHelper(s, sub))
|
||||
}
|
||||
|
||||
func containsHelper(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
package server_test
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/server"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// freePort 取得一個暫時可用的隨機 port
|
||||
func freePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
return port
|
||||
}
|
||||
|
||||
// makeFakeAgentBin 建立一個 shell script,模擬 agent 固定輸出
|
||||
// sync 模式:直接輸出一行文字
|
||||
// stream 模式:輸出 JSON stream 行
|
||||
func makeFakeAgentBin(t *testing.T, syncOutput string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
script := dir + "/agent"
|
||||
content := fmt.Sprintf(`#!/bin/sh
|
||||
# 若有 --stream-json 則輸出 stream 格式
|
||||
for arg; do
|
||||
if [ "$arg" = "--stream-json" ]; then
|
||||
printf '%%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"%s"}]}}'
|
||||
printf '%%s\n' '{"type":"result","subtype":"success"}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
# 否則輸出 sync 格式
|
||||
printf '%%s' '%s'
|
||||
`, syncOutput, syncOutput)
|
||||
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
// makeFakeAgentBinWithModels 額外支援 --list-models 輸出
|
||||
func makeFakeAgentBinWithModels(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
script := dir + "/agent"
|
||||
content := `#!/bin/sh
|
||||
for arg; do
|
||||
if [ "$arg" = "--list-models" ]; then
|
||||
printf 'claude-3-opus - Claude 3 Opus\n'
|
||||
printf 'claude-3-sonnet - Claude 3 Sonnet\n'
|
||||
exit 0
|
||||
fi
|
||||
if [ "$arg" = "--stream-json" ]; then
|
||||
printf '%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}'
|
||||
printf '%s\n' '{"type":"result","subtype":"success"}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
printf 'Hello from agent'
|
||||
`
|
||||
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
func makeTestConfig(agentBin string, port int, overrides ...func(*config.BridgeConfig)) config.BridgeConfig {
|
||||
cfg := config.BridgeConfig{
|
||||
AgentBin: agentBin,
|
||||
Host: "127.0.0.1",
|
||||
Port: port,
|
||||
DefaultModel: "auto",
|
||||
Mode: "ask",
|
||||
Force: false,
|
||||
ApproveMcps: false,
|
||||
StrictModel: true,
|
||||
Workspace: os.TempDir(),
|
||||
TimeoutMs: 30000,
|
||||
SessionsLogPath: os.TempDir() + "/test-sessions.log",
|
||||
ChatOnlyWorkspace: true,
|
||||
Verbose: false,
|
||||
MaxMode: false,
|
||||
ConfigDirs: []string{},
|
||||
MultiPort: false,
|
||||
WinCmdlineMax: 30000,
|
||||
}
|
||||
for _, fn := range overrides {
|
||||
fn(&cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func waitListening(t *testing.T, host string, port int, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 50*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("server on port %d did not start within %v", port, timeout)
|
||||
}
|
||||
|
||||
func doRequest(t *testing.T, method, url, body string, headers map[string]string) (int, string) {
|
||||
t.Helper()
|
||||
var reqBody io.Reader
|
||||
if body != "" {
|
||||
reqBody = strings.NewReader(body)
|
||||
}
|
||||
req, err := http.NewRequest(method, url, reqBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
return resp.StatusCode, string(data)
|
||||
}
|
||||
|
||||
func TestBridgeServer_Health(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port)
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
||||
if status != 200 {
|
||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["ok"] != true {
|
||||
t.Errorf("ok = %v, want true", result["ok"])
|
||||
}
|
||||
if result["version"] != "1.0.0" {
|
||||
t.Errorf("version = %v, want 1.0.0", result["version"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_Models(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port)
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/v1/models", port), "", nil)
|
||||
if status != 200 {
|
||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["object"] != "list" {
|
||||
t.Errorf("object = %v, want list", result["object"])
|
||||
}
|
||||
data := result["data"].([]interface{})
|
||||
if len(data) < 2 {
|
||||
t.Errorf("data len = %d, want >= 2", len(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_Unauthorized(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
||||
c.RequiredKey = "secret123"
|
||||
})
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
||||
if status != 401 {
|
||||
t.Fatalf("status = %d, want 401; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
errObj := result["error"].(map[string]interface{})
|
||||
if errObj["message"] != "Invalid API key" {
|
||||
t.Errorf("message = %v, want 'Invalid API key'", errObj["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_AuthorizedKey(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
||||
c.RequiredKey = "secret123"
|
||||
})
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, _ := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", map[string]string{
|
||||
"Authorization": "Bearer secret123",
|
||||
})
|
||||
if status != 200 {
|
||||
t.Errorf("status = %d, want 200", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_NotFound(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port)
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/unknown", port), "", nil)
|
||||
if status != 404 {
|
||||
t.Fatalf("status = %d, want 404; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
errObj := result["error"].(map[string]interface{})
|
||||
if errObj["code"] != "not_found" {
|
||||
t.Errorf("code = %v, want not_found", errObj["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_ChatCompletions_Sync(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBin(t, "Hello from agent")
|
||||
cfg := makeTestConfig(agentBin, port)
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hi"}]}`
|
||||
status, body := doRequest(t, "POST", fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", port), reqBody, nil)
|
||||
if status != 200 {
|
||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["object"] != "chat.completion" {
|
||||
t.Errorf("object = %v, want chat.completion", result["object"])
|
||||
}
|
||||
choices := result["choices"].([]interface{})
|
||||
msg := choices[0].(map[string]interface{})["message"].(map[string]interface{})
|
||||
if msg["content"] != "Hello from agent" {
|
||||
t.Errorf("content = %v, want 'Hello from agent'", msg["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_MultiPort(t *testing.T) {
|
||||
basePort := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
|
||||
dir1 := t.TempDir()
|
||||
dir2 := t.TempDir()
|
||||
|
||||
cfg := makeTestConfig(agentBin, basePort, func(c *config.BridgeConfig) {
|
||||
c.ConfigDirs = []string{dir1, dir2}
|
||||
c.MultiPort = true
|
||||
})
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
if len(srvs) != 2 {
|
||||
t.Fatalf("got %d servers, want 2", len(srvs))
|
||||
}
|
||||
|
||||
// 等待兩個 server 啟動(port 可能會衝突,這裡不嚴格測試 port 分配)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
package winlimit
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/env"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n"
|
||||
|
||||
type FitPromptResult struct {
|
||||
OK bool
|
||||
Args []string
|
||||
Truncated bool
|
||||
OriginalLength int
|
||||
FinalPromptLength int
|
||||
Error string
|
||||
}
|
||||
|
||||
func estimateCmdlineLength(resolved env.AgentCommand) int {
|
||||
argv := append([]string{resolved.Command}, resolved.Args...)
|
||||
if resolved.WindowsVerbatimArguments {
|
||||
n := 0
|
||||
for _, a := range argv {
|
||||
n += len(a)
|
||||
}
|
||||
if len(argv) > 1 {
|
||||
n += len(argv) - 1
|
||||
}
|
||||
return n + 512
|
||||
}
|
||||
dstLen := 0
|
||||
for _, a := range argv {
|
||||
dstLen += len(a)
|
||||
}
|
||||
dstLen = dstLen*2 + len(argv)*2
|
||||
if len(argv) > 1 {
|
||||
dstLen += len(argv) - 1
|
||||
}
|
||||
return dstLen + 512
|
||||
}
|
||||
|
||||
func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult {
|
||||
if runtime.GOOS != "windows" {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = prompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: false,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(prompt),
|
||||
}
|
||||
}
|
||||
|
||||
e := env.OsEnvToMap()
|
||||
measured := func(p string) int {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = p
|
||||
resolved := env.ResolveAgentCommand(agentBin, args, e, cwd)
|
||||
return estimateCmdlineLength(resolved)
|
||||
}
|
||||
|
||||
if measured("") > maxCmdline {
|
||||
return FitPromptResult{
|
||||
OK: false,
|
||||
Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.",
|
||||
}
|
||||
}
|
||||
|
||||
if measured(prompt) <= maxCmdline {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = prompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: false,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(prompt),
|
||||
}
|
||||
}
|
||||
|
||||
prefix := WinPromptOmissionPrefix
|
||||
if measured(prefix) > maxCmdline {
|
||||
return FitPromptResult{
|
||||
OK: false,
|
||||
Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.",
|
||||
}
|
||||
}
|
||||
|
||||
lo, hi, best := 0, len(prompt), 0
|
||||
for lo <= hi {
|
||||
mid := (lo + hi) / 2
|
||||
var tail string
|
||||
if mid > 0 {
|
||||
tail = prompt[len(prompt)-mid:]
|
||||
}
|
||||
candidate := prefix + tail
|
||||
if measured(candidate) <= maxCmdline {
|
||||
best = mid
|
||||
lo = mid + 1
|
||||
} else {
|
||||
hi = mid - 1
|
||||
}
|
||||
}
|
||||
|
||||
var finalPrompt string
|
||||
if best == 0 {
|
||||
finalPrompt = prefix
|
||||
} else {
|
||||
finalPrompt = prefix + prompt[len(prompt)-best:]
|
||||
}
|
||||
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = finalPrompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: true,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(finalPrompt),
|
||||
}
|
||||
}
|
||||
|
||||
func WarnPromptTruncated(originalLength, finalLength int) {
|
||||
_ = originalLength
|
||||
_ = finalLength
|
||||
// fmt.Fprintf skipped to avoid import; caller may log as needed
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
package winlimit
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNonWindowsPassThrough(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping non-Windows test on Windows")
|
||||
}
|
||||
|
||||
fixedArgs := []string{"--print", "--model", "gpt-4"}
|
||||
prompt := "Hello world"
|
||||
result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp")
|
||||
|
||||
if !result.OK {
|
||||
t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error)
|
||||
}
|
||||
if result.Truncated {
|
||||
t.Error("expected no truncation on non-Windows")
|
||||
}
|
||||
if result.OriginalLength != len(prompt) {
|
||||
t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength)
|
||||
}
|
||||
// Last arg should be the prompt
|
||||
if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt {
|
||||
t.Errorf("expected last arg to be prompt, got %v", result.Args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOmissionPrefix(t *testing.T) {
|
||||
if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") {
|
||||
t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
package workspace
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type WorkspaceResult struct {
|
||||
WorkspaceDir string
|
||||
TempDir string
|
||||
}
|
||||
|
||||
func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult {
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
tempDir, err := os.MkdirTemp("", "cursor-proxy-")
|
||||
if err != nil {
|
||||
tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback")
|
||||
_ = os.MkdirAll(tempDir, 0700)
|
||||
}
|
||||
return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir}
|
||||
}
|
||||
|
||||
headerWs := strings.TrimSpace(workspaceHeader)
|
||||
if headerWs != "" {
|
||||
return WorkspaceResult{WorkspaceDir: headerWs}
|
||||
}
|
||||
return WorkspaceResult{WorkspaceDir: cfg.Workspace}
|
||||
}
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/cmd"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"cursor-api-proxy/internal/server"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
const version = "1.0.0"
|
||||
|
||||
func main() {
|
||||
args, err := cmd.ParseArgs(os.Args[1:])
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if args.Help {
|
||||
cmd.PrintHelp(version)
|
||||
return
|
||||
}
|
||||
|
||||
if args.Login {
|
||||
if err := cmd.HandleLogin(args.AccountName, args.Proxies); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if args.Logout {
|
||||
if err := cmd.HandleLogout(args.AccountName); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if args.AccountsList {
|
||||
if err := cmd.HandleAccountsList(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if args.ResetHwid {
|
||||
if err := cmd.HandleResetHwid(args.DeepClean, args.DryRun); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
e := env.OsEnvToMap()
|
||||
if args.Tailscale {
|
||||
e["CURSOR_BRIDGE_HOST"] = "0.0.0.0"
|
||||
}
|
||||
|
||||
cwd, _ := os.Getwd()
|
||||
cfg := config.LoadBridgeConfig(e, cwd)
|
||||
|
||||
servers := server.StartBridgeServer(server.ServerOptions{
|
||||
Version: version,
|
||||
Config: cfg,
|
||||
})
|
||||
server.SetupGracefulShutdown(servers, 10000)
|
||||
|
||||
select {}
|
||||
}
|
||||
Loading…
Reference in New Issue