Task 9: Cleanup - remove old internal files, update import paths, add go-zero entry point
- Removed old files from internal/* (migrated to pkg/*) - Removed old CLI files from cmd/ (now in cmd/cli/) - Updated import paths (internal/* -> pkg/*) - Rewrote main.go to support CLI commands + go-zero HTTP server - Fixed AccountStat type references (use entity.AccountStat) - Fixed cmd/cli/* to use usecase package instead of agent - Fixed logger to use entity.AccountStat - Fixed geminiweb fmt.Println redundant newlines - Fixed scripts/detect-gemini-dom.go pointer format issues
This commit is contained in:
parent
7e0b7a970c
commit
081f404f77
196
cmd/accounts.go
196
cmd/accounts.go
|
|
@ -1,196 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type AccountInfo struct {
|
||||
Name string
|
||||
ConfigDir string
|
||||
Authenticated bool
|
||||
Email string
|
||||
DisplayName string
|
||||
AuthID string
|
||||
Plan string
|
||||
SubscriptionStatus string
|
||||
ExpiresAt string
|
||||
}
|
||||
|
||||
func ReadAccountInfo(name, configDir string) AccountInfo {
|
||||
info := AccountInfo{Name: name, ConfigDir: configDir}
|
||||
|
||||
configFile := filepath.Join(configDir, "cli-config.json")
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
var raw struct {
|
||||
AuthInfo *struct {
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"displayName"`
|
||||
AuthID string `json:"authId"`
|
||||
} `json:"authInfo"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err == nil && raw.AuthInfo != nil {
|
||||
info.Authenticated = true
|
||||
info.Email = raw.AuthInfo.Email
|
||||
info.DisplayName = raw.AuthInfo.DisplayName
|
||||
info.AuthID = raw.AuthInfo.AuthID
|
||||
}
|
||||
|
||||
statsigFile := filepath.Join(configDir, "statsig-cache.json")
|
||||
statsigData, err := os.ReadFile(statsigFile)
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
var statsigWrapper struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(statsigData, &statsigWrapper); err != nil || statsigWrapper.Data == "" {
|
||||
return info
|
||||
}
|
||||
|
||||
var statsig struct {
|
||||
User *struct {
|
||||
Custom *struct {
|
||||
IsEnterpriseUser bool `json:"isEnterpriseUser"`
|
||||
StripeSubscriptionStatus string `json:"stripeSubscriptionStatus"`
|
||||
StripeMembershipExpiration string `json:"stripeMembershipExpiration"`
|
||||
} `json:"custom"`
|
||||
} `json:"user"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(statsigWrapper.Data), &statsig); err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
if statsig.User != nil && statsig.User.Custom != nil {
|
||||
c := statsig.User.Custom
|
||||
if c.IsEnterpriseUser {
|
||||
info.Plan = "Enterprise"
|
||||
} else if c.StripeSubscriptionStatus == "active" {
|
||||
info.Plan = "Pro"
|
||||
} else {
|
||||
info.Plan = "Free"
|
||||
}
|
||||
info.SubscriptionStatus = c.StripeSubscriptionStatus
|
||||
info.ExpiresAt = c.StripeMembershipExpiration
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func HandleAccountsList() error {
|
||||
accountsDir := agent.AccountsDir()
|
||||
|
||||
entries, err := os.ReadDir(accountsDir)
|
||||
if err != nil {
|
||||
fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.")
|
||||
return nil
|
||||
}
|
||||
|
||||
var names []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Print("Cursor Accounts:\n\n")
|
||||
|
||||
keychainToken := agent.ReadKeychainToken()
|
||||
|
||||
for i, name := range names {
|
||||
configDir := filepath.Join(accountsDir, name)
|
||||
info := ReadAccountInfo(name, configDir)
|
||||
|
||||
fmt.Printf(" %d. %s\n", i+1, name)
|
||||
|
||||
if info.Authenticated {
|
||||
cachedToken := agent.ReadCachedToken(configDir)
|
||||
keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID
|
||||
token := cachedToken
|
||||
if token == "" && keychainMatchesAccount {
|
||||
token = keychainToken
|
||||
}
|
||||
|
||||
var liveProfile *StripeProfile
|
||||
var liveUsage *UsageData
|
||||
if token != "" {
|
||||
liveUsage, _ = FetchAccountUsage(token)
|
||||
liveProfile, _ = FetchStripeProfile(token)
|
||||
}
|
||||
|
||||
if info.Email != "" {
|
||||
display := ""
|
||||
if info.DisplayName != "" {
|
||||
display = " (" + info.DisplayName + ")"
|
||||
}
|
||||
fmt.Printf(" %s%s\n", info.Email, display)
|
||||
}
|
||||
|
||||
if info.Plan != "" && liveProfile == nil {
|
||||
canceled := ""
|
||||
if info.SubscriptionStatus == "canceled" {
|
||||
canceled = " · canceled"
|
||||
}
|
||||
expiry := ""
|
||||
if info.ExpiresAt != "" {
|
||||
expiry = " · expires " + info.ExpiresAt
|
||||
}
|
||||
fmt.Printf(" %s%s%s\n", info.Plan, canceled, expiry)
|
||||
}
|
||||
fmt.Println(" Authenticated")
|
||||
|
||||
if liveProfile != nil {
|
||||
fmt.Printf(" %s\n", DescribePlan(liveProfile))
|
||||
}
|
||||
if liveUsage != nil {
|
||||
for _, line := range FormatUsageSummary(liveUsage) {
|
||||
fmt.Println(line)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Println(" Not authenticated")
|
||||
}
|
||||
|
||||
fmt.Println("")
|
||||
}
|
||||
|
||||
fmt.Println("Tip: run 'cursor-api-proxy logout <name>' to remove an account.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleLogout(accountName string) error {
|
||||
if accountName == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: Please specify the account name to remove.")
|
||||
fmt.Fprintln(os.Stderr, "Usage: cursor-api-proxy logout <account-name>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
accountsDir := agent.AccountsDir()
|
||||
configDir := filepath.Join(accountsDir, accountName)
|
||||
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "Account '%s' not found.\n", accountName)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(configDir); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error removing account: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Account '%s' removed.\n", accountName)
|
||||
return nil
|
||||
}
|
||||
118
cmd/args.go
118
cmd/args.go
|
|
@ -1,118 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ParsedArgs struct {
|
||||
Tailscale bool
|
||||
Help bool
|
||||
Login bool
|
||||
AccountsList bool
|
||||
Logout bool
|
||||
AccountName string
|
||||
Proxies []string
|
||||
ResetHwid bool
|
||||
DeepClean bool
|
||||
DryRun bool
|
||||
}
|
||||
|
||||
func ParseArgs(argv []string) (ParsedArgs, error) {
|
||||
var args ParsedArgs
|
||||
|
||||
for i := 0; i < len(argv); i++ {
|
||||
arg := argv[i]
|
||||
|
||||
switch arg {
|
||||
case "login", "add-account":
|
||||
args.Login = true
|
||||
if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' {
|
||||
i++
|
||||
args.AccountName = argv[i]
|
||||
}
|
||||
|
||||
case "logout", "remove-account":
|
||||
args.Logout = true
|
||||
if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' {
|
||||
i++
|
||||
args.AccountName = argv[i]
|
||||
}
|
||||
|
||||
case "accounts", "list-accounts":
|
||||
args.AccountsList = true
|
||||
|
||||
case "reset-hwid", "reset":
|
||||
args.ResetHwid = true
|
||||
|
||||
case "--deep-clean":
|
||||
args.DeepClean = true
|
||||
|
||||
case "--dry-run":
|
||||
args.DryRun = true
|
||||
|
||||
case "--tailscale":
|
||||
args.Tailscale = true
|
||||
|
||||
case "--help", "-h":
|
||||
args.Help = true
|
||||
|
||||
default:
|
||||
if len(arg) > len("--proxy=") && arg[:len("--proxy=")] == "--proxy=" {
|
||||
raw := arg[len("--proxy="):]
|
||||
parts := splitComma(raw)
|
||||
for _, p := range parts {
|
||||
if p != "" {
|
||||
args.Proxies = append(args.Proxies, p)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return args, fmt.Errorf("Unknown argument: %s", arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func splitComma(s string) []string {
|
||||
var result []string
|
||||
start := 0
|
||||
for i := 0; i <= len(s); i++ {
|
||||
if i == len(s) || s[i] == ',' {
|
||||
part := trim(s[start:i])
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func trim(s string) string {
|
||||
start := 0
|
||||
end := len(s)
|
||||
for start < end && (s[start] == ' ' || s[start] == '\t') {
|
||||
start++
|
||||
}
|
||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
|
||||
end--
|
||||
}
|
||||
return s[start:end]
|
||||
}
|
||||
|
||||
func PrintHelp(version string) {
|
||||
fmt.Printf("cursor-api-proxy v%s\n\n", version)
|
||||
fmt.Println("Usage:")
|
||||
fmt.Println(" cursor-api-proxy [options]")
|
||||
fmt.Println("")
|
||||
fmt.Println("Commands:")
|
||||
fmt.Println(" login [name] Log into a Cursor account (saved to ~/.cursor-api-proxy/accounts/)")
|
||||
fmt.Println(" login [name] --proxy=... Same, but with a proxy from a comma-separated list")
|
||||
fmt.Println(" logout <name> Remove a saved Cursor account")
|
||||
fmt.Println(" accounts List saved accounts with plan info")
|
||||
fmt.Println(" reset-hwid Reset Cursor machine/telemetry IDs (anti-ban)")
|
||||
fmt.Println(" reset-hwid --deep-clean Also wipe session storage and cookies")
|
||||
fmt.Println("")
|
||||
fmt.Println("Options:")
|
||||
fmt.Println(" --tailscale Bind to 0.0.0.0 for tailnet/LAN access")
|
||||
fmt.Println(" -h, --help Show this help message")
|
||||
}
|
||||
|
|
@ -1,11 +1,12 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"cursor-api-proxy/pkg/usecase"
|
||||
)
|
||||
|
||||
type AccountInfo struct {
|
||||
|
|
@ -86,7 +87,7 @@ func ReadAccountInfo(name, configDir string) AccountInfo {
|
|||
}
|
||||
|
||||
func HandleAccountsList() error {
|
||||
accountsDir := agent.AccountsDir()
|
||||
accountsDir := usecase.AccountsDir()
|
||||
|
||||
entries, err := os.ReadDir(accountsDir)
|
||||
if err != nil {
|
||||
|
|
@ -108,7 +109,7 @@ func HandleAccountsList() error {
|
|||
|
||||
fmt.Print("Cursor Accounts:\n\n")
|
||||
|
||||
keychainToken := agent.ReadKeychainToken()
|
||||
keychainToken := usecase.ReadKeychainToken()
|
||||
|
||||
for i, name := range names {
|
||||
configDir := filepath.Join(accountsDir, name)
|
||||
|
|
@ -117,7 +118,7 @@ func HandleAccountsList() error {
|
|||
fmt.Printf(" %d. %s\n", i+1, name)
|
||||
|
||||
if info.Authenticated {
|
||||
cachedToken := agent.ReadCachedToken(configDir)
|
||||
cachedToken := usecase.ReadCachedToken(configDir)
|
||||
keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID
|
||||
token := cachedToken
|
||||
if token == "" && keychainMatchesAccount {
|
||||
|
|
@ -178,7 +179,7 @@ func HandleLogout(accountName string) error {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
accountsDir := agent.AccountsDir()
|
||||
accountsDir := usecase.AccountsDir()
|
||||
configDir := filepath.Join(accountsDir, accountName)
|
||||
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ package cmd
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
|
@ -12,6 +10,9 @@ import (
|
|||
"regexp"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"cursor-api-proxy/pkg/infrastructure/env"
|
||||
"cursor-api-proxy/pkg/usecase"
|
||||
)
|
||||
|
||||
var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`)
|
||||
|
|
@ -25,7 +26,7 @@ func HandleLogin(accountName string, proxies []string) error {
|
|||
accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000)
|
||||
}
|
||||
|
||||
accountsDir := agent.AccountsDir()
|
||||
accountsDir := usecase.AccountsDir()
|
||||
configDir := filepath.Join(accountsDir, accountName)
|
||||
dirWasNew := !fileExists(configDir)
|
||||
|
||||
|
|
@ -110,9 +111,9 @@ func HandleLogin(accountName string, proxies []string) error {
|
|||
}
|
||||
|
||||
// Cache keychain token for this account
|
||||
token := agent.ReadKeychainToken()
|
||||
token := usecase.ReadKeychainToken()
|
||||
if token != "" {
|
||||
agent.WriteCachedToken(configDir, token)
|
||||
usecase.WriteCachedToken(configDir, token)
|
||||
}
|
||||
|
||||
fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ package main
|
|||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"cursor-api-proxy/internal/providers/geminiweb"
|
||||
"cursor-api-proxy/pkg/infrastructure/env"
|
||||
"cursor-api-proxy/pkg/provider/geminiweb"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
|
|
|||
125
cmd/login.go
125
cmd/login.go
|
|
@ -1,125 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`)
|
||||
|
||||
func HandleLogin(accountName string, proxies []string) error {
|
||||
e := env.OsEnvToMap()
|
||||
loaded := env.LoadEnvConfig(e, "")
|
||||
agentBin := loaded.AgentBin
|
||||
|
||||
if accountName == "" {
|
||||
accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000)
|
||||
}
|
||||
|
||||
accountsDir := agent.AccountsDir()
|
||||
configDir := filepath.Join(accountsDir, accountName)
|
||||
dirWasNew := !fileExists(configDir)
|
||||
|
||||
if err := os.MkdirAll(accountsDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create accounts dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config dir: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Logging into Cursor account: %s\n", accountName)
|
||||
fmt.Printf("Config: %s\n\n", configDir)
|
||||
fmt.Println("Run the login command — complete the login in your browser.")
|
||||
fmt.Println("")
|
||||
|
||||
cleanupDir := func() {
|
||||
if dirWasNew {
|
||||
_ = os.RemoveAll(configDir)
|
||||
}
|
||||
}
|
||||
|
||||
cmdEnv := make([]string, 0, len(e)+2)
|
||||
for k, v := range e {
|
||||
cmdEnv = append(cmdEnv, k+"="+v)
|
||||
}
|
||||
cmdEnv = append(cmdEnv, "CURSOR_CONFIG_DIR="+configDir)
|
||||
cmdEnv = append(cmdEnv, "NO_OPEN_BROWSER=1")
|
||||
|
||||
child := exec.Command(agentBin, "login")
|
||||
child.Env = cmdEnv
|
||||
child.Stdin = os.Stdin
|
||||
child.Stderr = os.Stderr
|
||||
|
||||
stdoutPipe, err := child.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := child.Start(); err != nil {
|
||||
cleanupDir()
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("could not find '%s'. Make sure the Cursor CLI is installed", agentBin)
|
||||
}
|
||||
return fmt.Errorf("error launching agent login: %w", err)
|
||||
}
|
||||
|
||||
// Handle cancellation signals
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
go func() {
|
||||
sig := <-sigCh
|
||||
_ = child.Process.Kill()
|
||||
cleanupDir()
|
||||
if sig == syscall.SIGINT {
|
||||
fmt.Println("\n\nLogin cancelled.")
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
defer signal.Stop(sigCh)
|
||||
|
||||
var stdoutBuf string
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fmt.Println(line)
|
||||
stdoutBuf += line + "\n"
|
||||
|
||||
if loginURLRe.MatchString(stdoutBuf) {
|
||||
match := loginURLRe.FindString(stdoutBuf)
|
||||
if match != "" {
|
||||
fmt.Printf("\nOpen this URL in your browser (incognito recommended):\n%s\n\n", match)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := child.Wait(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
cleanupDir()
|
||||
return fmt.Errorf("login failed with code %d", exitErr.ExitCode())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache keychain token for this account
|
||||
token := agent.ReadKeychainToken()
|
||||
if token != "" {
|
||||
agent.WriteCachedToken(configDir, token)
|
||||
}
|
||||
|
||||
fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
261
cmd/resethwid.go
261
cmd/resethwid.go
|
|
@ -1,261 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func sha256hex() string {
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
h := sha256.Sum256(b)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func sha512hex() string {
|
||||
b := make([]byte, 64)
|
||||
_, _ = rand.Read(b)
|
||||
h := sha512.Sum512(b)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func newUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func log(icon, msg string) {
|
||||
fmt.Printf(" %s %s\n", icon, msg)
|
||||
}
|
||||
|
||||
func getCursorGlobalStorage() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, "Library", "Application Support", "Cursor", "User", "globalStorage")
|
||||
case "windows":
|
||||
appdata := os.Getenv("APPDATA")
|
||||
return filepath.Join(appdata, "Cursor", "User", "globalStorage")
|
||||
default:
|
||||
xdg := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdg == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
xdg = filepath.Join(home, ".config")
|
||||
}
|
||||
return filepath.Join(xdg, "Cursor", "User", "globalStorage")
|
||||
}
|
||||
}
|
||||
|
||||
func getCursorRoot() string {
|
||||
gs := getCursorGlobalStorage()
|
||||
return filepath.Dir(filepath.Dir(gs))
|
||||
}
|
||||
|
||||
func generateNewIDs() map[string]string {
|
||||
return map[string]string{
|
||||
"telemetry.machineId": sha256hex(),
|
||||
"telemetry.macMachineId": sha512hex(),
|
||||
"telemetry.devDeviceId": newUUID(),
|
||||
"telemetry.sqmId": "{" + fmt.Sprintf("%s", newUUID()+"") + "}",
|
||||
"storage.serviceMachineId": newUUID(),
|
||||
}
|
||||
}
|
||||
|
||||
func killCursor() {
|
||||
log("", "Stopping Cursor processes...")
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
exec.Command("taskkill", "/F", "/IM", "Cursor.exe").Run()
|
||||
default:
|
||||
exec.Command("pkill", "-x", "Cursor").Run()
|
||||
exec.Command("pkill", "-f", "Cursor.app").Run()
|
||||
}
|
||||
log("", "Cursor stopped (or was not running)")
|
||||
}
|
||||
|
||||
func updateStorageJSON(storagePath string, ids map[string]string) {
|
||||
if _, err := os.Stat(storagePath); os.IsNotExist(err) {
|
||||
log("", fmt.Sprintf("storage.json not found: %s", storagePath))
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
exec.Command("chflags", "nouchg", storagePath).Run()
|
||||
exec.Command("chmod", "644", storagePath).Run()
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(storagePath)
|
||||
if err != nil {
|
||||
log("", fmt.Sprintf("storage.json read error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal(data, &obj); err != nil {
|
||||
log("", fmt.Sprintf("storage.json parse error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range ids {
|
||||
obj[k] = v
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
log("", fmt.Sprintf("storage.json marshal error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile(storagePath, out, 0644); err != nil {
|
||||
log("", fmt.Sprintf("storage.json write error: %v", err))
|
||||
return
|
||||
}
|
||||
log("", "storage.json updated")
|
||||
}
|
||||
|
||||
func updateStateVscdb(dbPath string, ids map[string]string) {
|
||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||
log("", fmt.Sprintf("state.vscdb not found: %s", dbPath))
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
exec.Command("chflags", "nouchg", dbPath).Run()
|
||||
exec.Command("chmod", "644", dbPath).Run()
|
||||
}
|
||||
|
||||
if err := updateVscdbPureGo(dbPath, ids); err != nil {
|
||||
log("", fmt.Sprintf("state.vscdb error: %v", err))
|
||||
} else {
|
||||
log("", "state.vscdb updated")
|
||||
}
|
||||
}
|
||||
|
||||
func updateMachineIDFile(machineID, cursorRoot string) {
|
||||
var candidates []string
|
||||
if runtime.GOOS == "linux" {
|
||||
candidates = []string{
|
||||
filepath.Join(cursorRoot, "machineid"),
|
||||
filepath.Join(cursorRoot, "machineId"),
|
||||
}
|
||||
} else {
|
||||
candidates = []string{filepath.Join(cursorRoot, "machineId")}
|
||||
}
|
||||
|
||||
filePath := candidates[0]
|
||||
for _, c := range candidates {
|
||||
if _, err := os.Stat(c); err == nil {
|
||||
filePath = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
|
||||
log("", fmt.Sprintf("machineId dir error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
exec.Command("chflags", "nouchg", filePath).Run()
|
||||
exec.Command("chmod", "644", filePath).Run()
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filePath, []byte(machineID+"\n"), 0644); err != nil {
|
||||
log("", fmt.Sprintf("machineId write error: %v", err))
|
||||
return
|
||||
}
|
||||
log("", fmt.Sprintf("machineId file updated (%s)", filepath.Base(filePath)))
|
||||
}
|
||||
|
||||
var dirsToWipe = []string{
|
||||
"Session Storage", "Local Storage", "IndexedDB", "Cache", "Code Cache",
|
||||
"GPUCache", "Service Worker", "Network", "Cookies", "Cookies-journal",
|
||||
}
|
||||
|
||||
func deepClean(cursorRoot string) {
|
||||
log("", "Deep-cleaning session data...")
|
||||
wiped := 0
|
||||
for _, name := range dirsToWipe {
|
||||
target := filepath.Join(cursorRoot, name)
|
||||
if _, err := os.Stat(target); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(target)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
if err := os.RemoveAll(target); err == nil {
|
||||
wiped++
|
||||
}
|
||||
} else {
|
||||
if err := os.Remove(target); err == nil {
|
||||
wiped++
|
||||
}
|
||||
}
|
||||
}
|
||||
log("", fmt.Sprintf("Wiped %d cache/session items", wiped))
|
||||
}
|
||||
|
||||
func HandleResetHwid(doDeepClean, dryRun bool) error {
|
||||
fmt.Print("\nCursor HWID Reset\n\n")
|
||||
fmt.Println(" Resets all machine / telemetry IDs so Cursor sees a fresh install.")
|
||||
fmt.Print(" Cursor must be closed — it will be killed automatically.\n\n")
|
||||
|
||||
globalStorage := getCursorGlobalStorage()
|
||||
cursorRoot := getCursorRoot()
|
||||
|
||||
if _, err := os.Stat(globalStorage); os.IsNotExist(err) {
|
||||
fmt.Printf("Cursor config not found at:\n %s\n", globalStorage)
|
||||
fmt.Println(" Make sure Cursor is installed and has been run at least once.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
fmt.Println(" [DRY RUN] Would reset IDs in:")
|
||||
fmt.Printf(" %s\n", filepath.Join(globalStorage, "storage.json"))
|
||||
fmt.Printf(" %s\n", filepath.Join(globalStorage, "state.vscdb"))
|
||||
fmt.Printf(" %s\n", filepath.Join(cursorRoot, "machineId"))
|
||||
return nil
|
||||
}
|
||||
|
||||
killCursor()
|
||||
|
||||
time.Sleep(800 * time.Millisecond)
|
||||
|
||||
newIDs := generateNewIDs()
|
||||
log("", "Generated new IDs:")
|
||||
for k, v := range newIDs {
|
||||
fmt.Printf(" %s: %s\n", k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
log("", "Updating storage.json...")
|
||||
updateStorageJSON(filepath.Join(globalStorage, "storage.json"), newIDs)
|
||||
|
||||
log("", "Updating state.vscdb...")
|
||||
updateStateVscdb(filepath.Join(globalStorage, "state.vscdb"), newIDs)
|
||||
|
||||
log("", "Updating machineId file...")
|
||||
updateMachineIDFile(newIDs["telemetry.machineId"], cursorRoot)
|
||||
|
||||
if doDeepClean {
|
||||
fmt.Println()
|
||||
deepClean(cursorRoot)
|
||||
}
|
||||
|
||||
fmt.Print("\nHWID reset complete. You can now restart Cursor.\n\n")
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func updateVscdbPureGo(dbPath string, ids map[string]string) error {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS ItemTable (key TEXT PRIMARY KEY, value TEXT NOT NULL)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create table: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range ids {
|
||||
_, err = db.Exec(`INSERT OR REPLACE INTO ItemTable (key, value) VALUES (?, ?)`, k, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert %s: %w", k, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
255
cmd/usage.go
255
cmd/usage.go
|
|
@ -1,255 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ModelUsage struct {
|
||||
NumRequests int `json:"numRequests"`
|
||||
NumRequestsTotal int `json:"numRequestsTotal"`
|
||||
NumTokens int `json:"numTokens"`
|
||||
MaxTokenUsage *int `json:"maxTokenUsage"`
|
||||
MaxRequestUsage *int `json:"maxRequestUsage"`
|
||||
}
|
||||
|
||||
type UsageData struct {
|
||||
StartOfMonth string `json:"startOfMonth"`
|
||||
Models map[string]ModelUsage `json:"-"`
|
||||
}
|
||||
|
||||
type StripeProfile struct {
|
||||
MembershipType string `json:"membershipType"`
|
||||
SubscriptionStatus string `json:"subscriptionStatus"`
|
||||
DaysRemainingOnTrial *int `json:"daysRemainingOnTrial"`
|
||||
IsTeamMember bool `json:"isTeamMember"`
|
||||
IsYearlyPlan bool `json:"isYearlyPlan"`
|
||||
}
|
||||
|
||||
func DecodeJWTPayload(token string) map[string]interface{} {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil
|
||||
}
|
||||
padded := strings.ReplaceAll(parts[1], "-", "+")
|
||||
padded = strings.ReplaceAll(padded, "_", "/")
|
||||
data, err := base64.StdEncoding.DecodeString(padded + strings.Repeat("=", (4-len(padded)%4)%4))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TokenSub(token string) string {
|
||||
payload := DecodeJWTPayload(token)
|
||||
if payload == nil {
|
||||
return ""
|
||||
}
|
||||
if sub, ok := payload["sub"].(string); ok {
|
||||
return sub
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func apiGet(path, token string) (map[string]interface{}, error) {
|
||||
client := &http.Client{Timeout: 8 * time.Second}
|
||||
req, err := http.NewRequest("GET", "https://api2.cursor.sh"+path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func FetchAccountUsage(token string) (*UsageData, error) {
|
||||
raw, err := apiGet("/auth/usage", token)
|
||||
if err != nil || raw == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startOfMonth, _ := raw["startOfMonth"].(string)
|
||||
usage := &UsageData{
|
||||
StartOfMonth: startOfMonth,
|
||||
Models: make(map[string]ModelUsage),
|
||||
}
|
||||
|
||||
for k, v := range raw {
|
||||
if k == "startOfMonth" {
|
||||
continue
|
||||
}
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var mu ModelUsage
|
||||
if err := json.Unmarshal(data, &mu); err == nil {
|
||||
usage.Models[k] = mu
|
||||
}
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func FetchStripeProfile(token string) (*StripeProfile, error) {
|
||||
raw, err := apiGet("/auth/full_stripe_profile", token)
|
||||
if err != nil || raw == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profile := &StripeProfile{
|
||||
MembershipType: fmt.Sprintf("%v", raw["membershipType"]),
|
||||
SubscriptionStatus: fmt.Sprintf("%v", raw["subscriptionStatus"]),
|
||||
IsTeamMember: raw["isTeamMember"] == true,
|
||||
IsYearlyPlan: raw["isYearlyPlan"] == true,
|
||||
}
|
||||
if d, ok := raw["daysRemainingOnTrial"].(float64); ok {
|
||||
di := int(d)
|
||||
profile.DaysRemainingOnTrial = &di
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func DescribePlan(profile *StripeProfile) string {
|
||||
if profile == nil {
|
||||
return ""
|
||||
}
|
||||
switch profile.MembershipType {
|
||||
case "free_trial":
|
||||
days := 0
|
||||
if profile.DaysRemainingOnTrial != nil {
|
||||
days = *profile.DaysRemainingOnTrial
|
||||
}
|
||||
return fmt.Sprintf("Pro Trial (%dd left) — unlimited fast requests", days)
|
||||
case "pro":
|
||||
return "Pro — extended limits"
|
||||
case "pro_plus":
|
||||
return "Pro+ — extended limits"
|
||||
case "ultra":
|
||||
return "Ultra — extended limits"
|
||||
case "free", "hobby":
|
||||
return "Hobby (free) — limited agent requests"
|
||||
default:
|
||||
return fmt.Sprintf("%s · %s", profile.MembershipType, profile.SubscriptionStatus)
|
||||
}
|
||||
}
|
||||
|
||||
var modelLabels = map[string]string{
|
||||
"gpt-4": "Fast Premium Requests",
|
||||
"claude-sonnet-4-6": "Claude Sonnet 4.6",
|
||||
"claude-sonnet-4-5-20250929-v1": "Claude Sonnet 4.5",
|
||||
"claude-sonnet-4-20250514-v1": "Claude Sonnet 4",
|
||||
"claude-opus-4-6-v1": "Claude Opus 4.6",
|
||||
"claude-opus-4-5-20251101-v1": "Claude Opus 4.5",
|
||||
"claude-opus-4-1-20250805-v1": "Claude Opus 4.1",
|
||||
"claude-opus-4-20250514-v1": "Claude Opus 4",
|
||||
"claude-haiku-4-5-20251001-v1": "Claude Haiku 4.5",
|
||||
"claude-3-5-haiku-20241022-v1": "Claude 3.5 Haiku",
|
||||
"gpt-5": "GPT-5",
|
||||
"gpt-4o": "GPT-4o",
|
||||
"o1": "o1",
|
||||
"o3-mini": "o3-mini",
|
||||
"cursor-small": "Cursor Small (free)",
|
||||
}
|
||||
|
||||
func modelLabel(key string) string {
|
||||
if label, ok := modelLabels[key]; ok {
|
||||
return label
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func FormatUsageSummary(usage *UsageData) []string {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
var lines []string
|
||||
|
||||
start := "?"
|
||||
if usage.StartOfMonth != "" {
|
||||
if t, err := time.Parse(time.RFC3339, usage.StartOfMonth); err == nil {
|
||||
start = t.Format("2006-01-02")
|
||||
} else {
|
||||
start = usage.StartOfMonth
|
||||
}
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(" Billing period from %s", start))
|
||||
|
||||
if len(usage.Models) == 0 {
|
||||
lines = append(lines, " No requests this billing period")
|
||||
return lines
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
key string
|
||||
usage ModelUsage
|
||||
}
|
||||
var entries []entry
|
||||
for k, v := range usage.Models {
|
||||
entries = append(entries, entry{k, v})
|
||||
}
|
||||
|
||||
// Sort: entries with limits first, then by usage descending
|
||||
for i := 1; i < len(entries); i++ {
|
||||
for j := i; j > 0; j-- {
|
||||
a, b := entries[j-1], entries[j]
|
||||
aHasLimit := a.usage.MaxRequestUsage != nil
|
||||
bHasLimit := b.usage.MaxRequestUsage != nil
|
||||
if !aHasLimit && bHasLimit {
|
||||
entries[j-1], entries[j] = entries[j], entries[j-1]
|
||||
} else if aHasLimit == bHasLimit && a.usage.NumRequests < b.usage.NumRequests {
|
||||
entries[j-1], entries[j] = entries[j], entries[j-1]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
used := e.usage.NumRequests
|
||||
max := e.usage.MaxRequestUsage
|
||||
label := modelLabel(e.key)
|
||||
if max != nil && *max > 0 {
|
||||
pct := int(float64(used) / float64(*max) * 100)
|
||||
bar := makeBar(used, *max, 12)
|
||||
lines = append(lines, fmt.Sprintf(" %s: %d/%d (%d%%) [%s]", label, used, *max, pct, bar))
|
||||
} else if used > 0 {
|
||||
lines = append(lines, fmt.Sprintf(" %s: %d requests", label, used))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf(" %s: 0 requests (unlimited)", label))
|
||||
}
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
func makeBar(used, max, width int) string {
|
||||
fill := int(float64(used) / float64(max) * float64(width))
|
||||
if fill > width {
|
||||
fill = width
|
||||
}
|
||||
return strings.Repeat("█", fill) + strings.Repeat("░", width-fill)
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
package agent
|
||||
|
||||
import "cursor-api-proxy/internal/config"
|
||||
|
||||
func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, stream bool) []string {
|
||||
args := []string{"--print"}
|
||||
if cfg.ApproveMcps {
|
||||
args = append(args, "--approve-mcps")
|
||||
}
|
||||
if cfg.Force {
|
||||
args = append(args, "--force")
|
||||
}
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
args = append(args, "--trust")
|
||||
}
|
||||
args = append(args, "--workspace", workspaceDir)
|
||||
args = append(args, "--model", model)
|
||||
if stream {
|
||||
args = append(args, "--stream-partial-output", "--output-format", "stream-json")
|
||||
} else {
|
||||
args = append(args, "--output-format", "text")
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func BuildAgentCmdArgs(cfg config.BridgeConfig, workspaceDir, model, prompt string, stream bool) []string {
|
||||
return append(BuildAgentFixedArgs(cfg, workspaceDir, model, stream), prompt)
|
||||
}
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func getCandidates(agentScriptPath, configDirOverride string) []string {
|
||||
if configDirOverride != "" {
|
||||
return []string{filepath.Join(configDirOverride, "cli-config.json")}
|
||||
}
|
||||
|
||||
var result []string
|
||||
|
||||
if dir := os.Getenv("CURSOR_CONFIG_DIR"); dir != "" {
|
||||
result = append(result, filepath.Join(dir, "cli-config.json"))
|
||||
}
|
||||
|
||||
if agentScriptPath != "" {
|
||||
agentDir := filepath.Dir(agentScriptPath)
|
||||
result = append(result, filepath.Join(agentDir, "..", "data", "config", "cli-config.json"))
|
||||
}
|
||||
|
||||
home := os.Getenv("HOME")
|
||||
if home == "" {
|
||||
home = os.Getenv("USERPROFILE")
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
local := os.Getenv("LOCALAPPDATA")
|
||||
if local == "" {
|
||||
local = filepath.Join(home, "AppData", "Local")
|
||||
}
|
||||
result = append(result, filepath.Join(local, "cursor-agent", "cli-config.json"))
|
||||
case "darwin":
|
||||
result = append(result, filepath.Join(home, "Library", "Application Support", "cursor-agent", "cli-config.json"))
|
||||
default:
|
||||
xdg := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdg == "" {
|
||||
xdg = filepath.Join(home, ".config")
|
||||
}
|
||||
result = append(result, filepath.Join(xdg, "cursor-agent", "cli-config.json"))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func RunMaxModePreflight(agentScriptPath, configDirOverride string) {
|
||||
for _, candidate := range getCandidates(agentScriptPath, configDirOverride) {
|
||||
data, err := os.ReadFile(candidate)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Strip BOM if present
|
||||
if len(data) >= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF {
|
||||
data = data[3:]
|
||||
}
|
||||
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
continue
|
||||
}
|
||||
if raw == nil || len(raw) <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
raw["maxMode"] = true
|
||||
if model, ok := raw["model"].(map[string]interface{}); ok {
|
||||
model["maxMode"] = true
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(raw, "", " ")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := os.WriteFile(candidate, out, 0644); err != nil {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/process"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func init() {
|
||||
process.MaxModeFn = RunMaxModePreflight
|
||||
}
|
||||
|
||||
func cacheTokenForAccount(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
token := ReadKeychainToken()
|
||||
if token != "" {
|
||||
WriteCachedToken(configDir, token)
|
||||
}
|
||||
}
|
||||
|
||||
func AccountsDir() string {
|
||||
home := os.Getenv("HOME")
|
||||
if home == "" {
|
||||
home = os.Getenv("USERPROFILE")
|
||||
}
|
||||
return filepath.Join(home, ".cursor-api-proxy", "accounts")
|
||||
}
|
||||
|
||||
func RunAgentSync(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, tempDir, configDir string, ctx context.Context) (process.RunResult, error) {
|
||||
opts := process.RunOptions{
|
||||
Cwd: workspaceDir,
|
||||
TimeoutMs: cfg.TimeoutMs,
|
||||
MaxMode: cfg.MaxMode,
|
||||
ConfigDir: configDir,
|
||||
Ctx: ctx,
|
||||
}
|
||||
|
||||
result, err := process.Run(cfg.AgentBin, cmdArgs, opts)
|
||||
|
||||
cacheTokenForAccount(configDir)
|
||||
if tempDir != "" {
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func RunAgentStreamWithContext(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, onLine func(string), tempDir, configDir string, ctx context.Context) (process.StreamResult, error) {
|
||||
opts := process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{
|
||||
Cwd: workspaceDir,
|
||||
TimeoutMs: cfg.TimeoutMs,
|
||||
MaxMode: cfg.MaxMode,
|
||||
ConfigDir: configDir,
|
||||
Ctx: ctx,
|
||||
},
|
||||
OnLine: onLine,
|
||||
}
|
||||
|
||||
result, err := process.RunStreaming(cfg.AgentBin, cmdArgs, opts)
|
||||
|
||||
cacheTokenForAccount(configDir)
|
||||
if tempDir != "" {
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const tokenFile = ".cursor-token"
|
||||
|
||||
func ReadCachedToken(configDir string) string {
|
||||
p := filepath.Join(configDir, tokenFile)
|
||||
data, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
func WriteCachedToken(configDir, token string) {
|
||||
p := filepath.Join(configDir, tokenFile)
|
||||
_ = os.WriteFile(p, []byte(token), 0600)
|
||||
}
|
||||
|
||||
func ReadKeychainToken() string {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return ""
|
||||
}
|
||||
out, err := exec.Command("security", "find-generic-password", "-s", "cursor-access-token", "-w").Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
|
@ -1,174 +0,0 @@
|
|||
package anthropic
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/openai"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System interface{} `json:"system"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []interface{} `json:"tools"`
|
||||
}
|
||||
|
||||
func systemToText(system interface{}) string {
|
||||
if system == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if m, ok := p.(map[string]interface{}); ok {
|
||||
if m["type"] == "text" {
|
||||
if t, ok := m["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func anthropicBlockToText(p interface{}) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := p.(type) {
|
||||
case string:
|
||||
return v
|
||||
case map[string]interface{}:
|
||||
typ, _ := v["type"].(string)
|
||||
switch typ {
|
||||
case "text":
|
||||
if t, ok := v["text"].(string); ok {
|
||||
return t
|
||||
}
|
||||
case "image":
|
||||
if src, ok := v["source"].(map[string]interface{}); ok {
|
||||
srcType, _ := src["type"].(string)
|
||||
switch srcType {
|
||||
case "base64":
|
||||
mt, _ := src["media_type"].(string)
|
||||
if mt == "" {
|
||||
mt = "image"
|
||||
}
|
||||
return "[Image: base64 " + mt + "]"
|
||||
case "url":
|
||||
url, _ := src["url"].(string)
|
||||
return "[Image: " + url + "]"
|
||||
}
|
||||
}
|
||||
return "[Image]"
|
||||
case "document":
|
||||
title, _ := v["title"].(string)
|
||||
if title == "" {
|
||||
if src, ok := v["source"].(map[string]interface{}); ok {
|
||||
title, _ = src["url"].(string)
|
||||
}
|
||||
}
|
||||
if title != "" {
|
||||
return "[Document: " + title + "]"
|
||||
}
|
||||
return "[Document]"
|
||||
case "tool_use":
|
||||
name, _ := v["name"].(string)
|
||||
id, _ := v["id"].(string)
|
||||
input := v["input"]
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
if inputJSON == nil {
|
||||
inputJSON = []byte("{}")
|
||||
}
|
||||
tag := fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, string(inputJSON))
|
||||
if id != "" {
|
||||
tag = fmt.Sprintf("[tool_use_id=%s] ", id) + tag
|
||||
}
|
||||
return tag
|
||||
case "tool_result":
|
||||
toolUseID, _ := v["tool_use_id"].(string)
|
||||
content := v["content"]
|
||||
var contentText string
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
contentText = c
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, block := range c {
|
||||
if bm, ok := block.(map[string]interface{}); ok {
|
||||
if bm["type"] == "text" {
|
||||
if t, ok := bm["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
contentText = strings.Join(parts, "\n")
|
||||
}
|
||||
label := "Tool result"
|
||||
if toolUseID != "" {
|
||||
label += " [id=" + toolUseID + "]"
|
||||
}
|
||||
return label + ": " + contentText
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func anthropicContentToText(content interface{}) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if t := anthropicBlockToText(p); t != "" {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func BuildPromptFromAnthropicMessages(messages []MessageParam, system interface{}) string {
|
||||
var oaiMessages []interface{}
|
||||
|
||||
systemText := systemToText(system)
|
||||
if systemText != "" {
|
||||
oaiMessages = append(oaiMessages, map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": systemText,
|
||||
})
|
||||
}
|
||||
|
||||
for _, m := range messages {
|
||||
text := anthropicContentToText(m.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
role := m.Role
|
||||
if role != "user" && role != "assistant" {
|
||||
role = "user"
|
||||
}
|
||||
oaiMessages = append(oaiMessages, map[string]interface{}{
|
||||
"role": role,
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
|
||||
return openai.BuildPromptFromMessages(oaiMessages)
|
||||
}
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
package anthropic_test
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/anthropic"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_Simple(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "Hello") {
|
||||
t.Errorf("prompt missing user message: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "Hi there") {
|
||||
t.Errorf("prompt missing assistant message: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_WithSystem(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "ping"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, "You are a helpful bot.")
|
||||
if !strings.Contains(prompt, "You are a helpful bot.") {
|
||||
t.Errorf("prompt missing system: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "ping") {
|
||||
t.Errorf("prompt missing user: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_SystemArray(t *testing.T) {
|
||||
system := []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "Part A"},
|
||||
map[string]interface{}{"type": "text", "text": "Part B"},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "test"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, system)
|
||||
if !strings.Contains(prompt, "Part A") {
|
||||
t.Errorf("prompt missing Part A: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "Part B") {
|
||||
t.Errorf("prompt missing Part B: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_ContentBlocks(t *testing.T) {
|
||||
content := []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "block one"},
|
||||
map[string]interface{}{"type": "text", "text": "block two"},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: content},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "block one") {
|
||||
t.Errorf("prompt missing 'block one': %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "block two") {
|
||||
t.Errorf("prompt missing 'block two': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_ImageBlock(t *testing.T) {
|
||||
content := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "abc123",
|
||||
},
|
||||
},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: content},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "[Image") {
|
||||
t.Errorf("prompt missing [Image]: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_EmptyContentSkipped(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: ""},
|
||||
{Role: "assistant", Content: "response"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "response") {
|
||||
t.Errorf("prompt missing 'response': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_UnknownRoleBecomesUser(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "system", Content: "system-as-user"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "system-as-user") {
|
||||
t.Errorf("prompt missing 'system-as-user': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
package apitypes
|
||||
|
||||
type Message struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string
|
||||
Function ToolFunction
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters interface{}
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
type StreamChunk struct {
|
||||
Type ChunkType
|
||||
Text string
|
||||
Thinking string
|
||||
ToolCall *ToolCall
|
||||
Done bool
|
||||
}
|
||||
|
||||
type ChunkType int
|
||||
|
||||
const (
|
||||
ChunkText ChunkType = iota
|
||||
ChunkThinking
|
||||
ChunkToolCall
|
||||
ChunkDone
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/env"
|
||||
"cursor-api-proxy/pkg/infrastructure/env"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package config_test
|
|||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"cursor-api-proxy/pkg/infrastructure/env"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
|
|||
|
|
@ -1,381 +0,0 @@
|
|||
package env
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EnvSource map[string]string
|
||||
|
||||
type LoadedEnv struct {
|
||||
AgentBin string
|
||||
AgentNode string
|
||||
AgentScript string
|
||||
CommandShell string
|
||||
Host string
|
||||
Port int
|
||||
RequiredKey string
|
||||
DefaultModel string
|
||||
Provider string
|
||||
Force bool
|
||||
ApproveMcps bool
|
||||
StrictModel bool
|
||||
Workspace string
|
||||
TimeoutMs int
|
||||
TLSCertPath string
|
||||
TLSKeyPath string
|
||||
SessionsLogPath string
|
||||
ChatOnlyWorkspace bool
|
||||
Verbose bool
|
||||
MaxMode bool
|
||||
ConfigDirs []string
|
||||
MultiPort bool
|
||||
WinCmdlineMax int
|
||||
GeminiAccountDir string
|
||||
GeminiBrowserVisible bool
|
||||
GeminiMaxSessions int
|
||||
}
|
||||
|
||||
type AgentCommand struct {
|
||||
Command string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
WindowsVerbatimArguments bool
|
||||
AgentScriptPath string
|
||||
ConfigDir string
|
||||
}
|
||||
|
||||
func getEnvVal(e EnvSource, names []string) string {
|
||||
for _, name := range names {
|
||||
if v, ok := e[name]; ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func envBool(e EnvSource, names []string, def bool) bool {
|
||||
raw := getEnvVal(e, names)
|
||||
if raw == "" {
|
||||
return def
|
||||
}
|
||||
switch strings.ToLower(raw) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func envInt(e EnvSource, names []string, def int) int {
|
||||
raw := getEnvVal(e, names)
|
||||
if raw == "" {
|
||||
return def
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func normalizeModelId(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "auto"
|
||||
}
|
||||
parts := strings.Split(raw, "/")
|
||||
last := parts[len(parts)-1]
|
||||
if last == "" {
|
||||
return "auto"
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func resolveAbs(raw, cwd string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if filepath.IsAbs(raw) {
|
||||
return raw
|
||||
}
|
||||
return filepath.Join(cwd, raw)
|
||||
}
|
||||
|
||||
func isAuthenticatedAccountDir(dir string) bool {
|
||||
data, err := os.ReadFile(filepath.Join(dir, "cli-config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
var cfg struct {
|
||||
AuthInfo *struct {
|
||||
Email string `json:"email"`
|
||||
} `json:"authInfo"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
return cfg.AuthInfo != nil && cfg.AuthInfo.Email != ""
|
||||
}
|
||||
|
||||
func discoverAccountDirs(homeDir string) []string {
|
||||
if homeDir == "" {
|
||||
return nil
|
||||
}
|
||||
accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts")
|
||||
entries, err := os.ReadDir(accountsDir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var dirs []string
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
dir := filepath.Join(accountsDir, e.Name())
|
||||
if isAuthenticatedAccountDir(dir) {
|
||||
dirs = append(dirs, dir)
|
||||
}
|
||||
}
|
||||
return dirs
|
||||
}
|
||||
|
||||
func parseDotEnv(path string) EnvSource {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
m := make(EnvSource)
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
m[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func OsEnvToMap(cwdHint ...string) EnvSource {
|
||||
m := make(EnvSource)
|
||||
for _, kv := range os.Environ() {
|
||||
parts := strings.SplitN(kv, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
m[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
cwd := ""
|
||||
if len(cwdHint) > 0 && cwdHint[0] != "" {
|
||||
cwd = cwdHint[0]
|
||||
} else {
|
||||
cwd, _ = os.Getwd()
|
||||
}
|
||||
|
||||
if dotenv := parseDotEnv(filepath.Join(cwd, ".env")); dotenv != nil {
|
||||
for k, v := range dotenv {
|
||||
if _, exists := m[k]; !exists {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv {
|
||||
if e == nil {
|
||||
e = OsEnvToMap()
|
||||
}
|
||||
if cwd == "" {
|
||||
var err error
|
||||
cwd, err = os.Getwd()
|
||||
if err != nil {
|
||||
cwd = "."
|
||||
}
|
||||
}
|
||||
|
||||
host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"})
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765)
|
||||
if port <= 0 {
|
||||
port = 8765
|
||||
}
|
||||
|
||||
home := getEnvVal(e, []string{"HOME", "USERPROFILE"})
|
||||
|
||||
sessionsLogPath := func() string {
|
||||
if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" {
|
||||
return p
|
||||
}
|
||||
if home != "" {
|
||||
return filepath.Join(home, ".cursor-api-proxy", "sessions.log")
|
||||
}
|
||||
return filepath.Join(cwd, "sessions.log")
|
||||
}()
|
||||
|
||||
var configDirs []string
|
||||
if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" {
|
||||
for _, d := range strings.Split(raw, ",") {
|
||||
d = strings.TrimSpace(d)
|
||||
if d != "" {
|
||||
if p := resolveAbs(d, cwd); p != "" {
|
||||
configDirs = append(configDirs, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(configDirs) == 0 {
|
||||
configDirs = discoverAccountDirs(home)
|
||||
}
|
||||
|
||||
winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000)
|
||||
if winMax < 4096 {
|
||||
winMax = 4096
|
||||
}
|
||||
if winMax > 32700 {
|
||||
winMax = 32700
|
||||
}
|
||||
|
||||
agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"})
|
||||
if agentBin == "" {
|
||||
agentBin = "agent"
|
||||
}
|
||||
commandShell := getEnvVal(e, []string{"COMSPEC"})
|
||||
if commandShell == "" {
|
||||
commandShell = "cmd.exe"
|
||||
}
|
||||
workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd)
|
||||
if workspace == "" {
|
||||
workspace = cwd
|
||||
}
|
||||
|
||||
geminiAccountDir := getEnvVal(e, []string{"GEMINI_ACCOUNT_DIR"})
|
||||
if geminiAccountDir == "" {
|
||||
geminiAccountDir = filepath.Join(home, ".cursor-api-proxy", "gemini-accounts")
|
||||
} else {
|
||||
geminiAccountDir = resolveAbs(geminiAccountDir, cwd)
|
||||
}
|
||||
|
||||
return LoadedEnv{
|
||||
AgentBin: agentBin,
|
||||
AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}),
|
||||
AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}),
|
||||
CommandShell: commandShell,
|
||||
Host: host,
|
||||
Port: port,
|
||||
RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}),
|
||||
DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})),
|
||||
Provider: getEnvVal(e, []string{"CURSOR_BRIDGE_PROVIDER"}),
|
||||
Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false),
|
||||
ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false),
|
||||
StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true),
|
||||
Workspace: workspace,
|
||||
TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000),
|
||||
TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd),
|
||||
TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd),
|
||||
SessionsLogPath: sessionsLogPath,
|
||||
ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true),
|
||||
Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false),
|
||||
MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false),
|
||||
ConfigDirs: configDirs,
|
||||
MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false),
|
||||
WinCmdlineMax: winMax,
|
||||
GeminiAccountDir: geminiAccountDir,
|
||||
GeminiBrowserVisible: envBool(e, []string{"GEMINI_BROWSER_VISIBLE"}, false),
|
||||
GeminiMaxSessions: envInt(e, []string{"GEMINI_MAX_SESSIONS"}, 3),
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand {
|
||||
if e == nil {
|
||||
e = OsEnvToMap()
|
||||
}
|
||||
loaded := LoadEnvConfig(e, cwd)
|
||||
|
||||
cloneEnv := func() map[string]string {
|
||||
m := make(map[string]string, len(e))
|
||||
for k, v := range e {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if loaded.AgentNode != "" && loaded.AgentScript != "" {
|
||||
agentScriptPath := loaded.AgentScript
|
||||
if !filepath.IsAbs(agentScriptPath) {
|
||||
agentScriptPath = filepath.Join(cwd, agentScriptPath)
|
||||
}
|
||||
agentDir := filepath.Dir(agentScriptPath)
|
||||
configDir := filepath.Join(agentDir, "..", "data", "config")
|
||||
env2 := cloneEnv()
|
||||
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
|
||||
ac := AgentCommand{
|
||||
Command: loaded.AgentNode,
|
||||
Args: append([]string{loaded.AgentScript}, args...),
|
||||
Env: env2,
|
||||
AgentScriptPath: agentScriptPath,
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
|
||||
ac.ConfigDir = configDir
|
||||
}
|
||||
return ac
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.ToLower(cmd), ".cmd") {
|
||||
cmdResolved := cmd
|
||||
if !filepath.IsAbs(cmd) {
|
||||
cmdResolved = filepath.Join(cwd, cmd)
|
||||
}
|
||||
dir := filepath.Dir(cmdResolved)
|
||||
nodeBin := filepath.Join(dir, "node.exe")
|
||||
script := filepath.Join(dir, "index.js")
|
||||
if _, err1 := os.Stat(nodeBin); err1 == nil {
|
||||
if _, err2 := os.Stat(script); err2 == nil {
|
||||
configDir := filepath.Join(dir, "..", "data", "config")
|
||||
env2 := cloneEnv()
|
||||
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
|
||||
ac := AgentCommand{
|
||||
Command: nodeBin,
|
||||
Args: append([]string{script}, args...),
|
||||
Env: env2,
|
||||
AgentScriptPath: script,
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
|
||||
ac.ConfigDir = configDir
|
||||
}
|
||||
return ac
|
||||
}
|
||||
}
|
||||
|
||||
quotedArgs := make([]string, len(args))
|
||||
for i, a := range args {
|
||||
if strings.Contains(a, " ") {
|
||||
quotedArgs[i] = `"` + a + `"`
|
||||
} else {
|
||||
quotedArgs[i] = a
|
||||
}
|
||||
}
|
||||
cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"`
|
||||
return AgentCommand{
|
||||
Command: loaded.CommandShell,
|
||||
Args: []string{"/d", "/s", "/c", cmdLine},
|
||||
Env: cloneEnv(),
|
||||
WindowsVerbatimArguments: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()}
|
||||
}
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
package env
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLoadEnvConfigDefaults(t *testing.T) {
|
||||
e := EnvSource{}
|
||||
loaded := LoadEnvConfig(e, "/tmp")
|
||||
|
||||
if loaded.Host != "127.0.0.1" {
|
||||
t.Errorf("expected 127.0.0.1, got %s", loaded.Host)
|
||||
}
|
||||
if loaded.Port != 8765 {
|
||||
t.Errorf("expected 8765, got %d", loaded.Port)
|
||||
}
|
||||
if loaded.DefaultModel != "auto" {
|
||||
t.Errorf("expected auto, got %s", loaded.DefaultModel)
|
||||
}
|
||||
if loaded.AgentBin != "agent" {
|
||||
t.Errorf("expected agent, got %s", loaded.AgentBin)
|
||||
}
|
||||
if !loaded.StrictModel {
|
||||
t.Error("expected strictModel=true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEnvConfigOverride(t *testing.T) {
|
||||
e := EnvSource{
|
||||
"CURSOR_BRIDGE_HOST": "0.0.0.0",
|
||||
"CURSOR_BRIDGE_PORT": "9000",
|
||||
"CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4",
|
||||
"CURSOR_AGENT_BIN": "/usr/local/bin/agent",
|
||||
}
|
||||
loaded := LoadEnvConfig(e, "/tmp")
|
||||
|
||||
if loaded.Host != "0.0.0.0" {
|
||||
t.Errorf("expected 0.0.0.0, got %s", loaded.Host)
|
||||
}
|
||||
if loaded.Port != 9000 {
|
||||
t.Errorf("expected 9000, got %d", loaded.Port)
|
||||
}
|
||||
if loaded.DefaultModel != "gpt-4" {
|
||||
t.Errorf("expected gpt-4, got %s", loaded.DefaultModel)
|
||||
}
|
||||
if loaded.AgentBin != "/usr/local/bin/agent" {
|
||||
t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"gpt-4", "gpt-4"},
|
||||
{"openai/gpt-4", "gpt-4"},
|
||||
{"", "auto"},
|
||||
{" ", "auto"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := normalizeModelId(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,577 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/anthropic"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/models"
|
||||
"cursor-api-proxy/internal/openai"
|
||||
"cursor-api-proxy/internal/parser"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"cursor-api-proxy/internal/sanitize"
|
||||
"cursor-api-proxy/internal/toolcall"
|
||||
"cursor-api-proxy/internal/winlimit"
|
||||
"cursor-api-proxy/internal/workspace"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||||
var req anthropic.MessagesRequest
|
||||
if err := json.Unmarshal([]byte(rawBody), &req); err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
requested := openai.NormalizeModelID(req.Model)
|
||||
model := ResolveModel(requested, lastModelRef, cfg)
|
||||
|
||||
var rawMap map[string]interface{}
|
||||
_ = json.Unmarshal([]byte(rawBody), &rawMap)
|
||||
|
||||
cleanSystem := sanitize.SanitizeSystem(req.System)
|
||||
|
||||
rawMessages := make([]interface{}, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
||||
}
|
||||
cleanRawMessages := sanitize.SanitizeMessages(rawMessages)
|
||||
|
||||
var cleanMessages []anthropic.MessageParam
|
||||
for _, raw := range cleanRawMessages {
|
||||
if m, ok := raw.(map[string]interface{}); ok {
|
||||
role, _ := m["role"].(string)
|
||||
cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]})
|
||||
}
|
||||
}
|
||||
|
||||
toolsText := openai.ToolsToSystemText(req.Tools, nil)
|
||||
var systemWithTools interface{}
|
||||
if toolsText != "" {
|
||||
sysStr := ""
|
||||
switch v := cleanSystem.(type) {
|
||||
case string:
|
||||
sysStr = v
|
||||
}
|
||||
if sysStr != "" {
|
||||
systemWithTools = sysStr + "\n\n" + toolsText
|
||||
} else {
|
||||
systemWithTools = toolsText
|
||||
}
|
||||
} else {
|
||||
systemWithTools = cleanSystem
|
||||
}
|
||||
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools)
|
||||
|
||||
if req.MaxTokens == 0 {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
cursorModel := models.ResolveToCursorModel(model)
|
||||
if cursorModel == "" {
|
||||
cursorModel = model
|
||||
}
|
||||
|
||||
var trafficMsgs []logger.TrafficMessage
|
||||
if s := systemToString(cleanSystem); s != "" {
|
||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s})
|
||||
}
|
||||
for _, m := range cleanMessages {
|
||||
text := contentToString(m.Content)
|
||||
if text != "" {
|
||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text})
|
||||
}
|
||||
}
|
||||
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream)
|
||||
|
||||
headerWs := r.Header.Get("x-cursor-workspace")
|
||||
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
||||
|
||||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream)
|
||||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
||||
|
||||
if cfg.Verbose {
|
||||
if len(prompt) > 200 {
|
||||
logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...")
|
||||
} else {
|
||||
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt)
|
||||
}
|
||||
logger.LogDebug("cmd_args=%v", fit.Args)
|
||||
}
|
||||
|
||||
if !fit.OK {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": fit.Error},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
if fit.Truncated {
|
||||
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
|
||||
}
|
||||
|
||||
cmdArgs := fit.Args
|
||||
msgID := "msg_" + uuid.New().String()
|
||||
|
||||
var truncatedHeaders map[string]string
|
||||
if fit.Truncated {
|
||||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
||||
}
|
||||
|
||||
hasTools := len(req.Tools) > 0
|
||||
var toolNames map[string]bool
|
||||
if hasTools {
|
||||
toolNames = toolcall.CollectToolNames(req.Tools)
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
writeEvent := func(evt interface{}) {
|
||||
data, _ := json.Marshal(evt)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
var accumulated string
|
||||
var accumulatedThinking string
|
||||
var chunkNum int
|
||||
var p parser.Parser
|
||||
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
if hasTools {
|
||||
toolCallMarkerRe := regexp.MustCompile(`行政法规|<function_calls>`)
|
||||
var toolCallMode bool
|
||||
|
||||
textBlockOpen := false
|
||||
textBlockIndex := 0
|
||||
thinkingOpen := false
|
||||
thinkingBlockIndex := 0
|
||||
blockCount := 0
|
||||
|
||||
p = parser.CreateStreamParserWithThinking(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunkNum++
|
||||
logger.LogStreamChunk(model, text, chunkNum)
|
||||
|
||||
if toolCallMode {
|
||||
return
|
||||
}
|
||||
if toolCallMarkerRe.MatchString(text) {
|
||||
if textBlockOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
|
||||
textBlockOpen = false
|
||||
}
|
||||
if thinkingOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
|
||||
thinkingOpen = false
|
||||
}
|
||||
toolCallMode = true
|
||||
return
|
||||
}
|
||||
if !textBlockOpen && !thinkingOpen {
|
||||
textBlockIndex = blockCount
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": textBlockIndex,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
textBlockOpen = true
|
||||
blockCount++
|
||||
}
|
||||
if thinkingOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
|
||||
thinkingOpen = false
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": textBlockIndex,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
},
|
||||
func(thinking string) {
|
||||
accumulatedThinking += thinking
|
||||
chunkNum++
|
||||
if toolCallMode {
|
||||
return
|
||||
}
|
||||
if !thinkingOpen {
|
||||
thinkingBlockIndex = blockCount
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": thinkingBlockIndex,
|
||||
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
||||
})
|
||||
thinkingOpen = true
|
||||
blockCount++
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": thinkingBlockIndex,
|
||||
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
|
||||
})
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
|
||||
|
||||
blockIndex := 0
|
||||
if thinkingOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
|
||||
thinkingOpen = false
|
||||
}
|
||||
|
||||
if parsed.HasToolCalls() {
|
||||
if textBlockOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
|
||||
blockIndex = textBlockIndex + 1
|
||||
}
|
||||
if parsed.TextContent != "" && !textBlockOpen && !toolCallMode {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start", "index": blockIndex,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta", "index": blockIndex,
|
||||
"delta": map[string]string{"type": "text_delta", "text": parsed.TextContent},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||||
blockIndex++
|
||||
}
|
||||
for _, tc := range parsed.ToolCalls {
|
||||
toolID := "toolu_" + uuid.New().String()[:12]
|
||||
var inputObj interface{}
|
||||
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
|
||||
if inputObj == nil {
|
||||
inputObj = map[string]interface{}{}
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start", "index": blockIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta", "index": blockIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta", "partial_json": tc.Arguments,
|
||||
},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||||
blockIndex++
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": 0},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||||
} else {
|
||||
if textBlockOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
|
||||
} else if accumulated != "" {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start", "index": blockIndex,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta", "index": blockIndex,
|
||||
"delta": map[string]string{"type": "text_delta", "text": accumulated},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||||
blockIndex++
|
||||
} else {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start", "index": blockIndex,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||||
blockIndex++
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": 0},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||||
}
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// 非 tools 模式:即時串流 thinking 和 text
|
||||
blockCount := 0
|
||||
thinkingOpen := false
|
||||
textOpen := false
|
||||
|
||||
p = parser.CreateStreamParserWithThinking(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunkNum++
|
||||
logger.LogStreamChunk(model, text, chunkNum)
|
||||
if thinkingOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||||
thinkingOpen = false
|
||||
}
|
||||
if !textOpen {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": blockCount,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
textOpen = true
|
||||
blockCount++
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": blockCount - 1,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
},
|
||||
func(thinking string) {
|
||||
accumulatedThinking += thinking
|
||||
chunkNum++
|
||||
if !thinkingOpen {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": blockCount,
|
||||
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
||||
})
|
||||
thinkingOpen = true
|
||||
blockCount++
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": blockCount - 1,
|
||||
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
|
||||
})
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
if thinkingOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||||
thinkingOpen = false
|
||||
}
|
||||
if !textOpen {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": blockCount,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
blockCount++
|
||||
}
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": 0},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
configDir := ph.GetNextConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
ph.ReportRequestStart(configDir)
|
||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
||||
streamStart := time.Now().UnixMilli()
|
||||
|
||||
ctx := r.Context()
|
||||
wrappedParser := func(line string) {
|
||||
logger.LogRawLine(line)
|
||||
p.Parse(line)
|
||||
}
|
||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
|
||||
|
||||
if ctx.Err() == nil {
|
||||
p.Flush()
|
||||
}
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - streamStart
|
||||
ph.ReportRequestEnd(configDir)
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||||
} else if ctx.Err() == context.Canceled {
|
||||
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
||||
} else if err == nil && isRateLimited(result.Stderr) {
|
||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
||||
}
|
||||
|
||||
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
||||
ph.ReportRequestError(configDir, latencyMs)
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
} else {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
||||
}
|
||||
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
||||
} else if ctx.Err() == nil {
|
||||
ph.ReportRequestSuccess(configDir, latencyMs)
|
||||
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
||||
}
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
return
|
||||
}
|
||||
|
||||
configDir := ph.GetNextConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
ph.ReportRequestStart(configDir)
|
||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
|
||||
syncStart := time.Now().UnixMilli()
|
||||
|
||||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
||||
syncLatency := time.Now().UnixMilli() - syncStart
|
||||
ph.ReportRequestEnd(configDir)
|
||||
|
||||
ctx := r.Context()
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||||
httputil.WriteJSON(w, 504, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
if ctx.Err() == context.Canceled {
|
||||
logger.LogClientDisconnect(method, pathname, model, syncLatency)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
ph.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": err.Error()},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if isRateLimited(out.Stderr) {
|
||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
|
||||
}
|
||||
|
||||
if out.Code != 0 {
|
||||
ph.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": errMsg},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
ph.ReportRequestSuccess(configDir, syncLatency)
|
||||
content := strings.TrimSpace(out.Stdout)
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
|
||||
|
||||
if hasTools {
|
||||
parsed := toolcall.ExtractToolCalls(content, toolNames)
|
||||
if parsed.HasToolCalls() {
|
||||
var contentBlocks []map[string]interface{}
|
||||
if parsed.TextContent != "" {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text", "text": parsed.TextContent,
|
||||
})
|
||||
}
|
||||
for _, tc := range parsed.ToolCalls {
|
||||
toolID := "toolu_" + uuid.New().String()[:12]
|
||||
var inputObj interface{}
|
||||
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
|
||||
if inputObj == nil {
|
||||
inputObj = map[string]interface{}{}
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj,
|
||||
})
|
||||
}
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": contentBlocks,
|
||||
"model": model,
|
||||
"stop_reason": "tool_use",
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []map[string]string{{"type": "text", "text": content}},
|
||||
"model": model,
|
||||
"stop_reason": "end_turn",
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
}
|
||||
|
||||
func systemToString(system interface{}) string {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
result := ""
|
||||
for _, p := range v {
|
||||
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
|
||||
if t, ok := m["text"].(string); ok {
|
||||
result += t
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func contentToString(content interface{}) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
result := ""
|
||||
for _, p := range v {
|
||||
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
|
||||
if t, ok := m["text"].(string); ok {
|
||||
result += t
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
@ -1,471 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/models"
|
||||
"cursor-api-proxy/internal/openai"
|
||||
"cursor-api-proxy/internal/parser"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"cursor-api-proxy/internal/sanitize"
|
||||
"cursor-api-proxy/internal/toolcall"
|
||||
"cursor-api-proxy/internal/winlimit"
|
||||
"cursor-api-proxy/internal/workspace"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`)
|
||||
var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`)
|
||||
|
||||
func isRateLimited(stderr string) bool {
|
||||
return rateLimitRe.MatchString(stderr)
|
||||
}
|
||||
|
||||
func extractRetryAfterMs(stderr string) int64 {
|
||||
if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 {
|
||||
if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 {
|
||||
return secs * 1000
|
||||
}
|
||||
}
|
||||
return 60000
|
||||
}
|
||||
|
||||
func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||||
var bodyMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
rawModel, _ := bodyMap["model"].(string)
|
||||
requested := openai.NormalizeModelID(rawModel)
|
||||
model := ResolveModel(requested, lastModelRef, cfg)
|
||||
cursorModel := models.ResolveToCursorModel(model)
|
||||
if cursorModel == "" {
|
||||
cursorModel = model
|
||||
}
|
||||
|
||||
var messages []interface{}
|
||||
if m, ok := bodyMap["messages"].([]interface{}); ok {
|
||||
messages = m
|
||||
}
|
||||
|
||||
var tools []interface{}
|
||||
if t, ok := bodyMap["tools"].([]interface{}); ok {
|
||||
tools = t
|
||||
}
|
||||
var funcs []interface{}
|
||||
if f, ok := bodyMap["functions"].([]interface{}); ok {
|
||||
funcs = f
|
||||
}
|
||||
|
||||
cleanMessages := sanitize.SanitizeMessages(messages)
|
||||
|
||||
toolsText := openai.ToolsToSystemText(tools, funcs)
|
||||
messagesWithTools := cleanMessages
|
||||
if toolsText != "" {
|
||||
messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...)
|
||||
}
|
||||
prompt := openai.BuildPromptFromMessages(messagesWithTools)
|
||||
|
||||
var trafficMsgs []logger.TrafficMessage
|
||||
for _, raw := range cleanMessages {
|
||||
if m, ok := raw.(map[string]interface{}); ok {
|
||||
role, _ := m["role"].(string)
|
||||
content := openai.MessageContentToText(m["content"])
|
||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
|
||||
}
|
||||
}
|
||||
|
||||
isStream := false
|
||||
if s, ok := bodyMap["stream"].(bool); ok {
|
||||
isStream = s
|
||||
}
|
||||
|
||||
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream)
|
||||
|
||||
headerWs := r.Header.Get("x-cursor-workspace")
|
||||
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
||||
|
||||
promptLen := len(prompt)
|
||||
if cfg.Verbose {
|
||||
if promptLen > 200 {
|
||||
logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200])
|
||||
} else {
|
||||
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream)
|
||||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
||||
|
||||
if cfg.Verbose {
|
||||
logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args)
|
||||
}
|
||||
|
||||
if !fit.OK {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
if fit.Truncated {
|
||||
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
|
||||
}
|
||||
|
||||
cmdArgs := fit.Args
|
||||
id := "chatcmpl_" + uuid.New().String()
|
||||
created := time.Now().Unix()
|
||||
|
||||
var truncatedHeaders map[string]string
|
||||
if fit.Truncated {
|
||||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
||||
}
|
||||
|
||||
hasTools := len(tools) > 0 || len(funcs) > 0
|
||||
var toolNames map[string]bool
|
||||
if hasTools {
|
||||
toolNames = toolcall.CollectToolNames(tools)
|
||||
for _, f := range funcs {
|
||||
if fm, ok := f.(map[string]interface{}); ok {
|
||||
if name, ok := fm["name"].(string); ok {
|
||||
toolNames[name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isStream {
|
||||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
var accumulated string
|
||||
var chunkNum int
|
||||
var p parser.Parser
|
||||
|
||||
// toolCallMarkerRe 偵測 tool call 開頭標記,一旦出現就停止即時輸出並進入累積模式
|
||||
toolCallMarkerRe := regexp.MustCompile(`<tool_call>|<function_calls>`)
|
||||
if hasTools {
|
||||
var toolCallMode bool // 是否已進入 tool call 累積模式
|
||||
p = parser.CreateStreamParserWithThinking(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunkNum++
|
||||
logger.LogStreamChunk(model, text, chunkNum)
|
||||
if toolCallMode {
|
||||
// 已進入累積模式,不即時輸出
|
||||
return
|
||||
}
|
||||
if toolCallMarkerRe.MatchString(text) {
|
||||
// 偵測到 tool call 標記,切換為累積模式
|
||||
toolCallMode = true
|
||||
return
|
||||
}
|
||||
chunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
func(_ string) {}, // thinking ignored in tools mode
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
|
||||
|
||||
if parsed.HasToolCalls() {
|
||||
if parsed.TextContent != "" && toolCallMode {
|
||||
// 已有部分 text 被即時輸出,只補發剩餘的
|
||||
chunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
for i, tc := range parsed.ToolCalls {
|
||||
callID := "call_" + uuid.New().String()[:8]
|
||||
chunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{
|
||||
{
|
||||
"index": i,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": tc.Name,
|
||||
"arguments": tc.Arguments,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
} else {
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
} else {
|
||||
p = parser.CreateStreamParserWithThinking(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunkNum++
|
||||
logger.LogStreamChunk(model, text, chunkNum)
|
||||
chunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
func(thinking string) {
|
||||
chunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
configDir := ph.GetNextConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
ph.ReportRequestStart(configDir)
|
||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
||||
streamStart := time.Now().UnixMilli()
|
||||
|
||||
ctx := r.Context()
|
||||
wrappedParser := func(line string) {
|
||||
logger.LogRawLine(line)
|
||||
p.Parse(line)
|
||||
}
|
||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
|
||||
|
||||
// agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾
|
||||
if ctx.Err() == nil {
|
||||
p.Flush()
|
||||
}
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - streamStart
|
||||
ph.ReportRequestEnd(configDir)
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||||
} else if ctx.Err() == context.Canceled {
|
||||
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
||||
} else if err == nil && isRateLimited(result.Stderr) {
|
||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
||||
}
|
||||
|
||||
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
||||
ph.ReportRequestError(configDir, latencyMs)
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
} else {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
||||
}
|
||||
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
||||
} else if ctx.Err() == nil {
|
||||
ph.ReportRequestSuccess(configDir, latencyMs)
|
||||
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
||||
}
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
return
|
||||
}
|
||||
|
||||
configDir := ph.GetNextConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
ph.ReportRequestStart(configDir)
|
||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
|
||||
syncStart := time.Now().UnixMilli()
|
||||
|
||||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
||||
syncLatency := time.Now().UnixMilli() - syncStart
|
||||
ph.ReportRequestEnd(configDir)
|
||||
|
||||
ctx := r.Context()
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||||
httputil.WriteJSON(w, 504, map[string]interface{}{
|
||||
"error": map[string]string{"message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs), "code": "timeout"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
if ctx.Err() == context.Canceled {
|
||||
logger.LogClientDisconnect(method, pathname, model, syncLatency)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
ph.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if isRateLimited(out.Stderr) {
|
||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
|
||||
}
|
||||
|
||||
if out.Code != 0 {
|
||||
ph.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": errMsg, "code": "cursor_cli_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
ph.ReportRequestSuccess(configDir, syncLatency)
|
||||
content := strings.TrimSpace(out.Stdout)
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
|
||||
|
||||
if hasTools {
|
||||
parsed := toolcall.ExtractToolCalls(content, toolNames)
|
||||
if parsed.HasToolCalls() {
|
||||
msg := map[string]interface{}{"role": "assistant"}
|
||||
if parsed.TextContent != "" {
|
||||
msg["content"] = parsed.TextContent
|
||||
} else {
|
||||
msg["content"] = nil
|
||||
}
|
||||
var tcArr []map[string]interface{}
|
||||
for _, tc := range parsed.ToolCalls {
|
||||
callID := "call_" + uuid.New().String()[:8]
|
||||
tcArr = append(tcArr, map[string]interface{}{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": tc.Name,
|
||||
"arguments": tc.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
msg["tool_calls"] = tcArr
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "message": msg, "finish_reason": "tool_calls"},
|
||||
},
|
||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"message": map[string]string{"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
}
|
||||
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/apitypes"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/providers/geminiweb"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) {
|
||||
_ = context.Background() // 確保 context 被使用
|
||||
var bodyMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
rawModel, _ := bodyMap["model"].(string)
|
||||
if rawModel == "" {
|
||||
rawModel = "gemini-2.0-flash"
|
||||
}
|
||||
|
||||
var messages []interface{}
|
||||
if m, ok := bodyMap["messages"].([]interface{}); ok {
|
||||
messages = m
|
||||
}
|
||||
|
||||
isStream := false
|
||||
if s, ok := bodyMap["stream"].(bool); ok {
|
||||
isStream = s
|
||||
}
|
||||
|
||||
// 轉換 messages 為 apitypes.Message
|
||||
var apiMessages []apitypes.Message
|
||||
for _, m := range messages {
|
||||
if msgMap, ok := m.(map[string]interface{}); ok {
|
||||
role, _ := msgMap["role"].(string)
|
||||
content := ""
|
||||
if c, ok := msgMap["content"].(string); ok {
|
||||
content = c
|
||||
}
|
||||
apiMessages = append(apiMessages, apitypes.Message{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream)
|
||||
start := time.Now().UnixMilli()
|
||||
|
||||
// 創建 Gemini provider (使用 Playwright)
|
||||
provider, provErr := geminiweb.NewPlaywrightProvider(cfg)
|
||||
if provErr != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, provErr.Error())
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": provErr.Error(), "code": "provider_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if isStream {
|
||||
httputil.WriteSSEHeaders(w, nil)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
id := "chatcmpl_" + uuid.New().String()
|
||||
created := time.Now().Unix()
|
||||
var accumulated string
|
||||
|
||||
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
|
||||
if chunk.Type == apitypes.ChunkText {
|
||||
accumulated += chunk.Text
|
||||
respChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": rawModel,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]string{"content": chunk.Text},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(respChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
} else if chunk.Type == apitypes.ChunkThinking {
|
||||
respChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": rawModel,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{"reasoning_content": chunk.Thinking},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(respChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
} else if chunk.Type == apitypes.ChunkDone {
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": rawModel,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - start
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true)
|
||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
|
||||
return
|
||||
}
|
||||
|
||||
// 非串流模式
|
||||
var resultText string
|
||||
var resultThinking string
|
||||
|
||||
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
|
||||
if chunk.Type == apitypes.ChunkText {
|
||||
resultText += chunk.Text
|
||||
} else if chunk.Type == apitypes.ChunkThinking {
|
||||
resultThinking += chunk.Thinking
|
||||
}
|
||||
})
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - start
|
||||
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": err.Error(), "code": "gemini_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false)
|
||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
|
||||
|
||||
id := "chatcmpl_" + uuid.New().String()
|
||||
created := time.Now().Unix()
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": rawModel,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"message": map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": resultText,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, 200, resp, nil)
|
||||
}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func HandleHealth(w http.ResponseWriter, r *http.Request, version string, cfg config.BridgeConfig) {
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"ok": true,
|
||||
"version": version,
|
||||
"workspace": cfg.Workspace,
|
||||
"mode": cfg.Mode,
|
||||
"defaultModel": cfg.DefaultModel,
|
||||
"force": cfg.Force,
|
||||
"approveMcps": cfg.ApproveMcps,
|
||||
"strictModel": cfg.StrictModel,
|
||||
}, nil)
|
||||
}
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/models"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const modelCacheTTLMs = 5 * 60 * 1000
|
||||
|
||||
type ModelCache struct {
|
||||
At int64
|
||||
Models []models.CursorCliModel
|
||||
}
|
||||
|
||||
type ModelCacheRef struct {
|
||||
mu sync.Mutex
|
||||
cache *ModelCache
|
||||
inflight bool
|
||||
waiters []chan struct{}
|
||||
}
|
||||
|
||||
func (ref *ModelCacheRef) HandleModels(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig) {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
ref.mu.Lock()
|
||||
if ref.cache != nil && now-ref.cache.At <= modelCacheTTLMs {
|
||||
cache := ref.cache
|
||||
ref.mu.Unlock()
|
||||
writeModels(w, cache.Models)
|
||||
return
|
||||
}
|
||||
|
||||
if ref.inflight {
|
||||
// Wait for the in-flight fetch
|
||||
ch := make(chan struct{}, 1)
|
||||
ref.waiters = append(ref.waiters, ch)
|
||||
ref.mu.Unlock()
|
||||
<-ch
|
||||
ref.mu.Lock()
|
||||
cache := ref.cache
|
||||
ref.mu.Unlock()
|
||||
writeModels(w, cache.Models)
|
||||
return
|
||||
}
|
||||
|
||||
ref.inflight = true
|
||||
ref.mu.Unlock()
|
||||
|
||||
fetched, err := models.ListCursorCliModels(cfg.AgentBin, 60000)
|
||||
|
||||
ref.mu.Lock()
|
||||
ref.inflight = false
|
||||
if err == nil {
|
||||
ref.cache = &ModelCache{At: time.Now().UnixMilli(), Models: fetched}
|
||||
}
|
||||
waiters := ref.waiters
|
||||
ref.waiters = nil
|
||||
ref.mu.Unlock()
|
||||
|
||||
for _, ch := range waiters {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": err.Error(), "code": "models_fetch_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
writeModels(w, fetched)
|
||||
}
|
||||
|
||||
func writeModels(w http.ResponseWriter, mods []models.CursorCliModel) {
|
||||
cursorModels := make([]map[string]interface{}, len(mods))
|
||||
for i, m := range mods {
|
||||
cursorModels[i] = map[string]interface{}{
|
||||
"id": m.ID,
|
||||
"object": "model",
|
||||
"owned_by": "cursor",
|
||||
"name": m.Name,
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, len(mods))
|
||||
for i, m := range mods {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
aliases := models.GetAnthropicModelAliases(ids)
|
||||
for _, a := range aliases {
|
||||
cursorModels = append(cursorModels, map[string]interface{}{
|
||||
"id": a.ID,
|
||||
"object": "model",
|
||||
"owned_by": "cursor",
|
||||
"name": a.Name,
|
||||
})
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"object": "list",
|
||||
"data": cursorModels,
|
||||
}, nil)
|
||||
}
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import "cursor-api-proxy/internal/config"
|
||||
|
||||
func ResolveModel(requested string, lastModelRef *string, cfg config.BridgeConfig) string {
|
||||
isAuto := requested == "auto"
|
||||
var explicitModel string
|
||||
if requested != "" && !isAuto {
|
||||
explicitModel = requested
|
||||
}
|
||||
if explicitModel != "" {
|
||||
*lastModelRef = explicitModel
|
||||
}
|
||||
if isAuto {
|
||||
return "auto"
|
||||
}
|
||||
if explicitModel != "" {
|
||||
return explicitModel
|
||||
}
|
||||
if cfg.StrictModel && *lastModelRef != "" {
|
||||
return *lastModelRef
|
||||
}
|
||||
if *lastModelRef != "" {
|
||||
return *lastModelRef
|
||||
}
|
||||
return cfg.DefaultModel
|
||||
}
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`)
|
||||
|
||||
func ExtractBearerToken(r *http.Request) string {
|
||||
h := r.Header.Get("Authorization")
|
||||
if h == "" {
|
||||
return ""
|
||||
}
|
||||
m := bearerRe.FindStringSubmatch(h)
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m[1]
|
||||
}
|
||||
|
||||
func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) {
|
||||
for k, v := range extraHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) {
|
||||
for k, v := range extraHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
|
||||
func ReadBody(r *http.Request) (string, error) {
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractBearerToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{"Bearer mytoken123", "mytoken123"},
|
||||
{"bearer MYTOKEN", "MYTOKEN"},
|
||||
{"", ""},
|
||||
{"Basic abc", ""},
|
||||
{"Bearer ", ""},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.header != "" {
|
||||
req.Header.Set("Authorization", tc.header)
|
||||
}
|
||||
got := ExtractBearerToken(req)
|
||||
if got != tc.want {
|
||||
t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
WriteJSON(w, 200, map[string]string{"ok": "true"}, nil)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Content-Type") != "application/json" {
|
||||
t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSONWithExtraHeaders(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"})
|
||||
|
||||
if w.Header().Get("X-Custom") != "value" {
|
||||
t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom"))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,309 +0,0 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
cReset = "\x1b[0m"
|
||||
cBold = "\x1b[1m"
|
||||
cDim = "\x1b[2m"
|
||||
cCyan = "\x1b[36m"
|
||||
cBCyan = "\x1b[1;96m"
|
||||
cGreen = "\x1b[32m"
|
||||
cBGreen = "\x1b[1;92m"
|
||||
cYellow = "\x1b[33m"
|
||||
cMagenta = "\x1b[35m"
|
||||
cBMagenta = "\x1b[1;95m"
|
||||
cRed = "\x1b[31m"
|
||||
cGray = "\x1b[90m"
|
||||
cWhite = "\x1b[97m"
|
||||
)
|
||||
|
||||
var roleStyle = map[string]string{
|
||||
"system": cYellow,
|
||||
"user": cCyan,
|
||||
"assistant": cGreen,
|
||||
}
|
||||
|
||||
var roleEmoji = map[string]string{
|
||||
"system": "🔧",
|
||||
"user": "👤",
|
||||
"assistant": "🤖",
|
||||
}
|
||||
|
||||
func ts() string {
|
||||
return cGray + time.Now().UTC().Format("15:04:05") + cReset
|
||||
}
|
||||
|
||||
func tsDate() string {
|
||||
return cGray + time.Now().UTC().Format("2006-01-02 15:04:05") + cReset
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
head := int(float64(max) * 0.6)
|
||||
tail := max - head
|
||||
omitted := len(s) - head - tail
|
||||
return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset
|
||||
}
|
||||
|
||||
func hr(ch string, length int) string {
|
||||
return cGray + strings.Repeat(ch, length) + cReset
|
||||
}
|
||||
|
||||
type TrafficMessage struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
func LogDebug(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg)
|
||||
}
|
||||
|
||||
func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) {
|
||||
provider := cfg.Provider
|
||||
if provider == "" {
|
||||
provider = "cursor"
|
||||
}
|
||||
fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset)
|
||||
fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n",
|
||||
cBold, cBCyan, cReset, cBold, cWhite, version, cBCyan, cReset)
|
||||
fmt.Printf("%s%s╚══════════════════════════════════════════╝%s\n\n", cBold, cBCyan, cReset)
|
||||
url := fmt.Sprintf("%s://%s:%d", scheme, host, port)
|
||||
fmt.Printf(" %s●%s listening %s%s%s\n", cBGreen, cReset, cBold, url, cReset)
|
||||
fmt.Printf(" %s▸%s provider %s%s%s\n", cCyan, cReset, cBold, provider, cReset)
|
||||
fmt.Printf(" %s▸%s agent %s%s%s\n", cCyan, cReset, cDim, cfg.AgentBin, cReset)
|
||||
fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset)
|
||||
fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset)
|
||||
fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset)
|
||||
fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset)
|
||||
|
||||
// 顯示 Gemini Web Provider 相關設定
|
||||
if provider == "gemini-web" {
|
||||
fmt.Printf(" %s▸%s gemini-dir %s%s%s\n", cCyan, cReset, cDim, cfg.GeminiAccountDir, cReset)
|
||||
fmt.Printf(" %s▸%s max-sess %s%d%s\n", cCyan, cReset, cDim, cfg.GeminiMaxSessions, cReset)
|
||||
}
|
||||
|
||||
flags := []string{}
|
||||
if cfg.Force {
|
||||
flags = append(flags, "force")
|
||||
}
|
||||
if cfg.ApproveMcps {
|
||||
flags = append(flags, "approve-mcps")
|
||||
}
|
||||
if cfg.MaxMode {
|
||||
flags = append(flags, "max-mode")
|
||||
}
|
||||
if cfg.Verbose {
|
||||
flags = append(flags, "verbose")
|
||||
}
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
flags = append(flags, "chat-only")
|
||||
}
|
||||
if cfg.RequiredKey != "" {
|
||||
flags = append(flags, "api-key-required")
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
fmt.Printf(" %s▸%s flags %s%s%s\n", cCyan, cReset, cYellow, strings.Join(flags, " · "), cReset)
|
||||
}
|
||||
if len(cfg.ConfigDirs) > 0 {
|
||||
fmt.Printf(" %s▸%s pool %s%d accounts%s\n", cCyan, cReset, cBGreen, len(cfg.ConfigDirs), cReset)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func LogShutdown(sig string) {
|
||||
fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset)
|
||||
}
|
||||
|
||||
func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) {
|
||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
|
||||
}
|
||||
fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n",
|
||||
ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag)
|
||||
}
|
||||
|
||||
func LogRequestDone(method, pathname, model string, latencyMs int64, code int) {
|
||||
statusColor := cBGreen
|
||||
if code != 0 {
|
||||
statusColor = cRed
|
||||
}
|
||||
fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n",
|
||||
ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset)
|
||||
}
|
||||
|
||||
func LogRequestTimeout(method, pathname, model string, timeoutMs int) {
|
||||
fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n",
|
||||
ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset)
|
||||
}
|
||||
|
||||
func LogClientDisconnect(method, pathname, model string, latencyMs int64) {
|
||||
fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n",
|
||||
ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset)
|
||||
}
|
||||
|
||||
func LogStreamChunk(model string, text string, chunkNum int) {
|
||||
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120)
|
||||
fmt.Printf("%s %s▸%s #%d %s%s%s\n",
|
||||
ts(), cDim, cReset, chunkNum, cWhite, preview, cReset)
|
||||
}
|
||||
|
||||
func LogRawLine(line string) {
|
||||
preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200)
|
||||
fmt.Printf("%s %s│%s %sraw%s %s\n",
|
||||
ts(), cGray, cReset, cDim, cReset, preview)
|
||||
}
|
||||
|
||||
func LogIncoming(method, pathname, remoteAddress string) {
|
||||
methodColor := cBCyan
|
||||
switch method {
|
||||
case "POST":
|
||||
methodColor = cBMagenta
|
||||
case "GET":
|
||||
methodColor = cBCyan
|
||||
case "DELETE":
|
||||
methodColor = cRed
|
||||
}
|
||||
fmt.Printf("%s %s%s%s%s %s%s%s %s(%s)%s\n",
|
||||
ts(),
|
||||
methodColor, cBold, method, cReset,
|
||||
cWhite, pathname, cReset,
|
||||
cDim, remoteAddress, cReset,
|
||||
)
|
||||
}
|
||||
|
||||
func LogAccountAssigned(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
name := filepath.Base(configDir)
|
||||
fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset)
|
||||
}
|
||||
|
||||
func LogAccountStats(verbose bool, stats []pool.AccountStat) {
|
||||
if !verbose || len(stats) == 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset)
|
||||
for _, s := range stats {
|
||||
name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir))
|
||||
active := fmt.Sprintf("%sactive:0%s", cDim, cReset)
|
||||
if s.ActiveRequests > 0 {
|
||||
active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset)
|
||||
}
|
||||
total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset)
|
||||
ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset)
|
||||
errStr := fmt.Sprintf("%serr:0%s", cDim, cReset)
|
||||
if s.TotalErrors > 0 {
|
||||
errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset)
|
||||
}
|
||||
rl := fmt.Sprintf("%srl:0%s", cDim, cReset)
|
||||
if s.TotalRateLimits > 0 {
|
||||
rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset)
|
||||
}
|
||||
avg := "avg:-"
|
||||
if s.TotalRequests > 0 {
|
||||
avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests))
|
||||
}
|
||||
status := fmt.Sprintf("%s✓%s", cGreen, cReset)
|
||||
if s.IsRateLimited {
|
||||
recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339)
|
||||
_ = now
|
||||
status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset)
|
||||
}
|
||||
fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n",
|
||||
cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status)
|
||||
}
|
||||
fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset)
|
||||
}
|
||||
|
||||
func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) {
|
||||
if !verbose {
|
||||
return
|
||||
}
|
||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
|
||||
}
|
||||
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
|
||||
fmt.Println(hr("─", 60))
|
||||
fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag)
|
||||
for _, m := range messages {
|
||||
roleColor := cWhite
|
||||
if c, ok := roleStyle[m.Role]; ok {
|
||||
roleColor = c
|
||||
}
|
||||
emoji := "💬"
|
||||
if e, ok := roleEmoji[m.Role]; ok {
|
||||
emoji = e
|
||||
}
|
||||
label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset)
|
||||
charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset)
|
||||
preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280)
|
||||
fmt.Printf(" %s %s %s\n", emoji, label, charCount)
|
||||
fmt.Printf(" %s%s%s\n", cDim, preview, cReset)
|
||||
}
|
||||
}
|
||||
|
||||
func LogTrafficResponse(verbose bool, model, text string, isStream bool) {
|
||||
if !verbose {
|
||||
return
|
||||
}
|
||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset)
|
||||
}
|
||||
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
|
||||
charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset)
|
||||
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480)
|
||||
fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount)
|
||||
fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset)
|
||||
fmt.Println(hr("─", 60))
|
||||
}
|
||||
|
||||
func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) {
|
||||
line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode)
|
||||
dir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(dir, 0755); err == nil {
|
||||
f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err == nil {
|
||||
_, _ = f.WriteString(line)
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func LogTruncation(originalLen, finalLen int) {
|
||||
fmt.Printf("%s %s⚠ prompt truncated%s %s(%d → %d chars, tail preserved)%s\n",
|
||||
ts(), cYellow, cReset, cDim, originalLen, finalLen, cReset)
|
||||
}
|
||||
|
||||
func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string {
|
||||
errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr))
|
||||
fmt.Fprintf(os.Stderr, "%s %s✗ agent error%s %s%s%s\n", ts(), cRed, cReset, cDim, errMsg, cReset)
|
||||
truncated := strings.TrimSpace(stderr)
|
||||
if len(truncated) > 200 {
|
||||
truncated = truncated[:200]
|
||||
}
|
||||
truncated = strings.ReplaceAll(truncated, "\n", " ")
|
||||
line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n",
|
||||
time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated)
|
||||
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
|
||||
_, _ = f.WriteString(line)
|
||||
f.Close()
|
||||
}
|
||||
return errMsg
|
||||
}
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/process"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CursorCliModel struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
var modelLineRe = regexp.MustCompile(`^([A-Za-z0-9][A-Za-z0-9._:/-]*)\s+-\s+(.*)$`)
|
||||
var trailingParenRe = regexp.MustCompile(`\s*\([^)]*\)\s*$`)
|
||||
|
||||
func ParseCursorCliModels(output string) []CursorCliModel {
|
||||
lines := strings.Split(output, "\n")
|
||||
seen := make(map[string]CursorCliModel)
|
||||
var order []string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
m := modelLineRe.FindStringSubmatch(line)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
id := m[1]
|
||||
rawName := m[2]
|
||||
name := strings.TrimSpace(trailingParenRe.ReplaceAllString(rawName, ""))
|
||||
if name == "" {
|
||||
name = id
|
||||
}
|
||||
if _, exists := seen[id]; !exists {
|
||||
seen[id] = CursorCliModel{ID: id, Name: name}
|
||||
order = append(order, id)
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]CursorCliModel, 0, len(order))
|
||||
for _, id := range order {
|
||||
result = append(result, seen[id])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func ListCursorCliModels(agentBin string, timeoutMs int) ([]CursorCliModel, error) {
|
||||
tmpDir := os.TempDir()
|
||||
result, err := process.Run(agentBin, []string{"--list-models"}, process.RunOptions{
|
||||
Cwd: tmpDir,
|
||||
TimeoutMs: timeoutMs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("agent --list-models failed: %s", strings.TrimSpace(result.Stderr))
|
||||
}
|
||||
return ParseCursorCliModels(result.Stdout), nil
|
||||
}
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseCursorCliModels(t *testing.T) {
|
||||
output := `
|
||||
gpt-4o - GPT-4o (some info)
|
||||
claude-3-5-sonnet - Claude 3.5 Sonnet
|
||||
gpt-4o - GPT-4o duplicate
|
||||
invalid line without dash
|
||||
`
|
||||
result := ParseCursorCliModels(output)
|
||||
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 unique models, got %d: %v", len(result), result)
|
||||
}
|
||||
if result[0].ID != "gpt-4o" {
|
||||
t.Errorf("expected gpt-4o, got %s", result[0].ID)
|
||||
}
|
||||
if result[0].Name != "GPT-4o" {
|
||||
t.Errorf("expected 'GPT-4o', got %s", result[0].Name)
|
||||
}
|
||||
if result[1].ID != "claude-3-5-sonnet" {
|
||||
t.Errorf("expected claude-3-5-sonnet, got %s", result[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCursorCliModelsEmpty(t *testing.T) {
|
||||
result := ParseCursorCliModels("")
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("expected empty, got %v", result)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var anthropicToCursor = map[string]string{
|
||||
"claude-opus-4-6": "opus-4.6",
|
||||
"claude-opus-4.6": "opus-4.6",
|
||||
"claude-sonnet-4-6": "sonnet-4.6",
|
||||
"claude-sonnet-4.6": "sonnet-4.6",
|
||||
"claude-opus-4-5": "opus-4.5",
|
||||
"claude-opus-4.5": "opus-4.5",
|
||||
"claude-sonnet-4-5": "sonnet-4.5",
|
||||
"claude-sonnet-4.5": "sonnet-4.5",
|
||||
"claude-opus-4": "opus-4.6",
|
||||
"claude-sonnet-4": "sonnet-4.6",
|
||||
"claude-haiku-4-5-20251001": "sonnet-4.5",
|
||||
"claude-haiku-4-5": "sonnet-4.5",
|
||||
"claude-haiku-4-6": "sonnet-4.6",
|
||||
"claude-haiku-4": "sonnet-4.5",
|
||||
"claude-opus-4-6-thinking": "opus-4.6-thinking",
|
||||
"claude-sonnet-4-6-thinking": "sonnet-4.6-thinking",
|
||||
"claude-opus-4-5-thinking": "opus-4.5-thinking",
|
||||
"claude-sonnet-4-5-thinking": "sonnet-4.5-thinking",
|
||||
}
|
||||
|
||||
type ModelAlias struct {
|
||||
CursorID string
|
||||
AnthropicID string
|
||||
Name string
|
||||
}
|
||||
|
||||
var cursorToAnthropicAlias = []ModelAlias{
|
||||
{"opus-4.6", "claude-opus-4-6", "Claude 4.6 Opus"},
|
||||
{"opus-4.6-thinking", "claude-opus-4-6-thinking", "Claude 4.6 Opus (Thinking)"},
|
||||
{"sonnet-4.6", "claude-sonnet-4-6", "Claude 4.6 Sonnet"},
|
||||
{"sonnet-4.6-thinking", "claude-sonnet-4-6-thinking", "Claude 4.6 Sonnet (Thinking)"},
|
||||
{"opus-4.5", "claude-opus-4-5", "Claude 4.5 Opus"},
|
||||
{"opus-4.5-thinking", "claude-opus-4-5-thinking", "Claude 4.5 Opus (Thinking)"},
|
||||
{"sonnet-4.5", "claude-sonnet-4-5", "Claude 4.5 Sonnet"},
|
||||
{"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"},
|
||||
}
|
||||
|
||||
// cursorModelPattern matches cursor model IDs like "opus-4.6", "sonnet-4.7-thinking".
|
||||
var cursorModelPattern = regexp.MustCompile(`^([a-zA-Z]+)-(\d+)\.(\d+)(-thinking)?$`)
|
||||
|
||||
// reverseDynamicPattern matches dynamically generated anthropic aliases
|
||||
// like "claude-opus-4-7", "claude-sonnet-4-7-thinking".
|
||||
var reverseDynamicPattern = regexp.MustCompile(`^claude-([a-zA-Z]+)-(\d+)-(\d+)(-thinking)?$`)
|
||||
|
||||
func generateDynamicAlias(cursorID string) (AnthropicAlias, bool) {
|
||||
m := cursorModelPattern.FindStringSubmatch(cursorID)
|
||||
if m == nil {
|
||||
return AnthropicAlias{}, false
|
||||
}
|
||||
family := m[1]
|
||||
major := m[2]
|
||||
minor := m[3]
|
||||
thinking := m[4]
|
||||
|
||||
anthropicID := "claude-" + family + "-" + major + "-" + minor + thinking
|
||||
capFamily := strings.ToUpper(family[:1]) + family[1:]
|
||||
name := capFamily + " " + major + "." + minor
|
||||
if thinking == "-thinking" {
|
||||
name += " (Thinking)"
|
||||
}
|
||||
return AnthropicAlias{ID: anthropicID, Name: name}, true
|
||||
}
|
||||
|
||||
func reverseDynamicAlias(anthropicID string) (string, bool) {
|
||||
m := reverseDynamicPattern.FindStringSubmatch(anthropicID)
|
||||
if m == nil {
|
||||
return "", false
|
||||
}
|
||||
return m[1] + "-" + m[2] + "." + m[3] + m[4], true
|
||||
}
|
||||
|
||||
func ResolveToCursorModel(requested string) string {
|
||||
if strings.TrimSpace(requested) == "" {
|
||||
return ""
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(requested))
|
||||
if v, ok := anthropicToCursor[key]; ok {
|
||||
return v
|
||||
}
|
||||
if v, ok := reverseDynamicAlias(key); ok {
|
||||
return v
|
||||
}
|
||||
return strings.TrimSpace(requested)
|
||||
}
|
||||
|
||||
type AnthropicAlias struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func GetAnthropicModelAliases(availableCursorIDs []string) []AnthropicAlias {
|
||||
set := make(map[string]bool, len(availableCursorIDs))
|
||||
for _, id := range availableCursorIDs {
|
||||
set[id] = true
|
||||
}
|
||||
|
||||
staticSet := make(map[string]bool, len(cursorToAnthropicAlias))
|
||||
var result []AnthropicAlias
|
||||
for _, a := range cursorToAnthropicAlias {
|
||||
if set[a.CursorID] {
|
||||
staticSet[a.CursorID] = true
|
||||
result = append(result, AnthropicAlias{ID: a.AnthropicID, Name: a.Name})
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range availableCursorIDs {
|
||||
if staticSet[id] {
|
||||
continue
|
||||
}
|
||||
if alias, ok := generateDynamicAlias(id); ok {
|
||||
result = append(result, alias)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
@ -1,163 +0,0 @@
|
|||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetAnthropicModelAliases_StaticOnly(t *testing.T) {
|
||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.5"})
|
||||
if len(aliases) != 2 {
|
||||
t.Fatalf("expected 2 aliases, got %d: %v", len(aliases), aliases)
|
||||
}
|
||||
ids := map[string]string{}
|
||||
for _, a := range aliases {
|
||||
ids[a.ID] = a.Name
|
||||
}
|
||||
if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" {
|
||||
t.Errorf("unexpected name for claude-sonnet-4-6: %s", ids["claude-sonnet-4-6"])
|
||||
}
|
||||
if ids["claude-opus-4-5"] != "Claude 4.5 Opus" {
|
||||
t.Errorf("unexpected name for claude-opus-4-5: %s", ids["claude-opus-4-5"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnthropicModelAliases_DynamicFallback(t *testing.T) {
|
||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.7", "opus-5.0-thinking", "gpt-4o"})
|
||||
ids := map[string]string{}
|
||||
for _, a := range aliases {
|
||||
ids[a.ID] = a.Name
|
||||
}
|
||||
if ids["claude-sonnet-4-7"] != "Sonnet 4.7" {
|
||||
t.Errorf("unexpected name for claude-sonnet-4-7: %s", ids["claude-sonnet-4-7"])
|
||||
}
|
||||
if ids["claude-opus-5-0-thinking"] != "Opus 5.0 (Thinking)" {
|
||||
t.Errorf("unexpected name for claude-opus-5-0-thinking: %s", ids["claude-opus-5-0-thinking"])
|
||||
}
|
||||
if _, ok := ids["claude-gpt-4o"]; ok {
|
||||
t.Errorf("gpt-4o should not generate a claude alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnthropicModelAliases_Mixed(t *testing.T) {
|
||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.7", "gpt-4o"})
|
||||
ids := map[string]string{}
|
||||
for _, a := range aliases {
|
||||
ids[a.ID] = a.Name
|
||||
}
|
||||
// static entry keeps its custom name
|
||||
if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" {
|
||||
t.Errorf("static alias should keep original name, got: %s", ids["claude-sonnet-4-6"])
|
||||
}
|
||||
// dynamic entry uses auto-generated name
|
||||
if ids["claude-opus-4-7"] != "Opus 4.7" {
|
||||
t.Errorf("dynamic alias name mismatch: %s", ids["claude-opus-4-7"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnthropicModelAliases_UnknownPattern(t *testing.T) {
|
||||
aliases := GetAnthropicModelAliases([]string{"some-unknown-model"})
|
||||
if len(aliases) != 0 {
|
||||
t.Fatalf("expected 0 aliases for unknown pattern, got %d: %v", len(aliases), aliases)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToCursorModel_Static(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"claude-opus-4-6", "opus-4.6"},
|
||||
{"claude-opus-4.6", "opus-4.6"},
|
||||
{"claude-sonnet-4-5", "sonnet-4.5"},
|
||||
{"claude-opus-4-6-thinking", "opus-4.6-thinking"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := ResolveToCursorModel(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToCursorModel_DynamicFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"claude-opus-4-7", "opus-4.7"},
|
||||
{"claude-sonnet-5-0", "sonnet-5.0"},
|
||||
{"claude-opus-4-7-thinking", "opus-4.7-thinking"},
|
||||
{"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := ResolveToCursorModel(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToCursorModel_Passthrough(t *testing.T) {
|
||||
tests := []string{"sonnet-4.6", "gpt-4o", "custom-model"}
|
||||
for _, input := range tests {
|
||||
got := ResolveToCursorModel(input)
|
||||
if got != input {
|
||||
t.Errorf("ResolveToCursorModel(%q) = %q, want passthrough %q", input, got, input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToCursorModel_Empty(t *testing.T) {
|
||||
if got := ResolveToCursorModel(""); got != "" {
|
||||
t.Errorf("ResolveToCursorModel(\"\") = %q, want empty", got)
|
||||
}
|
||||
if got := ResolveToCursorModel(" "); got != "" {
|
||||
t.Errorf("ResolveToCursorModel(\" \") = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDynamicAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want AnthropicAlias
|
||||
ok bool
|
||||
}{
|
||||
{"opus-4.7", AnthropicAlias{"claude-opus-4-7", "Opus 4.7"}, true},
|
||||
{"sonnet-5.0-thinking", AnthropicAlias{"claude-sonnet-5-0-thinking", "Sonnet 5.0 (Thinking)"}, true},
|
||||
{"gpt-4o", AnthropicAlias{}, false},
|
||||
{"invalid", AnthropicAlias{}, false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got, ok := generateDynamicAlias(tc.input)
|
||||
if ok != tc.ok {
|
||||
t.Errorf("generateDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok)
|
||||
continue
|
||||
}
|
||||
if ok && (got.ID != tc.want.ID || got.Name != tc.want.Name) {
|
||||
t.Errorf("generateDynamicAlias(%q) = {%q, %q}, want {%q, %q}", tc.input, got.ID, got.Name, tc.want.ID, tc.want.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseDynamicAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{"claude-opus-4-7", "opus-4.7", true},
|
||||
{"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking", true},
|
||||
{"claude-opus-4-6", "opus-4.6", true},
|
||||
{"claude-opus-4.6", "", false},
|
||||
{"claude-haiku-4-5-20251001", "", false},
|
||||
{"some-model", "", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got, ok := reverseDynamicAlias(tc.input)
|
||||
if ok != tc.ok {
|
||||
t.Errorf("reverseDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok)
|
||||
continue
|
||||
}
|
||||
if ok && got != tc.want {
|
||||
t.Errorf("reverseDynamicAlias(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,243 +0,0 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []interface{} `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []interface{} `json:"tools"`
|
||||
ToolChoice interface{} `json:"tool_choice"`
|
||||
Functions []interface{} `json:"functions"`
|
||||
FunctionCall interface{} `json:"function_call"`
|
||||
}
|
||||
|
||||
func NormalizeModelID(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(trimmed, "/")
|
||||
last := parts[len(parts)-1]
|
||||
if last == "" {
|
||||
return ""
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func imageURLToText(imageURL interface{}) string {
|
||||
if imageURL == nil {
|
||||
return "[Image]"
|
||||
}
|
||||
var url string
|
||||
switch v := imageURL.(type) {
|
||||
case string:
|
||||
url = v
|
||||
case map[string]interface{}:
|
||||
if u, ok := v["url"].(string); ok {
|
||||
url = u
|
||||
}
|
||||
}
|
||||
if url == "" {
|
||||
return "[Image]"
|
||||
}
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
end := strings.Index(url, ";")
|
||||
mime := "image"
|
||||
if end > 5 {
|
||||
mime = url[5:end]
|
||||
}
|
||||
return "[Image: base64 " + mime + "]"
|
||||
}
|
||||
return "[Image: " + url + "]"
|
||||
}
|
||||
|
||||
func MessageContentToText(content interface{}) string {
|
||||
if content == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
switch part := p.(type) {
|
||||
case string:
|
||||
parts = append(parts, part)
|
||||
case map[string]interface{}:
|
||||
typ, _ := part["type"].(string)
|
||||
switch typ {
|
||||
case "text":
|
||||
if t, ok := part["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
case "image_url":
|
||||
parts = append(parts, imageURLToText(part["image_url"]))
|
||||
case "image":
|
||||
src := part["source"]
|
||||
if src == nil {
|
||||
src = part["url"]
|
||||
}
|
||||
parts = append(parts, imageURLToText(src))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func ToolsToSystemText(tools []interface{}, functions []interface{}) string {
|
||||
var defs []interface{}
|
||||
|
||||
for _, t := range tools {
|
||||
if m, ok := t.(map[string]interface{}); ok {
|
||||
if m["type"] == "function" {
|
||||
if fn := m["function"]; fn != nil {
|
||||
defs = append(defs, fn)
|
||||
}
|
||||
} else {
|
||||
defs = append(defs, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
defs = append(defs, functions...)
|
||||
|
||||
if len(defs) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, "Available tools (respond with a JSON object to call one):", "")
|
||||
|
||||
for _, raw := range defs {
|
||||
fn, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
params := "{}"
|
||||
if p := fn["parameters"]; p != nil {
|
||||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
||||
params = string(b)
|
||||
}
|
||||
} else if p := fn["input_schema"]; p != nil {
|
||||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
||||
params = string(b)
|
||||
}
|
||||
}
|
||||
lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params)
|
||||
}
|
||||
|
||||
lines = append(lines, "",
|
||||
"When you want to call a tool, use this EXACT format:",
|
||||
"",
|
||||
"<tool_call>",
|
||||
`{"name": "function_name", "arguments": {"param1": "value1"}}`,
|
||||
"</tool_call>",
|
||||
"",
|
||||
"Rules:",
|
||||
"- Write your reasoning BEFORE the tool call",
|
||||
"- You may make multiple tool calls by using multiple <tool_call> blocks",
|
||||
"- STOP writing after the last </tool_call> tag",
|
||||
"- If no tool is needed, respond normally without <tool_call> tags",
|
||||
)
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
type SimpleMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func BuildPromptFromMessages(messages []interface{}) string {
|
||||
var systemParts []string
|
||||
var convo []string
|
||||
|
||||
for _, raw := range messages {
|
||||
m, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := m["role"].(string)
|
||||
text := MessageContentToText(m["content"])
|
||||
|
||||
switch role {
|
||||
case "system", "developer":
|
||||
if text != "" {
|
||||
systemParts = append(systemParts, text)
|
||||
}
|
||||
case "user":
|
||||
if text != "" {
|
||||
convo = append(convo, "User: "+text)
|
||||
}
|
||||
case "assistant":
|
||||
toolCalls, _ := m["tool_calls"].([]interface{})
|
||||
if len(toolCalls) > 0 {
|
||||
var parts []string
|
||||
if text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
tcMap, ok := tc.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tcMap["function"].(map[string]interface{})
|
||||
if fn == nil {
|
||||
continue
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
args, _ := fn["arguments"].(string)
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, args))
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
convo = append(convo, "Assistant: "+strings.Join(parts, "\n"))
|
||||
}
|
||||
} else if text != "" {
|
||||
convo = append(convo, "Assistant: "+text)
|
||||
}
|
||||
case "tool", "function":
|
||||
name, _ := m["name"].(string)
|
||||
toolCallID, _ := m["tool_call_id"].(string)
|
||||
label := "Tool result"
|
||||
if name != "" {
|
||||
label = "Tool result (" + name + ")"
|
||||
}
|
||||
if toolCallID != "" {
|
||||
label += " [id=" + toolCallID + "]"
|
||||
}
|
||||
if text != "" {
|
||||
convo = append(convo, label+": "+text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
system := ""
|
||||
if len(systemParts) > 0 {
|
||||
system = "System:\n" + strings.Join(systemParts, "\n\n") + "\n\n"
|
||||
}
|
||||
transcript := strings.Join(convo, "\n\n")
|
||||
return system + transcript + "\n\nAssistant:"
|
||||
}
|
||||
|
||||
func BuildPromptFromSimpleMessages(messages []SimpleMessage) string {
|
||||
ifaces := make([]interface{}, len(messages))
|
||||
for i, m := range messages {
|
||||
ifaces[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
||||
}
|
||||
return BuildPromptFromMessages(ifaces)
|
||||
}
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"gpt-4", "gpt-4"},
|
||||
{"openai/gpt-4", "gpt-4"},
|
||||
{"anthropic/claude-3", "claude-3"},
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"a/b/c", "c"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := NormalizeModelID(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("NormalizeModelID(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromMessages(t *testing.T) {
|
||||
messages := []interface{}{
|
||||
map[string]interface{}{"role": "system", "content": "You are helpful."},
|
||||
map[string]interface{}{"role": "user", "content": "Hello"},
|
||||
map[string]interface{}{"role": "assistant", "content": "Hi there"},
|
||||
}
|
||||
got := BuildPromptFromMessages(messages)
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
containsSystem := false
|
||||
containsUser := false
|
||||
containsAssistant := false
|
||||
for i := 0; i < len(got)-10; i++ {
|
||||
if got[i:i+6] == "System" {
|
||||
containsSystem = true
|
||||
}
|
||||
if got[i:i+4] == "User" {
|
||||
containsUser = true
|
||||
}
|
||||
if got[i:i+9] == "Assistant" {
|
||||
containsAssistant = true
|
||||
}
|
||||
}
|
||||
if !containsSystem || !containsUser || !containsAssistant {
|
||||
t.Errorf("prompt missing sections: system=%v user=%v assistant=%v\n%s",
|
||||
containsSystem, containsUser, containsAssistant, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolsToSystemText(t *testing.T) {
|
||||
tools := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": map[string]interface{}{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := ToolsToSystemText(tools, nil)
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty tools text")
|
||||
}
|
||||
if len(got) < 10 {
|
||||
t.Errorf("tools text too short: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolsToSystemTextEmpty(t *testing.T) {
|
||||
got := ToolsToSystemText(nil, nil)
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for no tools, got %q", got)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,110 +0,0 @@
|
|||
package parser
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type StreamParser func(line string)
|
||||
|
||||
type Parser struct {
|
||||
Parse StreamParser
|
||||
Flush func()
|
||||
}
|
||||
|
||||
// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking)
|
||||
func CreateStreamParser(onText func(string), onDone func()) Parser {
|
||||
return CreateStreamParserWithThinking(onText, nil, onDone)
|
||||
}
|
||||
|
||||
// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。
|
||||
// onThinking 可為 nil,表示忽略思考過程。
|
||||
func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser {
|
||||
// accumulated 是所有已輸出內容的串接
|
||||
accumulatedText := ""
|
||||
accumulatedThinking := ""
|
||||
done := false
|
||||
|
||||
parse := func(line string) {
|
||||
if done {
|
||||
return
|
||||
}
|
||||
|
||||
var obj struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
Message *struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Thinking string `json:"thinking"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(line), &obj); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if obj.Type == "assistant" && obj.Message != nil {
|
||||
fullText := ""
|
||||
fullThinking := ""
|
||||
for _, p := range obj.Message.Content {
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
fullText += p.Text
|
||||
}
|
||||
case "thinking":
|
||||
if p.Thinking != "" {
|
||||
fullThinking += p.Thinking
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 處理思考過程(不因去重而 return,避免跳過同行的文字內容)
|
||||
if onThinking != nil && fullThinking != "" && fullThinking != accumulatedThinking {
|
||||
// 增量模式:新內容以 accumulated 為前綴
|
||||
if len(fullThinking) >= len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking {
|
||||
delta := fullThinking[len(accumulatedThinking):]
|
||||
if delta != "" {
|
||||
onThinking(delta)
|
||||
}
|
||||
accumulatedThinking = fullThinking
|
||||
} else {
|
||||
// 獨立片段:直接輸出,但 accumulated 要串接
|
||||
onThinking(fullThinking)
|
||||
accumulatedThinking = accumulatedThinking + fullThinking
|
||||
}
|
||||
}
|
||||
|
||||
// 處理一般文字
|
||||
if fullText == "" || fullText == accumulatedText {
|
||||
return
|
||||
}
|
||||
// 增量模式:新內容以 accumulated 為前綴
|
||||
if len(fullText) >= len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText {
|
||||
delta := fullText[len(accumulatedText):]
|
||||
if delta != "" {
|
||||
onText(delta)
|
||||
}
|
||||
accumulatedText = fullText
|
||||
} else {
|
||||
// 獨立片段:直接輸出,但 accumulated 要串接
|
||||
onText(fullText)
|
||||
accumulatedText = accumulatedText + fullText
|
||||
}
|
||||
}
|
||||
|
||||
if obj.Type == "result" && obj.Subtype == "success" {
|
||||
done = true
|
||||
onDone()
|
||||
}
|
||||
}
|
||||
|
||||
flush := func() {
|
||||
if !done {
|
||||
done = true
|
||||
onDone()
|
||||
}
|
||||
}
|
||||
|
||||
return Parser{Parse: parse, Flush: flush}
|
||||
}
|
||||
|
|
@ -1,304 +0,0 @@
|
|||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeAssistantLine(text string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func makeResultLine() string {
|
||||
b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"})
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestStreamParserFragmentMode(t *testing.T) {
|
||||
// cursor --stream-partial-output 模式:每個訊息是獨立 token fragment
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeAssistantLine("你"))
|
||||
p.Parse(makeAssistantLine("好!有"))
|
||||
p.Parse(makeAssistantLine("什"))
|
||||
p.Parse(makeAssistantLine("麼"))
|
||||
|
||||
if len(texts) != 4 {
|
||||
t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts)
|
||||
}
|
||||
if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" {
|
||||
t.Fatalf("unexpected texts: %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserDeduplicatesFinalFullText(t *testing.T) {
|
||||
// 最後一個訊息是完整的累積文字,應被跳過(去重)
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeAssistantLine("Hello"))
|
||||
p.Parse(makeAssistantLine(" world"))
|
||||
// 最後一個是完整累積文字,應被去重
|
||||
p.Parse(makeAssistantLine("Hello world"))
|
||||
|
||||
if len(texts) != 2 {
|
||||
t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts)
|
||||
}
|
||||
if texts[0] != "Hello" || texts[1] != " world" {
|
||||
t.Fatalf("unexpected texts: %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserCallsOnDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeResultLine())
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no text, got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresLinesAfterDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeResultLine())
|
||||
p.Parse(makeAssistantLine("late"))
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no text after done, got %v", texts)
|
||||
}
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresNonAssistantLines(t *testing.T) {
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}})
|
||||
p.Parse(string(b1))
|
||||
b2, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{"content": []interface{}{}},
|
||||
})
|
||||
p.Parse(string(b2))
|
||||
p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`)
|
||||
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no texts, got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresParseErrors(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse("not json")
|
||||
p.Parse("{")
|
||||
p.Parse("")
|
||||
|
||||
if len(texts) != 0 || doneCount != 0 {
|
||||
t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserJoinsMultipleTextParts(t *testing.T) {
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": " "},
|
||||
{"type": "text", "text": "world"},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
p.Parse(string(b))
|
||||
|
||||
if len(texts) != 1 || texts[0] != "Hello world" {
|
||||
t.Fatalf("expected ['Hello world'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserFlushTriggersDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeAssistantLine("Hello"))
|
||||
// agent 結束但沒有 result/success,手動 flush
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once after Flush, got %d", doneCount)
|
||||
}
|
||||
// 再 flush 不應重複觸發
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called only once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) {
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) {},
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeResultLine())
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func makeThinkingLine(thinking string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "thinking", "thinking": thinking},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func makeThinkingAndTextLine(thinking, text string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "thinking", "thinking": thinking},
|
||||
{"type": "text", "text": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) {
|
||||
var texts []string
|
||||
var thinkings []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("思考中..."))
|
||||
p.Parse(makeAssistantLine("回答"))
|
||||
|
||||
if len(thinkings) != 1 || thinkings[0] != "思考中..." {
|
||||
t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings)
|
||||
}
|
||||
if len(texts) != 1 || texts[0] != "回答" {
|
||||
t.Fatalf("expected texts=['回答'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) {
|
||||
var texts []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
nil,
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("忽略的思考"))
|
||||
p.Parse(makeAssistantLine("文字"))
|
||||
|
||||
if len(texts) != 1 || texts[0] != "文字" {
|
||||
t.Fatalf("expected texts=['文字'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingDeduplication(t *testing.T) {
|
||||
var thinkings []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) {},
|
||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("A"))
|
||||
p.Parse(makeThinkingLine("B"))
|
||||
// 重複的完整思考,應被跳過
|
||||
p.Parse(makeThinkingLine("AB"))
|
||||
|
||||
if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" {
|
||||
t.Fatalf("expected thinkings=['A','B'], got %v", thinkings)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamParserThinkingDuplicateButTextStillEmitted 驗證 bug 修復:
|
||||
// 當 thinking 重複(去重跳過)但同一行有 text 時,text 仍必須輸出。
|
||||
func TestStreamParserThinkingDuplicateButTextStillEmitted(t *testing.T) {
|
||||
var texts []string
|
||||
var thinkings []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
// 第一行:thinking="思考中" + text(thinking 為新增,兩者都應輸出)
|
||||
p.Parse(makeThinkingAndTextLine("思考中", "第一段"))
|
||||
// 第二行:thinking 與上一行相同(去重),但 text 是新的,text 仍應輸出
|
||||
p.Parse(makeThinkingAndTextLine("思考中", "第二段"))
|
||||
|
||||
if len(thinkings) != 1 || thinkings[0] != "思考中" {
|
||||
t.Fatalf("expected thinkings=['思考中'], got %v", thinkings)
|
||||
}
|
||||
if len(texts) != 2 || texts[0] != "第一段" || texts[1] != "第二段" {
|
||||
t.Fatalf("expected texts=['第一段','第二段'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,284 +0,0 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type accountStatus struct {
|
||||
configDir string
|
||||
activeRequests int
|
||||
lastUsed int64
|
||||
rateLimitUntil int64
|
||||
totalRequests int
|
||||
totalSuccess int
|
||||
totalErrors int
|
||||
totalRateLimits int
|
||||
totalLatencyMs int64
|
||||
}
|
||||
|
||||
type AccountStat struct {
|
||||
ConfigDir string
|
||||
ActiveRequests int
|
||||
TotalRequests int
|
||||
TotalSuccess int
|
||||
TotalErrors int
|
||||
TotalRateLimits int
|
||||
TotalLatencyMs int64
|
||||
IsRateLimited bool
|
||||
RateLimitUntil int64
|
||||
}
|
||||
|
||||
type AccountPool struct {
|
||||
mu sync.Mutex
|
||||
accounts []*accountStatus
|
||||
}
|
||||
|
||||
func NewAccountPool(configDirs []string) *AccountPool {
|
||||
accounts := make([]*accountStatus, 0, len(configDirs))
|
||||
for _, dir := range configDirs {
|
||||
accounts = append(accounts, &accountStatus{configDir: dir})
|
||||
}
|
||||
return &AccountPool{accounts: accounts}
|
||||
}
|
||||
|
||||
func (p *AccountPool) GetNextConfigDir() string {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.accounts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
available := make([]*accountStatus, 0, len(p.accounts))
|
||||
for _, a := range p.accounts {
|
||||
if a.rateLimitUntil < now {
|
||||
available = append(available, a)
|
||||
}
|
||||
}
|
||||
|
||||
target := available
|
||||
if len(target) == 0 {
|
||||
target = make([]*accountStatus, len(p.accounts))
|
||||
copy(target, p.accounts)
|
||||
// sort by earliest recovery
|
||||
for i := 1; i < len(target); i++ {
|
||||
for j := i; j > 0 && target[j].rateLimitUntil < target[j-1].rateLimitUntil; j-- {
|
||||
target[j], target[j-1] = target[j-1], target[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pick least busy then least recently used
|
||||
best := target[0]
|
||||
for _, a := range target[1:] {
|
||||
if a.activeRequests < best.activeRequests {
|
||||
best = a
|
||||
} else if a.activeRequests == best.activeRequests && a.lastUsed < best.lastUsed {
|
||||
best = a
|
||||
}
|
||||
}
|
||||
best.lastUsed = now
|
||||
return best.configDir
|
||||
}
|
||||
|
||||
func (p *AccountPool) find(configDir string) *accountStatus {
|
||||
for _, a := range p.accounts {
|
||||
if a.configDir == configDir {
|
||||
return a
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRequestStart(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil {
|
||||
a.activeRequests++
|
||||
a.totalRequests++
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRequestEnd(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil && a.activeRequests > 0 {
|
||||
a.activeRequests--
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRequestSuccess(configDir string, latencyMs int64) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil {
|
||||
a.totalSuccess++
|
||||
a.totalLatencyMs += latencyMs
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRequestError(configDir string, latencyMs int64) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil {
|
||||
a.totalErrors++
|
||||
a.totalLatencyMs += latencyMs
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) ReportRateLimit(configDir string, penaltyMs int64) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
if penaltyMs <= 0 {
|
||||
penaltyMs = 60000
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if a := p.find(configDir); a != nil {
|
||||
a.rateLimitUntil = time.Now().UnixMilli() + penaltyMs
|
||||
a.totalRateLimits++
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountPool) GetStats() []AccountStat {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
now := time.Now().UnixMilli()
|
||||
stats := make([]AccountStat, len(p.accounts))
|
||||
for i, a := range p.accounts {
|
||||
stats[i] = AccountStat{
|
||||
ConfigDir: a.configDir,
|
||||
ActiveRequests: a.activeRequests,
|
||||
TotalRequests: a.totalRequests,
|
||||
TotalSuccess: a.totalSuccess,
|
||||
TotalErrors: a.totalErrors,
|
||||
TotalRateLimits: a.totalRateLimits,
|
||||
TotalLatencyMs: a.totalLatencyMs,
|
||||
IsRateLimited: a.rateLimitUntil > now,
|
||||
RateLimitUntil: a.rateLimitUntil,
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
func (p *AccountPool) Count() int {
|
||||
return len(p.accounts)
|
||||
}
|
||||
|
||||
|
||||
// ─── PoolHandle interface ──────────────────────────────────────────────────
|
||||
// PoolHandle 讓 handler 可以注入獨立的 pool 實例,避免多 port 模式共用全域 pool。
|
||||
|
||||
type PoolHandle interface {
|
||||
GetNextConfigDir() string
|
||||
ReportRequestStart(configDir string)
|
||||
ReportRequestEnd(configDir string)
|
||||
ReportRequestSuccess(configDir string, latencyMs int64)
|
||||
ReportRequestError(configDir string, latencyMs int64)
|
||||
ReportRateLimit(configDir string, penaltyMs int64)
|
||||
GetStats() []AccountStat
|
||||
}
|
||||
|
||||
// GlobalPoolHandle 包裝全域函式以實作 PoolHandle 介面(單 port 模式使用)
|
||||
type GlobalPoolHandle struct{}
|
||||
|
||||
func (GlobalPoolHandle) GetNextConfigDir() string { return GetNextAccountConfigDir() }
|
||||
func (GlobalPoolHandle) ReportRequestStart(d string) { ReportRequestStart(d) }
|
||||
func (GlobalPoolHandle) ReportRequestEnd(d string) { ReportRequestEnd(d) }
|
||||
func (GlobalPoolHandle) ReportRequestSuccess(d string, l int64) { ReportRequestSuccess(d, l) }
|
||||
func (GlobalPoolHandle) ReportRequestError(d string, l int64) { ReportRequestError(d, l) }
|
||||
func (GlobalPoolHandle) ReportRateLimit(d string, p int64) { ReportRateLimit(d, p) }
|
||||
func (GlobalPoolHandle) GetStats() []AccountStat { return GetAccountStats() }
|
||||
|
||||
// ─── Global pool ───────────────────────────────────────────────────────────
|
||||
|
||||
var (
|
||||
globalPool *AccountPool
|
||||
globalMu sync.Mutex
|
||||
)
|
||||
|
||||
func InitAccountPool(configDirs []string) {
|
||||
globalMu.Lock()
|
||||
defer globalMu.Unlock()
|
||||
globalPool = NewAccountPool(configDirs)
|
||||
}
|
||||
|
||||
func GetNextAccountConfigDir() string {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return p.GetNextConfigDir()
|
||||
}
|
||||
|
||||
func ReportRequestStart(configDir string) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRequestStart(configDir)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportRequestEnd(configDir string) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRequestEnd(configDir)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportRequestSuccess(configDir string, latencyMs int64) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRequestSuccess(configDir, latencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportRequestError(configDir string, latencyMs int64) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRequestError(configDir, latencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportRateLimit(configDir string, penaltyMs int64) {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p != nil {
|
||||
p.ReportRateLimit(configDir, penaltyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func GetAccountStats() []AccountStat {
|
||||
globalMu.Lock()
|
||||
p := globalPool
|
||||
globalMu.Unlock()
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return p.GetStats()
|
||||
}
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEmptyPool(t *testing.T) {
|
||||
p := NewAccountPool(nil)
|
||||
if got := p.GetNextConfigDir(); got != "" {
|
||||
t.Fatalf("expected empty string for empty pool, got %q", got)
|
||||
}
|
||||
if p.Count() != 0 {
|
||||
t.Fatalf("expected count 0, got %d", p.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleDir(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1"})
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("expected /dir1, got %q", got)
|
||||
}
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("expected /dir1 again, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobin(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/a", "/b", "/c"})
|
||||
got := []string{
|
||||
p.GetNextConfigDir(),
|
||||
p.GetNextConfigDir(),
|
||||
p.GetNextConfigDir(),
|
||||
p.GetNextConfigDir(),
|
||||
}
|
||||
want := []string{"/a", "/b", "/c", "/a"}
|
||||
for i, w := range want {
|
||||
if got[i] != w {
|
||||
t.Fatalf("call %d: expected %q, got %q", i, w, got[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastBusy(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1", "/dir2", "/dir3"})
|
||||
p.ReportRequestStart("/dir1")
|
||||
p.ReportRequestStart("/dir2")
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir3" {
|
||||
t.Fatalf("expected /dir3 (least busy), got %q", got)
|
||||
}
|
||||
|
||||
p.ReportRequestStart("/dir3")
|
||||
p.ReportRequestEnd("/dir1")
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("expected /dir1 after end, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipsRateLimited(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
||||
p.ReportRateLimit("/dir1", 60000)
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
||||
t.Fatalf("expected /dir2, got %q", got)
|
||||
}
|
||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
||||
t.Fatalf("expected /dir2 again, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackToSoonestRecovery(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
||||
p.ReportRateLimit("/dir1", 60000)
|
||||
p.ReportRateLimit("/dir2", 30000)
|
||||
|
||||
// dir2 recovers sooner — should be selected
|
||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
||||
t.Fatalf("expected /dir2 (sooner recovery), got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveRequestsDoesNotGoNegative(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1"})
|
||||
p.ReportRequestEnd("/dir1")
|
||||
p.ReportRequestEnd("/dir1")
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("pool should still work, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnoreUnknownConfigDir(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1"})
|
||||
p.ReportRequestStart("/nonexistent")
|
||||
p.ReportRequestEnd("/nonexistent")
|
||||
p.ReportRateLimit("/nonexistent", 60000)
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("expected /dir1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitExpires(t *testing.T) {
|
||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
||||
p.ReportRateLimit("/dir1", 50)
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
||||
t.Fatalf("immediately expected /dir2, got %q", got)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
||||
t.Fatalf("after expiry expected /dir1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalPool(t *testing.T) {
|
||||
InitAccountPool([]string{"/g1", "/g2"})
|
||||
if got := GetNextAccountConfigDir(); got != "/g1" {
|
||||
t.Fatalf("expected /g1, got %q", got)
|
||||
}
|
||||
if got := GetNextAccountConfigDir(); got != "/g2" {
|
||||
t.Fatalf("expected /g2, got %q", got)
|
||||
}
|
||||
if got := GetNextAccountConfigDir(); got != "/g1" {
|
||||
t.Fatalf("expected /g1 again, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalPoolEmpty(t *testing.T) {
|
||||
InitAccountPool(nil)
|
||||
if got := GetNextAccountConfigDir(); got != "" {
|
||||
t.Fatalf("expected empty string for empty global pool, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalPoolReinit(t *testing.T) {
|
||||
InitAccountPool([]string{"/old1", "/old2"})
|
||||
GetNextAccountConfigDir()
|
||||
InitAccountPool([]string{"/new1"})
|
||||
if got := GetNextAccountConfigDir(); got != "/new1" {
|
||||
t.Fatalf("expected /new1 after reinit, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalPoolFunctionsNoopBeforeInit(t *testing.T) {
|
||||
InitAccountPool(nil)
|
||||
ReportRequestStart("/dir1")
|
||||
ReportRequestEnd("/dir1")
|
||||
ReportRateLimit("/dir1", 1000)
|
||||
}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func killProcessGroup(c *exec.Cmd) error {
|
||||
if c.Process == nil {
|
||||
return nil
|
||||
}
|
||||
// 殺死整個 process group(負號表示 group)
|
||||
pgid, err := syscall.Getpgid(c.Process.Pid)
|
||||
if err == nil {
|
||||
_ = syscall.Kill(-pgid, syscall.SIGKILL)
|
||||
}
|
||||
// 同時也 kill 主程序,以防萬一
|
||||
return c.Process.Kill()
|
||||
}
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func killProcessGroup(c *exec.Cmd) error {
|
||||
if c.Process == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Process.Kill()
|
||||
}
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
package process
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RunResult struct {
|
||||
Code int
|
||||
Stdout string
|
||||
Stderr string
|
||||
}
|
||||
|
||||
type RunOptions struct {
|
||||
Cwd string
|
||||
TimeoutMs int
|
||||
MaxMode bool
|
||||
ConfigDir string
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
type RunStreamingOptions struct {
|
||||
RunOptions
|
||||
OnLine func(line string)
|
||||
}
|
||||
|
||||
// ─── Global child process registry ──────────────────────────────────────────
|
||||
|
||||
var (
|
||||
activeMu sync.Mutex
|
||||
activeChildren []*exec.Cmd
|
||||
)
|
||||
|
||||
func registerChild(c *exec.Cmd) {
|
||||
activeMu.Lock()
|
||||
activeChildren = append(activeChildren, c)
|
||||
activeMu.Unlock()
|
||||
}
|
||||
|
||||
func unregisterChild(c *exec.Cmd) {
|
||||
activeMu.Lock()
|
||||
for i, ch := range activeChildren {
|
||||
if ch == c {
|
||||
activeChildren = append(activeChildren[:i], activeChildren[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
activeMu.Unlock()
|
||||
}
|
||||
|
||||
func KillAllChildProcesses() {
|
||||
activeMu.Lock()
|
||||
all := make([]*exec.Cmd, len(activeChildren))
|
||||
copy(all, activeChildren)
|
||||
activeChildren = nil
|
||||
activeMu.Unlock()
|
||||
for _, c := range all {
|
||||
killProcessGroup(c)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Spawn ────────────────────────────────────────────────────────────────
|
||||
|
||||
func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd {
|
||||
envSrc := env.OsEnvToMap()
|
||||
resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd)
|
||||
|
||||
if opts.MaxMode && maxModeFn != nil {
|
||||
maxModeFn(resolved.AgentScriptPath, opts.ConfigDir)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string, len(resolved.Env))
|
||||
for k, v := range resolved.Env {
|
||||
envMap[k] = v
|
||||
}
|
||||
if opts.ConfigDir != "" {
|
||||
envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir
|
||||
} else if resolved.ConfigDir != "" {
|
||||
if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists {
|
||||
envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir
|
||||
}
|
||||
}
|
||||
|
||||
envSlice := make([]string, 0, len(envMap))
|
||||
for k, v := range envMap {
|
||||
envSlice = append(envSlice, k+"="+v)
|
||||
}
|
||||
|
||||
ctx := opts.Ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出
|
||||
c := exec.CommandContext(ctx, resolved.Command, resolved.Args...)
|
||||
c.Dir = opts.Cwd
|
||||
c.Env = envSlice
|
||||
// 設定新的 process group,使 kill 能傳遞給所有子孫程序
|
||||
c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
// WaitDelay:context cancel 後額外等待這麼久再強制關閉 pipes
|
||||
c.WaitDelay = 5 * time.Second
|
||||
// Cancel 函式:殺死整個 process group
|
||||
c.Cancel = func() error {
|
||||
return killProcessGroup(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// MaxModeFn is set by the agent package to avoid import cycle.
|
||||
var MaxModeFn func(agentScriptPath, configDir string)
|
||||
|
||||
func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) {
|
||||
ctx := opts.Ctx
|
||||
var cancel context.CancelFunc
|
||||
if opts.TimeoutMs > 0 {
|
||||
if ctx == nil {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
} else {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
opts.Ctx = ctx
|
||||
} else if ctx == nil {
|
||||
opts.Ctx = context.Background()
|
||||
}
|
||||
|
||||
c := spawnChild(cmdStr, args, &opts, MaxModeFn)
|
||||
var stdoutBuf, stderrBuf strings.Builder
|
||||
c.Stdout = &stdoutBuf
|
||||
c.Stderr = &stderrBuf
|
||||
|
||||
if err := c.Start(); err != nil {
|
||||
// context 已取消或命令找不到時
|
||||
if opts.Ctx != nil && opts.Ctx.Err() != nil {
|
||||
return RunResult{Code: -1}, nil
|
||||
}
|
||||
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
|
||||
return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
|
||||
}
|
||||
return RunResult{}, err
|
||||
}
|
||||
registerChild(c)
|
||||
defer unregisterChild(c)
|
||||
|
||||
err := c.Wait()
|
||||
code := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
code = exitErr.ExitCode()
|
||||
if code == 0 {
|
||||
code = -1
|
||||
}
|
||||
} else {
|
||||
// context cancelled or killed — return -1 but no error
|
||||
return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil
|
||||
}
|
||||
}
|
||||
return RunResult{
|
||||
Code: code,
|
||||
Stdout: stdoutBuf.String(),
|
||||
Stderr: stderrBuf.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type StreamResult struct {
|
||||
Code int
|
||||
Stderr string
|
||||
}
|
||||
|
||||
func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) {
|
||||
ctx := opts.Ctx
|
||||
var cancel context.CancelFunc
|
||||
if opts.TimeoutMs > 0 {
|
||||
if ctx == nil {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
} else {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
opts.RunOptions.Ctx = ctx
|
||||
} else if opts.RunOptions.Ctx == nil {
|
||||
opts.RunOptions.Ctx = context.Background()
|
||||
}
|
||||
|
||||
c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn)
|
||||
stdoutPipe, err := c.StdoutPipe()
|
||||
if err != nil {
|
||||
return StreamResult{}, err
|
||||
}
|
||||
stderrPipe, err := c.StderrPipe()
|
||||
if err != nil {
|
||||
return StreamResult{}, err
|
||||
}
|
||||
|
||||
if err := c.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
|
||||
return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
|
||||
}
|
||||
return StreamResult{}, err
|
||||
}
|
||||
registerChild(c)
|
||||
defer unregisterChild(c)
|
||||
|
||||
var stderrBuf strings.Builder
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) != "" {
|
||||
opts.OnLine(line)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
stderrBuf.WriteString(scanner.Text())
|
||||
stderrBuf.WriteString("\n")
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
err = c.Wait()
|
||||
code := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
code = exitErr.ExitCode()
|
||||
if code == 0 {
|
||||
code = -1
|
||||
}
|
||||
}
|
||||
}
|
||||
return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil
|
||||
}
|
||||
|
|
@ -1,283 +0,0 @@
|
|||
package process_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/process"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// sh 是跨平台 shell 執行小 script 的輔助函式
|
||||
func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) {
|
||||
t.Helper()
|
||||
return process.Run("sh", []string{"-c", script}, opts)
|
||||
}
|
||||
|
||||
func TestRun_StdoutAndStderr(t *testing.T) {
|
||||
result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if result.Stdout != "hello\n" {
|
||||
t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n")
|
||||
}
|
||||
if result.Stderr != "world\n" {
|
||||
t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_BasicSpawn(t *testing.T) {
|
||||
result, err := sh(t, "printf ok", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if result.Stdout != "ok" {
|
||||
t.Errorf("Stdout = %q, want ok", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ConfigDir_Propagated(t *testing.T) {
|
||||
result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
|
||||
process.RunOptions{ConfigDir: "/test/account/dir"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stdout != "/test/account/dir" {
|
||||
t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ConfigDir_Absent(t *testing.T) {
|
||||
// 確保沒有殘留的環境變數
|
||||
_ = os.Unsetenv("CURSOR_CONFIG_DIR")
|
||||
result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`},
|
||||
process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stdout != "unset" {
|
||||
t.Errorf("Stdout = %q, want unset", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_NonZeroExit(t *testing.T) {
|
||||
result, err := sh(t, "exit 42", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 42 {
|
||||
t.Errorf("Code = %d, want 42", result.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_Timeout(t *testing.T) {
|
||||
start := time.Now()
|
||||
result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300})
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after timeout")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_OnLine(t *testing.T) {
|
||||
var lines []string
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if len(lines) != 3 {
|
||||
t.Errorf("got %d lines, want 3: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
|
||||
t.Errorf("lines = %v, want [a b c]", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_FlushFinalLine(t *testing.T) {
|
||||
var lines []string
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "printf tail"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if len(lines) != 1 {
|
||||
t.Errorf("got %d lines, want 1: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "tail" {
|
||||
t.Errorf("lines[0] = %q, want tail", lines[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_ConfigDir(t *testing.T) {
|
||||
var lines []string
|
||||
_, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"},
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(lines) != 1 || lines[0] != "/my/config/dir" {
|
||||
t.Errorf("lines = %v, want [/my/config/dir]", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Stderr(t *testing.T) {
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"},
|
||||
process.RunStreamingOptions{OnLine: func(string) {}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stderr == "" {
|
||||
t.Error("expected stderr to contain output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Timeout(t *testing.T) {
|
||||
start := time.Now()
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{TimeoutMs: 300},
|
||||
OnLine: func(string) {},
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after timeout")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Concurrent(t *testing.T) {
|
||||
var lines1, lines2 []string
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
run := func(label string, target *[]string) {
|
||||
process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { *target = append(*target, line) },
|
||||
})
|
||||
done <- struct{}{}
|
||||
}
|
||||
|
||||
go run("stream1", &lines1)
|
||||
go run("stream2", &lines2)
|
||||
|
||||
<-done
|
||||
<-done
|
||||
|
||||
if len(lines1) != 1 || lines1[0] != "stream1" {
|
||||
t.Errorf("lines1 = %v, want [stream1]", lines1)
|
||||
}
|
||||
if len(lines2) != 1 || lines2[0] != "stream2" {
|
||||
t.Errorf("lines2 = %v, want [stream2]", lines2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_ContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
process.RunStreaming("sh", []string{"-c", "sleep 30"},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{Ctx: ctx},
|
||||
OnLine: func(string) {},
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.AfterFunc(100*time.Millisecond, cancel)
|
||||
<-done
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan process.RunResult, 1)
|
||||
|
||||
go func() {
|
||||
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
|
||||
done <- r
|
||||
}()
|
||||
|
||||
time.AfterFunc(100*time.Millisecond, cancel)
|
||||
result := <-done
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after cancel")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_AlreadyCancelledContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 已取消
|
||||
|
||||
start := time.Now()
|
||||
result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKillAllChildProcesses(t *testing.T) {
|
||||
done := make(chan process.RunResult, 1)
|
||||
go func() {
|
||||
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{})
|
||||
done <- r
|
||||
}()
|
||||
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
process.KillAllChildProcesses()
|
||||
result := <-done
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after kill")
|
||||
}
|
||||
// 再次呼叫不應 panic
|
||||
process.KillAllChildProcesses()
|
||||
}
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
package cursor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/apitypes"
|
||||
"cursor-api-proxy/internal/config"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
cfg config.BridgeConfig
|
||||
}
|
||||
|
||||
func NewProvider(cfg config.BridgeConfig) *Provider {
|
||||
return &Provider{cfg: cfg}
|
||||
}
|
||||
|
||||
func (p *Provider) Name() string {
|
||||
return "cursor"
|
||||
}
|
||||
|
||||
func (p *Provider) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/apitypes"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/providers/cursor"
|
||||
"cursor-api-proxy/internal/providers/geminiweb"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Name() string
|
||||
Close() error
|
||||
Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error
|
||||
}
|
||||
|
||||
func NewProvider(cfg config.BridgeConfig) (Provider, error) {
|
||||
providerType := cfg.Provider
|
||||
if providerType == "" {
|
||||
providerType = "cursor"
|
||||
}
|
||||
|
||||
switch providerType {
|
||||
case "cursor":
|
||||
return cursor.NewProvider(cfg), nil
|
||||
case "gemini-web":
|
||||
return geminiweb.NewPlaywrightProvider(cfg)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider: %s", providerType)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
package geminiweb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/launcher"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
)
|
||||
|
||||
type Browser struct {
|
||||
browser *rod.Browser
|
||||
visible bool
|
||||
}
|
||||
|
||||
func NewBrowser(visible bool) (*Browser, error) {
|
||||
l := launcher.New()
|
||||
if visible {
|
||||
l = l.Headless(false)
|
||||
} else {
|
||||
l = l.Headless(true)
|
||||
}
|
||||
|
||||
url, err := l.Launch()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to launch browser: %w", err)
|
||||
}
|
||||
|
||||
b := rod.New().ControlURL(url)
|
||||
if err := b.Connect(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect browser: %w", err)
|
||||
}
|
||||
|
||||
return &Browser{browser: b, visible: visible}, nil
|
||||
}
|
||||
|
||||
func (b *Browser) Close() error {
|
||||
if b.browser != nil {
|
||||
return b.browser.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Browser) NewPage() (*rod.Page, error) {
|
||||
return b.browser.Page(proto.TargetCreateTarget{URL: "about:blank"})
|
||||
}
|
||||
|
||||
type Cookie struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Domain string `json:"domain"`
|
||||
Path string `json:"path"`
|
||||
Expires float64 `json:"expires"`
|
||||
HTTPOnly bool `json:"httpOnly"`
|
||||
Secure bool `json:"secure"`
|
||||
}
|
||||
|
||||
func LoadCookiesFromFile(cookieFile string) ([]Cookie, error) {
|
||||
data, err := os.ReadFile(cookieFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read cookies: %w", err)
|
||||
}
|
||||
|
||||
var cookies []Cookie
|
||||
if err := json.Unmarshal(data, &cookies); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse cookies: %w", err)
|
||||
}
|
||||
|
||||
return cookies, nil
|
||||
}
|
||||
|
||||
func SaveCookiesToFile(cookies []Cookie, cookieFile string) error {
|
||||
data, err := json.MarshalIndent(cookies, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal cookies: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(cookieFile)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create cookie dir: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(cookieFile, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write cookies: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetCookiesOnPage(page *rod.Page, cookies []Cookie) error {
|
||||
var protoCookies []*proto.NetworkCookieParam
|
||||
for _, c := range cookies {
|
||||
p := &proto.NetworkCookieParam{
|
||||
Name: c.Name,
|
||||
Value: c.Value,
|
||||
Domain: c.Domain,
|
||||
Path: c.Path,
|
||||
HTTPOnly: c.HTTPOnly,
|
||||
Secure: c.Secure,
|
||||
}
|
||||
if c.Expires > 0 {
|
||||
exp := proto.TimeSinceEpoch(c.Expires)
|
||||
p.Expires = exp
|
||||
}
|
||||
protoCookies = append(protoCookies, p)
|
||||
}
|
||||
return page.SetCookies(protoCookies)
|
||||
}
|
||||
|
||||
func WaitForElement(page *rod.Page, selector string, timeout time.Duration) (*rod.Element, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
return page.Context(ctx).Element(selector)
|
||||
}
|
||||
|
||||
func WaitForElements(page *rod.Page, selector string, timeout time.Duration) (rod.Elements, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
return page.Context(ctx).Elements(selector)
|
||||
}
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
package geminiweb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/launcher"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
)
|
||||
|
||||
// BrowserManager 管理瀏覽器實例的生命週期
|
||||
type BrowserManager struct {
|
||||
mu sync.Mutex
|
||||
browser *rod.Browser
|
||||
userDataDir string
|
||||
page *rod.Page
|
||||
visible bool
|
||||
isRunning bool
|
||||
currentModel string
|
||||
}
|
||||
|
||||
var (
|
||||
globalManager *BrowserManager
|
||||
globalMu sync.Mutex
|
||||
)
|
||||
|
||||
// GetBrowserManager 獲取全域瀏覽器管理器(單例)
|
||||
func GetBrowserManager(userDataDir string, visible bool) (*BrowserManager, error) {
|
||||
globalMu.Lock()
|
||||
defer globalMu.Unlock()
|
||||
|
||||
if globalManager != nil {
|
||||
return globalManager, nil
|
||||
}
|
||||
|
||||
manager, err := NewBrowserManager(userDataDir, visible)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
globalManager = manager
|
||||
return globalManager, nil
|
||||
}
|
||||
|
||||
// NewBrowserManager 建立新的瀏覽器管理器
|
||||
func NewBrowserManager(userDataDir string, visible bool) (*BrowserManager, error) {
|
||||
cleanLockFiles(userDataDir)
|
||||
|
||||
if err := os.MkdirAll(userDataDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create user data dir: %w", err)
|
||||
}
|
||||
|
||||
return &BrowserManager{
|
||||
userDataDir: userDataDir,
|
||||
visible: visible,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// cleanLockFiles 清理 Chrome 的殘留鎖檔案
|
||||
func cleanLockFiles(userDataDir string) {
|
||||
lockFiles := []string{
|
||||
"SingletonLock",
|
||||
"SingletonCookie",
|
||||
"SingletonSocket",
|
||||
"Default/SingletonLock",
|
||||
"Default/SingletonCookie",
|
||||
"Default/SingletonSocket",
|
||||
}
|
||||
|
||||
for _, file := range lockFiles {
|
||||
path := filepath.Join(userDataDir, file)
|
||||
os.Remove(path)
|
||||
}
|
||||
}
|
||||
|
||||
// Launch 啟動瀏覽器(如果尚未啟動)
|
||||
func (m *BrowserManager) Launch() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.isRunning && m.browser != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
l := launcher.New()
|
||||
|
||||
if m.visible {
|
||||
l = l.Headless(false)
|
||||
} else {
|
||||
l = l.Headless(true)
|
||||
}
|
||||
|
||||
l = l.UserDataDir(m.userDataDir)
|
||||
|
||||
url, err := l.Launch()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to launch browser: %w", err)
|
||||
}
|
||||
|
||||
b := rod.New().ControlURL(url)
|
||||
if err := b.Connect(); err != nil {
|
||||
return fmt.Errorf("failed to connect browser: %w", err)
|
||||
}
|
||||
|
||||
m.browser = b
|
||||
|
||||
page, err := b.Page(proto.TargetCreateTarget{URL: "about:blank"})
|
||||
if err != nil {
|
||||
_ = b.Close()
|
||||
return fmt.Errorf("failed to create page: %w", err)
|
||||
}
|
||||
|
||||
m.page = page
|
||||
m.isRunning = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPage 獲取頁面
|
||||
func (m *BrowserManager) GetPage() (*rod.Page, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.isRunning || m.browser == nil {
|
||||
return nil, fmt.Errorf("browser not running")
|
||||
}
|
||||
|
||||
return m.page, nil
|
||||
}
|
||||
|
||||
// Close 關閉瀏覽器
|
||||
func (m *BrowserManager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.isRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if m.browser != nil {
|
||||
err = m.browser.Close()
|
||||
m.browser = nil
|
||||
}
|
||||
|
||||
m.page = nil
|
||||
m.isRunning = false
|
||||
return err
|
||||
}
|
||||
|
||||
// IsRunning 檢查瀏覽器是否正在運行
|
||||
func (m *BrowserManager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.isRunning
|
||||
}
|
||||
|
||||
// SetCurrentModel 設定當前模型
|
||||
func (m *BrowserManager) SetCurrentModel(model string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.currentModel = model
|
||||
}
|
||||
|
||||
// GetCurrentModel 獲取當前模型
|
||||
func (m *BrowserManager) GetCurrentModel() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.currentModel
|
||||
}
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
package geminiweb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
)
|
||||
|
||||
const geminiURL = "https://gemini.google.com/app"
|
||||
|
||||
// 輸入框選擇器(依優先順序)
|
||||
var inputSelectors = []string{
|
||||
".ProseMirror",
|
||||
"rich-textarea",
|
||||
"div[role='textbox'][contenteditable='true']",
|
||||
"div[contenteditable='true']",
|
||||
"textarea",
|
||||
}
|
||||
|
||||
// NavigateToGemini 導航到 Gemini
|
||||
func NavigateToGemini(page *rod.Page) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := page.Context(ctx).Navigate(geminiURL); err != nil {
|
||||
return fmt.Errorf("failed to navigate: %w", err)
|
||||
}
|
||||
return page.Context(ctx).WaitLoad()
|
||||
}
|
||||
|
||||
// IsLoggedIn 檢查是否已登入
|
||||
func IsLoggedIn(page *rod.Page) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for _, sel := range inputSelectors {
|
||||
if _, err := page.Context(ctx).Element(sel); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SelectModel 選擇模型(可選)
|
||||
func SelectModel(page *rod.Page, model string) error {
|
||||
fmt.Printf("[GeminiWeb] Model selection skipped (using current model)\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
// TypeInput 在輸入框中輸入文字
|
||||
func TypeInput(page *rod.Page, text string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fmt.Println("[GeminiWeb] Looking for input field...")
|
||||
|
||||
// 1. 嘗試所有選擇器
|
||||
var inputEl *rod.Element
|
||||
var err error
|
||||
|
||||
for _, sel := range inputSelectors {
|
||||
fmt.Printf(" Trying: %s\n", sel)
|
||||
inputEl, err = page.Context(ctx).Element(sel)
|
||||
if err == nil {
|
||||
fmt.Printf(" ✓ Found with: %s\n", sel)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// 2. Fallback: 嘗試等待頁面載入完成後重試
|
||||
fmt.Println("[GeminiWeb] Waiting for page to fully load...")
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
for _, sel := range inputSelectors {
|
||||
fmt.Printf(" Retrying: %s\n", sel)
|
||||
inputEl, err = page.Context(ctx).Element(sel)
|
||||
if err == nil {
|
||||
fmt.Printf(" ✓ Found with: %s\n", sel)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// 3. Debug: 印出頁面標題和 URL
|
||||
info, _ := page.Info()
|
||||
fmt.Printf("[GeminiWeb] DEBUG: URL=%s Title=%s\n", info.URL, info.Title)
|
||||
|
||||
// 4. Fallback: 嘗試更通用的選擇器
|
||||
fmt.Println("[GeminiWeb] Trying generic selectors...")
|
||||
genericSelectors := []string{
|
||||
"div[contenteditable]",
|
||||
"[contenteditable]",
|
||||
"textarea",
|
||||
"input[type='text']",
|
||||
}
|
||||
|
||||
for _, sel := range genericSelectors {
|
||||
fmt.Printf(" Trying generic: %s\n", sel)
|
||||
inputEl, err = page.Context(ctx).Element(sel)
|
||||
if err == nil {
|
||||
fmt.Printf(" ✓ Found with: %s\n", sel)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
info, _ := page.Info()
|
||||
return fmt.Errorf("input field not found after trying all selectors (URL=%s)", info.URL)
|
||||
}
|
||||
|
||||
// 2. Focus 輸入框
|
||||
fmt.Printf("[GeminiWeb] Focusing input field...\n")
|
||||
if err := inputEl.Focus(); err != nil {
|
||||
return fmt.Errorf("failed to focus input: %w", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// 3. 使用 Input 方法
|
||||
fmt.Printf("[GeminiWeb] Typing %d chars...\n", len(text))
|
||||
if err := inputEl.Input(text); err != nil {
|
||||
return fmt.Errorf("failed to input text: %w", err)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
fmt.Println("[GeminiWeb] Input complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClickSend 發送訊息
|
||||
func ClickSend(page *rod.Page) error {
|
||||
// 方法 1: 按 Enter
|
||||
if err := page.Keyboard.Press('\r'); err != nil {
|
||||
return fmt.Errorf("failed to press Enter: %w", err)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WaitForReady 等待頁面空閒
|
||||
func WaitForReady(page *rod.Page) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fmt.Println("[GeminiWeb] Checking if page is ready...")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Println("[GeminiWeb] Page ready check timeout, proceeding anyway")
|
||||
return nil
|
||||
default:
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// 檢查是否有停止按鈕
|
||||
hasStopBtn := false
|
||||
stopBtns, _ := page.Elements("button[aria-label*='Stop'], button[aria-label*='停止']")
|
||||
for _, btn := range stopBtns {
|
||||
visible, _ := btn.Visible()
|
||||
if visible {
|
||||
hasStopBtn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasStopBtn {
|
||||
fmt.Println("[GeminiWeb] Page is ready")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractResponse 提取回應文字
|
||||
func ExtractResponse(page *rod.Page) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var lastText string
|
||||
lastUpdate := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if lastText != "" {
|
||||
return lastText, nil
|
||||
}
|
||||
return "", fmt.Errorf("response timeout")
|
||||
default:
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// 尋找回應文字
|
||||
for _, sel := range responseSelectors {
|
||||
elements, err := page.Elements(sel)
|
||||
if err != nil || len(elements) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 取得最後一個元素的文字
|
||||
lastEl := elements[len(elements)-1]
|
||||
text, err := lastEl.Text()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
text = strings.TrimSpace(text)
|
||||
if text != "" && text != lastText && len(text) > len(lastText) {
|
||||
lastText = text
|
||||
lastUpdate = time.Now()
|
||||
fmt.Printf("[GeminiWeb] Response length: %d\n", len(text))
|
||||
}
|
||||
}
|
||||
|
||||
// 檢查是否已完成(2 秒內沒有新內容)
|
||||
if time.Since(lastUpdate) > 2*time.Second && lastText != "" {
|
||||
// 最後檢查一次是否還有停止按鈕
|
||||
hasStopBtn := false
|
||||
stopBtns, _ := page.Elements("button[aria-label*='Stop'], button[aria-label*='停止']")
|
||||
for _, btn := range stopBtns {
|
||||
visible, _ := btn.Visible()
|
||||
if visible {
|
||||
hasStopBtn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasStopBtn {
|
||||
return lastText, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 默認的回應選擇器
|
||||
var responseSelectors = []string{
|
||||
".model-response-text",
|
||||
".message-content",
|
||||
".markdown",
|
||||
".prose",
|
||||
"model-response",
|
||||
}
|
||||
|
|
@ -1,641 +0,0 @@
|
|||
package geminiweb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/apitypes"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/playwright-community/playwright-go"
|
||||
)
|
||||
|
||||
// PlaywrightProvider 使用 Playwright 的 Gemini Provider
|
||||
type PlaywrightProvider struct {
|
||||
cfg config.BridgeConfig
|
||||
pw *playwright.Playwright
|
||||
browser playwright.Browser
|
||||
context playwright.BrowserContext
|
||||
page playwright.Page
|
||||
mu sync.Mutex
|
||||
userDataDir string
|
||||
}
|
||||
|
||||
var (
|
||||
playwrightInstance *playwright.Playwright
|
||||
playwrightOnce sync.Once
|
||||
playwrightErr error
|
||||
)
|
||||
|
||||
// NewPlaywrightProvider 建立新的 Playwright Provider
|
||||
func NewPlaywrightProvider(cfg config.BridgeConfig) (*PlaywrightProvider, error) {
|
||||
// 確保 Playwright 已初始化(單例)
|
||||
playwrightOnce.Do(func() {
|
||||
playwrightInstance, playwrightErr = playwright.Run()
|
||||
if playwrightErr != nil {
|
||||
playwrightErr = fmt.Errorf("failed to run playwright: %w", playwrightErr)
|
||||
}
|
||||
})
|
||||
|
||||
if playwrightErr != nil {
|
||||
return nil, playwrightErr
|
||||
}
|
||||
|
||||
// 清理 Chrome 鎖檔案
|
||||
userDataDir := filepath.Join(cfg.GeminiAccountDir, "default-session")
|
||||
cleanLockFiles(userDataDir)
|
||||
|
||||
// 確保目錄存在
|
||||
if err := os.MkdirAll(userDataDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create user data dir: %w", err)
|
||||
}
|
||||
|
||||
return &PlaywrightProvider{
|
||||
cfg: cfg,
|
||||
pw: playwrightInstance,
|
||||
userDataDir: userDataDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getName 返回 Provider 名稱
|
||||
func (p *PlaywrightProvider) Name() string {
|
||||
return "gemini-web"
|
||||
}
|
||||
|
||||
// launchIfNeeded 如果需要則啟動瀏覽器
|
||||
func (p *PlaywrightProvider) launchIfNeeded() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.context != nil && p.page != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("[GeminiWeb] Launching Chromium...")
|
||||
|
||||
// 使用 LaunchPersistentContext(自動保存 session)
|
||||
context, err := p.pw.Chromium.LaunchPersistentContext(p.userDataDir,
|
||||
playwright.BrowserTypeLaunchPersistentContextOptions{
|
||||
Headless: playwright.Bool(!p.cfg.GeminiBrowserVisible),
|
||||
Args: []string{
|
||||
"--no-first-run",
|
||||
"--no-default-browser-check",
|
||||
"--disable-background-networking",
|
||||
"--disable-extensions",
|
||||
"--disable-plugins",
|
||||
"--disable-sync",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to launch persistent context: %w", err)
|
||||
}
|
||||
|
||||
p.context = context
|
||||
|
||||
// 取得或建立頁面
|
||||
pages := context.Pages()
|
||||
if len(pages) > 0 {
|
||||
p.page = pages[0]
|
||||
} else {
|
||||
page, err := context.NewPage()
|
||||
if err != nil {
|
||||
_ = context.Close()
|
||||
return fmt.Errorf("failed to create page: %w", err)
|
||||
}
|
||||
p.page = page
|
||||
}
|
||||
|
||||
fmt.Println("[GeminiWeb] Browser launched")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate 生成回應
|
||||
func (p *PlaywrightProvider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) (err error) {
|
||||
// 確保在返回錯誤時保存診斷
|
||||
defer func() {
|
||||
if err != nil {
|
||||
fmt.Println("[GeminiWeb] Error occurred, saving diagnostics...")
|
||||
_ = p.saveDiagnostics()
|
||||
}
|
||||
}()
|
||||
|
||||
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
|
||||
|
||||
// 1. 確保瀏覽器已啟動
|
||||
if err := p.launchIfNeeded(); err != nil {
|
||||
return fmt.Errorf("failed to launch browser: %w", err)
|
||||
}
|
||||
|
||||
// 2. 導航到 Gemini(如果需要)
|
||||
currentURL := p.page.URL()
|
||||
if !strings.Contains(currentURL, "gemini.google.com") {
|
||||
fmt.Println("[GeminiWeb] Navigating to Gemini...")
|
||||
if _, err := p.page.Goto("https://gemini.google.com/app", playwright.PageGotoOptions{
|
||||
WaitUntil: playwright.WaitUntilStateDomcontentloaded,
|
||||
Timeout: playwright.Float(60000),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to navigate: %w", err)
|
||||
}
|
||||
// 額外等待 JavaScript 載入
|
||||
fmt.Println("[GeminiWeb] Waiting for page to initialize...")
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
|
||||
// 3. 調試模式:等待用戶確認
|
||||
if p.cfg.GeminiBrowserVisible {
|
||||
fmt.Println("\n" + strings.Repeat("=", 70))
|
||||
fmt.Println("🔍 調試模式:瀏覽器已開啟")
|
||||
fmt.Println("請檢查瀏覽器畫面,然後按 ENTER 繼續...")
|
||||
fmt.Println("如果有問題,請查看: /tmp/gemini-debug.*")
|
||||
fmt.Println(strings.Repeat("=", 70))
|
||||
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
}
|
||||
|
||||
// 4. 等待頁面完全載入(project-golem 策略)
|
||||
fmt.Println("[GeminiWeb] Waiting for page to be ready...")
|
||||
if err := p.waitForPageReady(); err != nil {
|
||||
fmt.Printf("[GeminiWeb] Warning: %v\n", err)
|
||||
|
||||
// 額外調試:輸出頁面 HTML 結構
|
||||
if p.cfg.GeminiBrowserVisible {
|
||||
html, _ := p.page.Content()
|
||||
debugPath := "/tmp/gemini-debug.html"
|
||||
if err := os.WriteFile(debugPath, []byte(html), 0644); err == nil {
|
||||
fmt.Printf("[GeminiWeb] HTML saved to: %s\n", debugPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 檢查登入狀態
|
||||
fmt.Println("[GeminiWeb] Checking login status...")
|
||||
loggedIn := p.isLoggedIn()
|
||||
if !loggedIn {
|
||||
fmt.Println("[GeminiWeb] Not logged in, continuing anyway")
|
||||
if p.cfg.GeminiBrowserVisible {
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println("Browser is open. You can:")
|
||||
fmt.Println("1. Log in to Gemini now")
|
||||
fmt.Println("2. Continue without login")
|
||||
fmt.Println("========================================\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Println("[GeminiWeb] ✓ Logged in")
|
||||
}
|
||||
|
||||
// 5. 選擇模型(如果支援)
|
||||
if err := p.selectModel(model); err != nil {
|
||||
fmt.Printf("[GeminiWeb] Warning: model selection failed: %v\n", err)
|
||||
}
|
||||
|
||||
// 6. 建構提示詞
|
||||
prompt := buildPromptFromMessagesPlaywright(messages)
|
||||
fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt))
|
||||
|
||||
// 7. 輸入文字(使用 Playwright 的 Auto-wait)
|
||||
if err := p.typeInput(prompt); err != nil {
|
||||
return fmt.Errorf("failed to type: %w", err)
|
||||
}
|
||||
|
||||
// 7. 發送訊息
|
||||
fmt.Println("[GeminiWeb] Sending message...")
|
||||
if err := p.sendMessage(); err != nil {
|
||||
return fmt.Errorf("failed to send: %w", err)
|
||||
}
|
||||
|
||||
// 8. 提取回應
|
||||
fmt.Println("[GeminiWeb] Waiting for response...")
|
||||
response, err := p.extractResponse()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract response: %w", err)
|
||||
}
|
||||
|
||||
// 9. 回調
|
||||
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response})
|
||||
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
|
||||
|
||||
fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 關閉 Provider
|
||||
func (p *PlaywrightProvider) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.context != nil {
|
||||
if err := p.context.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
p.context = nil
|
||||
p.page = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveDiagnostics 保存診斷信息
|
||||
func (p *PlaywrightProvider) saveDiagnostics() error {
|
||||
if p.page == nil {
|
||||
return fmt.Errorf("no page available")
|
||||
}
|
||||
|
||||
// 截圖
|
||||
screenshotPath := "/tmp/gemini-debug.png"
|
||||
if _, err := p.page.Screenshot(playwright.PageScreenshotOptions{
|
||||
Path: playwright.String(screenshotPath),
|
||||
}); err == nil {
|
||||
fmt.Printf("[GeminiWeb] Screenshot saved: %s\n", screenshotPath)
|
||||
}
|
||||
|
||||
// HTML
|
||||
htmlPath := "/tmp/gemini-debug.html"
|
||||
if html, err := p.page.Content(); err == nil {
|
||||
if err := os.WriteFile(htmlPath, []byte(html), 0644); err == nil {
|
||||
fmt.Printf("[GeminiWeb] HTML saved: %s\n", htmlPath)
|
||||
}
|
||||
}
|
||||
|
||||
// 輸出頁面信息
|
||||
url := p.page.URL()
|
||||
title, _ := p.page.Title()
|
||||
fmt.Printf("[GeminiWeb] Diagnostics: URL=%s, Title=%s\n", url, title)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForPageReady 等待頁面完全就緒(project-golem 策略)
|
||||
func (p *PlaywrightProvider) waitForPageReady() error {
|
||||
fmt.Println("[GeminiWeb] Checking for ready state...")
|
||||
|
||||
// 1. 等待停止按鈕消失(如果存在)
|
||||
_, _ = p.page.WaitForSelector("button[aria-label*='Stop'], button[aria-label*='停止']", playwright.PageWaitForSelectorOptions{
|
||||
State: playwright.WaitForSelectorStateDetached,
|
||||
Timeout: playwright.Float(5000),
|
||||
})
|
||||
|
||||
// 2. 嘗試多種等待策略
|
||||
inputSelectors := []string{
|
||||
".ql-editor.ql-blank",
|
||||
".ql-editor",
|
||||
"div[contenteditable='true'][role='textbox']",
|
||||
"div[contenteditable='true']",
|
||||
".ProseMirror",
|
||||
"rich-textarea",
|
||||
"textarea",
|
||||
}
|
||||
|
||||
// 策略 A: 等待任一輸入框出現
|
||||
for i, sel := range inputSelectors {
|
||||
fmt.Printf(" [%d/%d] Waiting for: %s\n", i+1, len(inputSelectors), sel)
|
||||
locator := p.page.Locator(sel)
|
||||
if err := locator.WaitFor(playwright.LocatorWaitForOptions{
|
||||
Timeout: playwright.Float(5000),
|
||||
State: playwright.WaitForSelectorStateVisible,
|
||||
}); err == nil {
|
||||
fmt.Printf(" ✓ Input field found: %s\n", sel)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 策略 B: 等待頁面完全載入
|
||||
fmt.Println("[GeminiWeb] Waiting for page load...")
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// 策略 C: 使用 JavaScript 檢查
|
||||
fmt.Println("[GeminiWeb] Checking with JavaScript...")
|
||||
result, err := p.page.Evaluate(`
|
||||
() => {
|
||||
// 檢查所有可能的輸入元素
|
||||
const selectors = [
|
||||
'.ql-editor.ql-blank',
|
||||
'.ql-editor',
|
||||
'div[contenteditable="true"][role="textbox"]',
|
||||
'div[contenteditable="true"]',
|
||||
'.ProseMirror',
|
||||
'rich-textarea',
|
||||
'textarea'
|
||||
];
|
||||
|
||||
for (const sel of selectors) {
|
||||
const el = document.querySelector(sel);
|
||||
if (el) {
|
||||
return {
|
||||
found: true,
|
||||
selector: sel,
|
||||
tagName: el.tagName,
|
||||
className: el.className,
|
||||
visible: el.offsetParent !== null
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return { found: false };
|
||||
}
|
||||
`)
|
||||
|
||||
if err == nil {
|
||||
if m, ok := result.(map[string]interface{}); ok {
|
||||
if found, _ := m["found"].(bool); found {
|
||||
sel, _ := m["selector"].(string)
|
||||
fmt.Printf(" ✓ JavaScript found: %s\n", sel)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 策略 D: 調試模式 - 輸出頁面結構
|
||||
if p.cfg.GeminiBrowserVisible {
|
||||
fmt.Println("[GeminiWeb].dump: Page structure analysis")
|
||||
_, _ = p.page.Evaluate(`
|
||||
() => {
|
||||
const allElements = document.querySelectorAll('*');
|
||||
const inputLike = [];
|
||||
for (const el of allElements) {
|
||||
if (el.contentEditable === 'true' ||
|
||||
el.role === 'textbox' ||
|
||||
el.tagName === 'TEXTAREA' ||
|
||||
el.tagName === 'INPUT') {
|
||||
inputLike.push({
|
||||
tag: el.tagName,
|
||||
class: el.className,
|
||||
id: el.id,
|
||||
role: el.role,
|
||||
contentEditable: el.contentEditable
|
||||
});
|
||||
}
|
||||
}
|
||||
console.log('Input-like elements:', inputLike);
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
return fmt.Errorf("no input field found after all strategies")
|
||||
}
|
||||
|
||||
// isLoggedIn 檢查是否已登入
|
||||
func (p *PlaywrightProvider) isLoggedIn() bool {
|
||||
// 嘗試找輸入框(登入狀態的主要特徵)
|
||||
selectors := []string{
|
||||
".ProseMirror",
|
||||
"rich-textarea",
|
||||
"div[role='textbox']",
|
||||
"div[contenteditable='true']",
|
||||
"textarea",
|
||||
}
|
||||
|
||||
for _, sel := range selectors {
|
||||
locator := p.page.Locator(sel)
|
||||
if count, _ := locator.Count(); count > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// typeInput 輸入文字(使用 Playwright 的 Auto-wait)
|
||||
func (p *PlaywrightProvider) typeInput(text string) error {
|
||||
fmt.Println("[GeminiWeb] Looking for input field...")
|
||||
|
||||
selectors := []string{
|
||||
".ql-editor.ql-blank",
|
||||
".ql-editor",
|
||||
"div[contenteditable='true'][role='textbox']",
|
||||
"div[contenteditable='true']",
|
||||
".ProseMirror",
|
||||
"rich-textarea",
|
||||
"textarea",
|
||||
}
|
||||
|
||||
var inputLocator playwright.Locator
|
||||
var found bool
|
||||
|
||||
for _, sel := range selectors {
|
||||
fmt.Printf(" Trying: %s\n", sel)
|
||||
locator := p.page.Locator(sel)
|
||||
if err := locator.WaitFor(playwright.LocatorWaitForOptions{
|
||||
Timeout: playwright.Float(3000),
|
||||
}); err == nil {
|
||||
inputLocator = locator
|
||||
found = true
|
||||
fmt.Printf(" ✓ Found with: %s\n", sel)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
// 錯誤會被 Generate 的 defer 捕獲並保存診斷
|
||||
url := p.page.URL()
|
||||
title, _ := p.page.Title()
|
||||
return fmt.Errorf("input field not found (URL=%s, Title=%s). Diagnostics will be saved to /tmp/", url, title)
|
||||
}
|
||||
|
||||
// Focus 並填充(Playwright 自動等待)
|
||||
fmt.Printf("[GeminiWeb] Typing %d chars...\n", len(text))
|
||||
if err := inputLocator.Fill(text); err != nil {
|
||||
return fmt.Errorf("failed to fill: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("[GeminiWeb] Input complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendMessage 發送訊息
|
||||
func (p *PlaywrightProvider) sendMessage() error {
|
||||
// 方法 1: 按 Enter(最可靠)
|
||||
if err := p.page.Keyboard().Press("Enter"); err != nil {
|
||||
return fmt.Errorf("failed to press Enter: %w", err)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 方法 2: 嘗試點擊發送按鈕(補強)
|
||||
_, _ = p.page.Evaluate(`
|
||||
() => {
|
||||
const keywords = ['發送', 'Send', '傳送'];
|
||||
const buttons = Array.from(document.querySelectorAll('button, [role="button"]'));
|
||||
|
||||
for (const btn of buttons) {
|
||||
const text = (btn.innerText || btn.textContent || '').trim();
|
||||
const label = (btn.getAttribute('aria-label') || '').trim();
|
||||
|
||||
// 跳過停止按鈕
|
||||
if (['停止', 'Stop', '中斷'].includes(text) || label.toLowerCase().includes('stop')) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (keywords.some(kw => text.includes(kw) || label.includes(kw))) {
|
||||
btn.click();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
`)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractResponse 提取回應
|
||||
func (p *PlaywrightProvider) extractResponse() (string, error) {
|
||||
var lastText string
|
||||
var stableCount int
|
||||
lastUpdate := time.Now()
|
||||
timeout := 120 * time.Second
|
||||
startTime := time.Now()
|
||||
|
||||
for time.Since(startTime) < timeout {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// 使用 JavaScript 提取回應文字(更精確)
|
||||
result, err := p.page.Evaluate(`
|
||||
() => {
|
||||
// 尋找所有可能的回應容器
|
||||
const selectors = [
|
||||
'model-response',
|
||||
'.model-response',
|
||||
'message-content',
|
||||
'.message-content'
|
||||
];
|
||||
|
||||
for (const sel of selectors) {
|
||||
const el = document.querySelector(sel);
|
||||
if (el) {
|
||||
// 嘗試找markdown內容
|
||||
const markdown = el.querySelector('.markdown, .prose, [class*="markdown"]');
|
||||
if (markdown && markdown.innerText.trim()) {
|
||||
let text = markdown.innerText.trim();
|
||||
// 移除常見的標籤前綴
|
||||
text = text.replace(/^Gemini said\s*\n*/i, '').replace(/^Gemini\s*[::]\s*\n*/i, '').trim();
|
||||
return { text: text, source: sel + ' .markdown' };
|
||||
}
|
||||
|
||||
// 嘗試找純文字內容(排除標籤)
|
||||
let textContent = el.innerText.trim();
|
||||
if (textContent) {
|
||||
// 移除常見的標籤前綴
|
||||
textContent = textContent.replace(/^Gemini said\s*\n*/i, '').replace(/^Gemini\s*[::]\s*\n*/i, '').trim();
|
||||
return { text: textContent, source: sel };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { text: '', source: 'none' };
|
||||
}
|
||||
`)
|
||||
|
||||
if err == nil {
|
||||
if m, ok := result.(map[string]interface{}); ok {
|
||||
text, _ := m["text"].(string)
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
if text != "" && len(text) > len(lastText) {
|
||||
lastText = text
|
||||
lastUpdate = time.Now()
|
||||
stableCount = 0
|
||||
fmt.Printf("[GeminiWeb] Response: %d chars\n", len(text))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 檢查是否完成(需要連續 3 次穩定)
|
||||
if time.Since(lastUpdate) > 500*time.Millisecond && lastText != "" {
|
||||
stableCount++
|
||||
if stableCount >= 3 {
|
||||
// 最終檢查:停止按鈕是否還存在
|
||||
stopBtn := p.page.Locator("button[aria-label*='Stop'], button[aria-label*='停止'], button[data-test-id='stop-button']")
|
||||
count, _ := stopBtn.Count()
|
||||
|
||||
if count == 0 {
|
||||
fmt.Println("[GeminiWeb] ✓ Response complete")
|
||||
return lastText, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastText != "" {
|
||||
fmt.Println("[GeminiWeb] ✓ Response complete (timeout)")
|
||||
return lastText, nil
|
||||
}
|
||||
return "", fmt.Errorf("response timeout")
|
||||
}
|
||||
|
||||
// selectModel 選擇 Gemini 模型
|
||||
// Gemini Web 只有三種模型:fast, thinking, pro
|
||||
func (p *PlaywrightProvider) selectModel(model string) error {
|
||||
// 映射模型名稱到 Gemini Web 的模型選擇器
|
||||
modelMap := map[string]string{
|
||||
"fast": "Fast",
|
||||
"thinking": "Thinking",
|
||||
"pro": "Pro",
|
||||
"gemini-fast": "Fast",
|
||||
"gemini-thinking": "Thinking",
|
||||
"gemini-pro": "Pro",
|
||||
"gemini-2.0-fast": "Fast",
|
||||
"gemini-2.0-flash": "Fast", // 相容舊名稱
|
||||
"gemini-2.5-pro": "Pro",
|
||||
"gemini-2.5-pro-thinking": "Thinking",
|
||||
}
|
||||
|
||||
// 從完整模型名稱中提取類型
|
||||
modelType := ""
|
||||
modelLower := strings.ToLower(model)
|
||||
for key, value := range modelMap {
|
||||
if strings.Contains(modelLower, strings.ToLower(key)) || modelLower == strings.ToLower(key) {
|
||||
modelType = value
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if modelType == "" {
|
||||
// 默認使用 Fast
|
||||
fmt.Printf("[GeminiWeb] Unknown model '%s', defaulting to Fast\n", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("[GeminiWeb] Selecting model: %s\n", modelType)
|
||||
|
||||
// 點擊模型選擇器
|
||||
modelSelector := p.page.Locator("button[aria-label*='Model'], button[aria-label*='模型'], [data-test-id='model-selector']")
|
||||
if count, _ := modelSelector.Count(); count > 0 {
|
||||
if err := modelSelector.First().Click(); err != nil {
|
||||
fmt.Printf("[GeminiWeb] Warning: could not click model selector: %v\n", err)
|
||||
} else {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// 選擇對應的模型選項
|
||||
optionSelector := p.page.Locator(fmt.Sprintf("button:has-text('%s'), [role='menuitem']:has-text('%s')", modelType, modelType))
|
||||
if count, _ := optionSelector.Count(); count > 0 {
|
||||
if err := optionSelector.First().Click(); err != nil {
|
||||
fmt.Printf("[GeminiWeb] Warning: could not select model: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("[GeminiWeb] ✓ Model selected: %s\n", modelType)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPromptFromMessages 從訊息列表建構提示詞
|
||||
func buildPromptFromMessagesPlaywright(messages []apitypes.Message) string {
|
||||
var prompt string
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "system":
|
||||
prompt += "System: " + m.Content + "\n\n"
|
||||
case "user":
|
||||
prompt += m.Content + "\n\n"
|
||||
case "assistant":
|
||||
prompt += "Assistant: " + m.Content + "\n\n"
|
||||
}
|
||||
}
|
||||
return prompt
|
||||
}
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
package geminiweb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GeminiSession struct {
|
||||
Name string `json:"name"`
|
||||
CookieFile string `json:"cookie_file"`
|
||||
LastUsed int64 `json:"last_used"`
|
||||
ActiveCount int `json:"active_count"`
|
||||
RateLimitEnd int64 `json:"rate_limit_end"`
|
||||
}
|
||||
|
||||
type SessionPool struct {
|
||||
mu sync.Mutex
|
||||
sessions []*GeminiSession
|
||||
dir string
|
||||
maxCount int
|
||||
}
|
||||
|
||||
func NewSessionPool(dir string, maxSessions int) (*SessionPool, error) {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create session dir: %w", err)
|
||||
}
|
||||
|
||||
sessions, err := loadSessions(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load sessions: %w", err)
|
||||
}
|
||||
|
||||
return &SessionPool{
|
||||
sessions: sessions,
|
||||
dir: dir,
|
||||
maxCount: maxSessions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadSessions(dir string) ([]*GeminiSession, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sessions []*GeminiSession
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
metaPath := filepath.Join(dir, name, "session.json")
|
||||
data, err := os.ReadFile(metaPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var s GeminiSession
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, &s)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (p *SessionPool) Count() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.sessions)
|
||||
}
|
||||
|
||||
func (p *SessionPool) GetAvailable() *GeminiSession {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
var available []*GeminiSession
|
||||
for _, s := range p.sessions {
|
||||
if s.RateLimitEnd < now {
|
||||
available = append(available, s)
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var best *GeminiSession
|
||||
for _, s := range available {
|
||||
if best == nil || s.ActiveCount < best.ActiveCount {
|
||||
best = s
|
||||
} else if s.ActiveCount == best.ActiveCount && s.LastUsed < best.LastUsed {
|
||||
best = s
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *SessionPool) StartSession(s *GeminiSession) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
s.ActiveCount++
|
||||
s.LastUsed = time.Now().UnixMilli()
|
||||
p.saveSession(s)
|
||||
}
|
||||
|
||||
func (p *SessionPool) EndSession(s *GeminiSession) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if s.ActiveCount > 0 {
|
||||
s.ActiveCount--
|
||||
}
|
||||
p.saveSession(s)
|
||||
}
|
||||
|
||||
func (p *SessionPool) RateLimitSession(s *GeminiSession, durationMs int64) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
s.RateLimitEnd = time.Now().UnixMilli() + durationMs
|
||||
p.saveSession(s)
|
||||
}
|
||||
|
||||
func (p *SessionPool) saveSession(s *GeminiSession) {
|
||||
metaPath := filepath.Join(p.dir, s.Name, "session.json")
|
||||
data, err := json.MarshalIndent(s, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(metaPath, data, 0644)
|
||||
}
|
||||
|
||||
func (p *SessionPool) CreateSession(name string) (*GeminiSession, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
sessionDir := filepath.Join(p.dir, name)
|
||||
if err := os.MkdirAll(sessionDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create session dir: %w", err)
|
||||
}
|
||||
|
||||
s := &GeminiSession{
|
||||
Name: name,
|
||||
CookieFile: filepath.Join(sessionDir, "cookies.json"),
|
||||
LastUsed: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
p.sessions = append(p.sessions, s)
|
||||
p.saveSession(s)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (p *SessionPool) GetSessionNames() []string {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
names := make([]string, len(p.sessions))
|
||||
for i, s := range p.sessions {
|
||||
names[i] = s.Name
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
|
@ -1,196 +0,0 @@
|
|||
package geminiweb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/apitypes"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider 使用持久化瀏覽器管理器
|
||||
type Provider struct {
|
||||
cfg config.BridgeConfig
|
||||
managerOnce sync.Once
|
||||
manager *BrowserManager
|
||||
managerErr error
|
||||
}
|
||||
|
||||
// NewProvider 建立新的 Provider
|
||||
func NewProvider(cfg config.BridgeConfig) *Provider {
|
||||
return &Provider{cfg: cfg}
|
||||
}
|
||||
|
||||
// getName 返回 Provider 名稱
|
||||
func (p *Provider) Name() string {
|
||||
return "gemini-web"
|
||||
}
|
||||
|
||||
// Close 關閉瀏覽器
|
||||
func (p *Provider) Close() error {
|
||||
if p.manager != nil {
|
||||
return p.manager.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getManager 獲取或初始化瀏覽器管理器(單例)
|
||||
func (p *Provider) getManager() (*BrowserManager, error) {
|
||||
p.managerOnce.Do(func() {
|
||||
sessionDir := p.getSessionDir()
|
||||
p.manager, p.managerErr = GetBrowserManager(sessionDir, p.cfg.GeminiBrowserVisible)
|
||||
})
|
||||
return p.manager, p.managerErr
|
||||
}
|
||||
|
||||
// getSessionDir 獲取 session 目錄
|
||||
func (p *Provider) getSessionDir() string {
|
||||
// 使用單一 session 目錄(簡化設計)
|
||||
return filepath.Join(p.cfg.GeminiAccountDir, "default-session")
|
||||
}
|
||||
|
||||
// Generate 生成回應
|
||||
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
|
||||
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
|
||||
|
||||
// 1. 獲取瀏覽器管理器
|
||||
manager, err := p.getManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get browser manager: %w", err)
|
||||
}
|
||||
|
||||
// 2. 啟動瀏覽器(如果尚未啟動)
|
||||
if !manager.IsRunning() {
|
||||
fmt.Printf("[GeminiWeb] Launching browser...\n")
|
||||
if err := manager.Launch(); err != nil {
|
||||
return fmt.Errorf("failed to launch browser: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 獲取頁面
|
||||
page, err := manager.GetPage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get page: %w", err)
|
||||
}
|
||||
|
||||
// 4. 檢查當前 URL,如果不是 Gemini 則導航
|
||||
currentURL, _ := page.Info()
|
||||
if !strings.Contains(currentURL.URL, "gemini.google.com") {
|
||||
fmt.Printf("[GeminiWeb] Navigating to Gemini...\n")
|
||||
if err := NavigateToGemini(page); err != nil {
|
||||
return fmt.Errorf("failed to navigate: %w", err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
// 5. 檢查登入狀態
|
||||
fmt.Printf("[GeminiWeb] Checking login status...\n")
|
||||
if !IsLoggedIn(page) {
|
||||
fmt.Printf("[GeminiWeb] Not logged in, continuing anyway\n")
|
||||
|
||||
if p.cfg.GeminiBrowserVisible {
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println("Browser is open. You can:")
|
||||
fmt.Println("1. Log in to Gemini now")
|
||||
fmt.Println("2. Continue without login")
|
||||
fmt.Println("========================================\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("[GeminiWeb] Logged in\n")
|
||||
}
|
||||
|
||||
// 6. 等待頁面就緒
|
||||
if err := WaitForReady(page); err != nil {
|
||||
fmt.Printf("[GeminiWeb] Warning: %v\n", err)
|
||||
}
|
||||
|
||||
// 7. 建構提示詞
|
||||
prompt := buildPromptFromMessages(messages)
|
||||
fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt))
|
||||
|
||||
// 8. 輸入文字
|
||||
if err := TypeInput(page, prompt); err != nil {
|
||||
return fmt.Errorf("failed to type input: %w", err)
|
||||
}
|
||||
|
||||
// 9. 發送
|
||||
fmt.Printf("[GeminiWeb] Sending message...\n")
|
||||
if err := ClickSend(page); err != nil {
|
||||
return fmt.Errorf("failed to send: %w", err)
|
||||
}
|
||||
|
||||
// 10. 提取回應
|
||||
fmt.Printf("[GeminiWeb] Waiting for response...\n")
|
||||
response, err := ExtractResponse(page)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract response: %w", err)
|
||||
}
|
||||
|
||||
// 11. 串流回調
|
||||
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response})
|
||||
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
|
||||
|
||||
fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response))
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPromptFromMessages 從訊息列表建構提示詞
|
||||
func buildPromptFromMessages(messages []apitypes.Message) string {
|
||||
var prompt string
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "system":
|
||||
prompt += "System: " + m.Content + "\n\n"
|
||||
case "user":
|
||||
prompt += m.Content + "\n\n"
|
||||
case "assistant":
|
||||
prompt += "Assistant: " + m.Content + "\n\n"
|
||||
}
|
||||
}
|
||||
return prompt
|
||||
}
|
||||
|
||||
// RunLogin 執行登入流程(供 gemini-login 命令使用)
|
||||
func RunLogin(cfg config.BridgeConfig, sessionName string) error {
|
||||
if sessionName == "" {
|
||||
sessionName = "default-session"
|
||||
}
|
||||
|
||||
sessionDir := filepath.Join(cfg.GeminiAccountDir, sessionName)
|
||||
if err := os.MkdirAll(sessionDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create session dir: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Starting browser for login. Session: %s\n", sessionName)
|
||||
fmt.Printf("Session directory: %s\n", sessionDir)
|
||||
fmt.Println("Please log in to your Gemini account in the browser window.")
|
||||
fmt.Println("Press Ctrl+C when you have completed the login...")
|
||||
|
||||
manager, err := NewBrowserManager(sessionDir, true) // visible=true
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create browser manager: %w", err)
|
||||
}
|
||||
|
||||
if err := manager.Launch(); err != nil {
|
||||
return fmt.Errorf("failed to launch browser: %w", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
page, err := manager.GetPage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get page: %w", err)
|
||||
}
|
||||
|
||||
if err := NavigateToGemini(page); err != nil {
|
||||
return fmt.Errorf("failed to navigate: %w", err)
|
||||
}
|
||||
|
||||
// 等待用戶手動登入...
|
||||
// 使用 Ctrl+C 退出,瀏覽器資料會自動保存在 userDataDir
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/handlers"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RouterOptions struct {
|
||||
Version string
|
||||
Config config.BridgeConfig
|
||||
ModelCache *handlers.ModelCacheRef
|
||||
LastModel *string
|
||||
Pool pool.PoolHandle
|
||||
}
|
||||
|
||||
func NewRouter(opts RouterOptions) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cfg := opts.Config
|
||||
pathname := r.URL.Path
|
||||
method := r.Method
|
||||
remoteAddress := r.RemoteAddr
|
||||
if r.Header.Get("X-Real-IP") != "" {
|
||||
remoteAddress = r.Header.Get("X-Real-IP")
|
||||
}
|
||||
|
||||
logger.LogIncoming(method, pathname, remoteAddress)
|
||||
|
||||
defer func() {
|
||||
logger.AppendSessionLine(cfg.SessionsLogPath, method, pathname, remoteAddress, 200)
|
||||
}()
|
||||
|
||||
if cfg.RequiredKey != "" {
|
||||
token := httputil.ExtractBearerToken(r)
|
||||
if token != cfg.RequiredKey {
|
||||
httputil.WriteJSON(w, 401, map[string]interface{}{
|
||||
"error": map[string]string{"message": "Invalid API key", "code": "unauthorized"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case method == "GET" && pathname == "/health":
|
||||
handlers.HandleHealth(w, r, opts.Version, cfg)
|
||||
|
||||
case method == "GET" && pathname == "/v1/models":
|
||||
opts.ModelCache.HandleModels(w, r, cfg)
|
||||
|
||||
case method == "POST" && pathname == "/v1/chat/completions":
|
||||
raw, err := httputil.ReadBody(r)
|
||||
if err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"message": "failed to read body", "code": "bad_request"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
// 根據 Provider 選擇處理方式
|
||||
provider := cfg.Provider
|
||||
if provider == "" {
|
||||
provider = "cursor"
|
||||
}
|
||||
if provider == "gemini-web" {
|
||||
handlers.HandleGeminiChatCompletions(w, r, cfg, raw, method, pathname, remoteAddress)
|
||||
} else {
|
||||
handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
}
|
||||
|
||||
case method == "POST" && pathname == "/v1/messages":
|
||||
raw, err := httputil.ReadBody(r)
|
||||
if err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"message": "failed to read body", "code": "bad_request"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
handlers.HandleAnthropicMessages(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
|
||||
case (method == "POST" || method == "GET") && pathname == "/v1/completions":
|
||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
||||
"error": map[string]string{
|
||||
"message": "Legacy completions endpoint is not supported. Use POST /v1/chat/completions instead.",
|
||||
"code": "not_found",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
case pathname == "/v1/embeddings":
|
||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
||||
"error": map[string]string{
|
||||
"message": "Embeddings are not supported by this proxy.",
|
||||
"code": "not_found",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
default:
|
||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
||||
"error": map[string]string{"message": "Not found", "code": "not_found"},
|
||||
}, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func recoveryMiddleware(logPath string, next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
msg := fmt.Sprintf("%v", rec)
|
||||
fmt.Fprintf(os.Stderr, "[%s] Proxy panic: %s\n", time.Now().UTC().Format(time.RFC3339), msg)
|
||||
line := fmt.Sprintf("%s ERROR %s %s %s %s\n",
|
||||
time.Now().UTC().Format(time.RFC3339), r.Method, r.URL.Path, r.RemoteAddr,
|
||||
msg[:min(200, len(msg))])
|
||||
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
|
||||
_, _ = f.WriteString(line)
|
||||
f.Close()
|
||||
}
|
||||
if !isHeaderWritten(w) {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": msg, "code": "internal_error"},
|
||||
}, nil)
|
||||
}
|
||||
}
|
||||
}()
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func isHeaderWritten(w http.ResponseWriter) bool {
|
||||
// Can't reliably detect without wrapping; always try to write
|
||||
return false
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func WrapWithRecovery(logPath string, handler http.HandlerFunc) http.HandlerFunc {
|
||||
return recoveryMiddleware(logPath, handler)
|
||||
}
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
package sanitize
|
||||
|
||||
import "regexp"
|
||||
|
||||
type rule struct {
|
||||
pattern *regexp.Regexp
|
||||
replacement string
|
||||
}
|
||||
|
||||
var rules = []rule{
|
||||
{regexp.MustCompile(`(?i)x-anthropic-billing-header:[^\n]*\n?`), ""},
|
||||
{regexp.MustCompile(`(?i)\bcc_version=[^\s;,\n]+[;,]?\s*`), ""},
|
||||
{regexp.MustCompile(`(?i)\bcc_entrypoint=[^\s;,\n]+[;,]?\s*`), ""},
|
||||
{regexp.MustCompile(`(?i)\bcch=[a-f0-9]+[;,]?\s*`), ""},
|
||||
{regexp.MustCompile(`\bClaude Code\b`), "Cursor"},
|
||||
{regexp.MustCompile(`(?i)Anthropic['']s official CLI for Claude`), "Cursor AI assistant"},
|
||||
{regexp.MustCompile(`\bAnthropic\b`), "Cursor"},
|
||||
{regexp.MustCompile(`(?i)anthropic\.com`), "cursor.com"},
|
||||
{regexp.MustCompile(`(?i)claude\.ai`), "cursor.sh"},
|
||||
{regexp.MustCompile(`^[;,\s]+`), ""},
|
||||
}
|
||||
|
||||
func SanitizeText(text string) string {
|
||||
for _, r := range rules {
|
||||
text = r.pattern.ReplaceAllString(text, r.replacement)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func SanitizeMessages(messages []interface{}) []interface{} {
|
||||
result := make([]interface{}, len(messages))
|
||||
for i, raw := range messages {
|
||||
if raw == nil {
|
||||
result[i] = raw
|
||||
continue
|
||||
}
|
||||
m, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
result[i] = raw
|
||||
continue
|
||||
}
|
||||
newMsg := make(map[string]interface{}, len(m))
|
||||
for k, v := range m {
|
||||
newMsg[k] = v
|
||||
}
|
||||
switch c := m["content"].(type) {
|
||||
case string:
|
||||
newMsg["content"] = SanitizeText(c)
|
||||
case []interface{}:
|
||||
newParts := make([]interface{}, len(c))
|
||||
for j, p := range c {
|
||||
if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" {
|
||||
if t, ok := pm["text"].(string); ok {
|
||||
newPart := make(map[string]interface{}, len(pm))
|
||||
for k, v := range pm {
|
||||
newPart[k] = v
|
||||
}
|
||||
newPart["text"] = SanitizeText(t)
|
||||
newParts[j] = newPart
|
||||
continue
|
||||
}
|
||||
}
|
||||
newParts[j] = p
|
||||
}
|
||||
newMsg["content"] = newParts
|
||||
}
|
||||
result[i] = newMsg
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func SanitizeSystem(system interface{}) interface{} {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return SanitizeText(v)
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(v))
|
||||
for i, p := range v {
|
||||
if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" {
|
||||
if t, ok := pm["text"].(string); ok {
|
||||
newPart := make(map[string]interface{}, len(pm))
|
||||
for k, val := range pm {
|
||||
newPart[k] = val
|
||||
}
|
||||
newPart["text"] = SanitizeText(t)
|
||||
result[i] = newPart
|
||||
continue
|
||||
}
|
||||
}
|
||||
result[i] = p
|
||||
}
|
||||
return result
|
||||
}
|
||||
return system
|
||||
}
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
package sanitize
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeTextAnthropicBilling(t *testing.T) {
|
||||
input := "x-anthropic-billing-header: abc123\nHello"
|
||||
got := SanitizeText(input)
|
||||
if strings.Contains(got, "x-anthropic-billing-header") {
|
||||
t.Errorf("billing header not removed: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeTextClaudeCode(t *testing.T) {
|
||||
input := "I am Claude Code assistant"
|
||||
got := SanitizeText(input)
|
||||
if strings.Contains(got, "Claude Code") {
|
||||
t.Errorf("'Claude Code' not replaced: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Cursor") {
|
||||
t.Errorf("'Cursor' not present in output: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeTextAnthropic(t *testing.T) {
|
||||
input := "Powered by Anthropic's technology and anthropic.com"
|
||||
got := SanitizeText(input)
|
||||
if strings.Contains(got, "Anthropic") {
|
||||
t.Errorf("'Anthropic' not replaced: %q", got)
|
||||
}
|
||||
if strings.Contains(got, "anthropic.com") {
|
||||
t.Errorf("'anthropic.com' not replaced: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeTextNoChange(t *testing.T) {
|
||||
input := "Hello, this is a normal message about cursor."
|
||||
got := SanitizeText(input)
|
||||
if got != input {
|
||||
t.Errorf("unexpected change: %q -> %q", input, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeMessages(t *testing.T) {
|
||||
messages := []interface{}{
|
||||
map[string]interface{}{"role": "user", "content": "Ask Claude Code something"},
|
||||
map[string]interface{}{"role": "system", "content": "Use Anthropic's tools"},
|
||||
}
|
||||
result := SanitizeMessages(messages)
|
||||
|
||||
for _, raw := range result {
|
||||
m := raw.(map[string]interface{})
|
||||
c := m["content"].(string)
|
||||
if strings.Contains(c, "Claude Code") || strings.Contains(c, "Anthropic") {
|
||||
t.Errorf("found unsanitized content: %q", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,159 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/handlers"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"cursor-api-proxy/internal/process"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/router"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ServerOptions struct {
|
||||
Version string
|
||||
Config config.BridgeConfig
|
||||
Pool pool.PoolHandle
|
||||
}
|
||||
|
||||
func StartBridgeServer(opts ServerOptions) []*http.Server {
|
||||
cfg := opts.Config
|
||||
var servers []*http.Server
|
||||
|
||||
if len(cfg.ConfigDirs) > 0 {
|
||||
if cfg.MultiPort {
|
||||
for i, dir := range cfg.ConfigDirs {
|
||||
port := cfg.Port + i
|
||||
subCfg := cfg
|
||||
subCfg.Port = port
|
||||
subCfg.ConfigDirs = []string{dir}
|
||||
subCfg.MultiPort = false
|
||||
subPool := pool.NewAccountPool([]string{dir})
|
||||
srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg, Pool: subPool})
|
||||
servers = append(servers, srv)
|
||||
}
|
||||
return servers
|
||||
}
|
||||
pool.InitAccountPool(cfg.ConfigDirs)
|
||||
}
|
||||
|
||||
servers = append(servers, startSingleServer(opts))
|
||||
return servers
|
||||
}
|
||||
|
||||
func startSingleServer(opts ServerOptions) *http.Server {
|
||||
cfg := opts.Config
|
||||
|
||||
modelCache := &handlers.ModelCacheRef{}
|
||||
lastModel := cfg.DefaultModel
|
||||
|
||||
ph := opts.Pool
|
||||
if ph == nil {
|
||||
ph = pool.GlobalPoolHandle{}
|
||||
}
|
||||
handler := router.NewRouter(router.RouterOptions{
|
||||
Version: opts.Version,
|
||||
Config: cfg,
|
||||
ModelCache: modelCache,
|
||||
LastModel: &lastModel,
|
||||
Pool: ph,
|
||||
})
|
||||
handler = router.WrapWithRecovery(cfg.SessionsLogPath, handler)
|
||||
|
||||
useTLS := cfg.TLSCertPath != "" && cfg.TLSKeyPath != ""
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
if useTLS {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "TLS error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if useTLS {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
if useTLS {
|
||||
err = srv.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
if isAddrInUse(err) {
|
||||
fmt.Fprintf(os.Stderr, "❌ Port %d is already in use. Set CURSOR_BRIDGE_PORT to use a different port.\n", cfg.Port)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "❌ Server error: %v\n", err)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.LogServerStart(opts.Version, scheme, cfg.Host, cfg.Port, cfg)
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func SetupGracefulShutdown(servers []*http.Server, timeoutMs int) {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
sig := <-sigCh
|
||||
logger.LogShutdown(sig.String())
|
||||
|
||||
process.KillAllChildProcesses()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for _, srv := range servers {
|
||||
_ = srv.Shutdown(ctx)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
os.Exit(0)
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintln(os.Stderr, "[shutdown] Timed out waiting for connections to drain — forcing exit.")
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func isAddrInUse(err error) bool {
|
||||
return err != nil && (contains(err.Error(), "address already in use") || contains(err.Error(), "bind: address already in use"))
|
||||
}
|
||||
|
||||
func contains(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsHelper(s, sub))
|
||||
}
|
||||
|
||||
func containsHelper(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -1,331 +0,0 @@
|
|||
package server_test
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/server"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// freePort 取得一個暫時可用的隨機 port
|
||||
func freePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
return port
|
||||
}
|
||||
|
||||
// makeFakeAgentBin 建立一個 shell script,模擬 agent 固定輸出
|
||||
// sync 模式:直接輸出一行文字
|
||||
// stream 模式:輸出 JSON stream 行
|
||||
func makeFakeAgentBin(t *testing.T, syncOutput string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
script := dir + "/agent"
|
||||
content := fmt.Sprintf(`#!/bin/sh
|
||||
# 若有 --stream-json 則輸出 stream 格式
|
||||
for arg; do
|
||||
if [ "$arg" = "--stream-json" ]; then
|
||||
printf '%%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"%s"}]}}'
|
||||
printf '%%s\n' '{"type":"result","subtype":"success"}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
# 否則輸出 sync 格式
|
||||
printf '%%s' '%s'
|
||||
`, syncOutput, syncOutput)
|
||||
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
// makeFakeAgentBinWithModels 額外支援 --list-models 輸出
|
||||
func makeFakeAgentBinWithModels(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
script := dir + "/agent"
|
||||
content := `#!/bin/sh
|
||||
for arg; do
|
||||
if [ "$arg" = "--list-models" ]; then
|
||||
printf 'claude-3-opus - Claude 3 Opus\n'
|
||||
printf 'claude-3-sonnet - Claude 3 Sonnet\n'
|
||||
exit 0
|
||||
fi
|
||||
if [ "$arg" = "--stream-json" ]; then
|
||||
printf '%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}'
|
||||
printf '%s\n' '{"type":"result","subtype":"success"}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
printf 'Hello from agent'
|
||||
`
|
||||
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
func makeTestConfig(agentBin string, port int, overrides ...func(*config.BridgeConfig)) config.BridgeConfig {
|
||||
cfg := config.BridgeConfig{
|
||||
AgentBin: agentBin,
|
||||
Host: "127.0.0.1",
|
||||
Port: port,
|
||||
DefaultModel: "auto",
|
||||
Mode: "ask",
|
||||
Force: false,
|
||||
ApproveMcps: false,
|
||||
StrictModel: true,
|
||||
Workspace: os.TempDir(),
|
||||
TimeoutMs: 30000,
|
||||
SessionsLogPath: os.TempDir() + "/test-sessions.log",
|
||||
ChatOnlyWorkspace: true,
|
||||
Verbose: false,
|
||||
MaxMode: false,
|
||||
ConfigDirs: []string{},
|
||||
MultiPort: false,
|
||||
WinCmdlineMax: 30000,
|
||||
}
|
||||
for _, fn := range overrides {
|
||||
fn(&cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func waitListening(t *testing.T, host string, port int, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 50*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("server on port %d did not start within %v", port, timeout)
|
||||
}
|
||||
|
||||
func doRequest(t *testing.T, method, url, body string, headers map[string]string) (int, string) {
|
||||
t.Helper()
|
||||
var reqBody io.Reader
|
||||
if body != "" {
|
||||
reqBody = strings.NewReader(body)
|
||||
}
|
||||
req, err := http.NewRequest(method, url, reqBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
return resp.StatusCode, string(data)
|
||||
}
|
||||
|
||||
func TestBridgeServer_Health(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port)
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
||||
if status != 200 {
|
||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["ok"] != true {
|
||||
t.Errorf("ok = %v, want true", result["ok"])
|
||||
}
|
||||
if result["version"] != "1.0.0" {
|
||||
t.Errorf("version = %v, want 1.0.0", result["version"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_Models(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port)
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/v1/models", port), "", nil)
|
||||
if status != 200 {
|
||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["object"] != "list" {
|
||||
t.Errorf("object = %v, want list", result["object"])
|
||||
}
|
||||
data := result["data"].([]interface{})
|
||||
if len(data) < 2 {
|
||||
t.Errorf("data len = %d, want >= 2", len(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_Unauthorized(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
||||
c.RequiredKey = "secret123"
|
||||
})
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
||||
if status != 401 {
|
||||
t.Fatalf("status = %d, want 401; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
errObj := result["error"].(map[string]interface{})
|
||||
if errObj["message"] != "Invalid API key" {
|
||||
t.Errorf("message = %v, want 'Invalid API key'", errObj["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_AuthorizedKey(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
||||
c.RequiredKey = "secret123"
|
||||
})
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, _ := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", map[string]string{
|
||||
"Authorization": "Bearer secret123",
|
||||
})
|
||||
if status != 200 {
|
||||
t.Errorf("status = %d, want 200", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_NotFound(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
cfg := makeTestConfig(agentBin, port)
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/unknown", port), "", nil)
|
||||
if status != 404 {
|
||||
t.Fatalf("status = %d, want 404; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
errObj := result["error"].(map[string]interface{})
|
||||
if errObj["code"] != "not_found" {
|
||||
t.Errorf("code = %v, want not_found", errObj["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_ChatCompletions_Sync(t *testing.T) {
|
||||
port := freePort(t)
|
||||
agentBin := makeFakeAgentBin(t, "Hello from agent")
|
||||
cfg := makeTestConfig(agentBin, port)
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hi"}]}`
|
||||
status, body := doRequest(t, "POST", fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", port), reqBody, nil)
|
||||
if status != 200 {
|
||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["object"] != "chat.completion" {
|
||||
t.Errorf("object = %v, want chat.completion", result["object"])
|
||||
}
|
||||
choices := result["choices"].([]interface{})
|
||||
msg := choices[0].(map[string]interface{})["message"].(map[string]interface{})
|
||||
if msg["content"] != "Hello from agent" {
|
||||
t.Errorf("content = %v, want 'Hello from agent'", msg["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeServer_MultiPort(t *testing.T) {
|
||||
basePort := freePort(t)
|
||||
agentBin := makeFakeAgentBinWithModels(t)
|
||||
|
||||
dir1 := t.TempDir()
|
||||
dir2 := t.TempDir()
|
||||
|
||||
cfg := makeTestConfig(agentBin, basePort, func(c *config.BridgeConfig) {
|
||||
c.ConfigDirs = []string{dir1, dir2}
|
||||
c.MultiPort = true
|
||||
})
|
||||
|
||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||||
if len(srvs) != 2 {
|
||||
t.Fatalf("got %d servers, want 2", len(srvs))
|
||||
}
|
||||
|
||||
// 等待兩個 server 啟動(port 可能會衝突,這裡不嚴格測試 port 分配)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
defer func() {
|
||||
for _, s := range srvs {
|
||||
s.Shutdown(context.Background())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
package toolcall
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ToolCall struct {
|
||||
Name string
|
||||
Arguments string // JSON string
|
||||
}
|
||||
|
||||
type ParsedResponse struct {
|
||||
TextContent string
|
||||
ToolCalls []ToolCall
|
||||
}
|
||||
|
||||
func (p *ParsedResponse) HasToolCalls() bool {
|
||||
return len(p.ToolCalls) > 0
|
||||
}
|
||||
|
||||
// Modified regex to handle nested JSON
|
||||
var toolCallTagRe = regexp.MustCompile(`(?s)行政法规\s*(\{(?:[^{}]|\{[^{}]*\})*\})\s*ugalakh`)
|
||||
var antmlFunctionCallsRe = regexp.MustCompile(`(?s)<function_calls>\s*(.*?)\s*</function_calls>`)
|
||||
var antmlInvokeRe = regexp.MustCompile(`(?s)<invoke\s+name="([^"]+)">\s*(.*?)\s*</invoke>`)
|
||||
var antmlParamRe = regexp.MustCompile(`(?s)<parameter\s+name="([^"]+)">(.*?)</parameter>`)
|
||||
|
||||
func ExtractToolCalls(text string, toolNames map[string]bool) *ParsedResponse {
|
||||
result := &ParsedResponse{}
|
||||
|
||||
if locs := toolCallTagRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
|
||||
var calls []ToolCall
|
||||
var textParts []string
|
||||
last := 0
|
||||
for _, loc := range locs {
|
||||
if loc[0] > last {
|
||||
textParts = append(textParts, text[last:loc[0]])
|
||||
}
|
||||
jsonStr := text[loc[2]:loc[3]]
|
||||
if tc := parseToolCallJSON(jsonStr, toolNames); tc != nil {
|
||||
calls = append(calls, *tc)
|
||||
} else {
|
||||
textParts = append(textParts, text[loc[0]:loc[1]])
|
||||
}
|
||||
last = loc[1]
|
||||
}
|
||||
if last < len(text) {
|
||||
textParts = append(textParts, text[last:])
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
|
||||
result.ToolCalls = calls
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
if locs := antmlFunctionCallsRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
|
||||
var calls []ToolCall
|
||||
var textParts []string
|
||||
last := 0
|
||||
for _, loc := range locs {
|
||||
if loc[0] > last {
|
||||
textParts = append(textParts, text[last:loc[0]])
|
||||
}
|
||||
block := text[loc[2]:loc[3]]
|
||||
invokes := antmlInvokeRe.FindAllStringSubmatch(block, -1)
|
||||
for _, inv := range invokes {
|
||||
name := inv[1]
|
||||
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
|
||||
continue
|
||||
}
|
||||
body := inv[2]
|
||||
args := map[string]interface{}{}
|
||||
params := antmlParamRe.FindAllStringSubmatch(body, -1)
|
||||
for _, p := range params {
|
||||
paramName := p[1]
|
||||
paramValue := strings.TrimSpace(p[2])
|
||||
var jsonVal interface{}
|
||||
if err := json.Unmarshal([]byte(paramValue), &jsonVal); err == nil {
|
||||
args[paramName] = jsonVal
|
||||
} else {
|
||||
args[paramName] = paramValue
|
||||
}
|
||||
}
|
||||
argsJSON, _ := json.Marshal(args)
|
||||
calls = append(calls, ToolCall{Name: name, Arguments: string(argsJSON)})
|
||||
}
|
||||
last = loc[1]
|
||||
}
|
||||
if last < len(text) {
|
||||
textParts = append(textParts, text[last:])
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
|
||||
result.ToolCalls = calls
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
result.TextContent = text
|
||||
return result
|
||||
}
|
||||
|
||||
func parseToolCallJSON(jsonStr string, toolNames map[string]bool) *ToolCall {
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
name, _ := raw["name"].(string)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
|
||||
return nil
|
||||
}
|
||||
var argsStr string
|
||||
switch a := raw["arguments"].(type) {
|
||||
case string:
|
||||
argsStr = a
|
||||
case map[string]interface{}, []interface{}:
|
||||
b, _ := json.Marshal(a)
|
||||
argsStr = string(b)
|
||||
default:
|
||||
if p, ok := raw["parameters"]; ok {
|
||||
b, _ := json.Marshal(p)
|
||||
argsStr = string(b)
|
||||
} else {
|
||||
argsStr = "{}"
|
||||
}
|
||||
}
|
||||
return &ToolCall{Name: name, Arguments: argsStr}
|
||||
}
|
||||
|
||||
func CollectToolNames(tools []interface{}) map[string]bool {
|
||||
names := map[string]bool{}
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if m["type"] == "function" {
|
||||
if fn, ok := m["function"].(map[string]interface{}); ok {
|
||||
if name, ok := fn["name"].(string); ok {
|
||||
names[name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if name, ok := m["name"].(string); ok {
|
||||
names[name] = true
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
package winlimit
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/env"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n"
|
||||
const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n"
|
||||
|
||||
// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux.
|
||||
// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars.
|
||||
func safeLinuxArgMax() int {
|
||||
return 1536 * 1024
|
||||
}
|
||||
|
||||
type FitPromptResult struct {
|
||||
OK bool
|
||||
Args []string
|
||||
Truncated bool
|
||||
OriginalLength int
|
||||
FinalPromptLength int
|
||||
Error string
|
||||
}
|
||||
|
||||
func estimateCmdlineLength(resolved env.AgentCommand) int {
|
||||
argv := append([]string{resolved.Command}, resolved.Args...)
|
||||
if resolved.WindowsVerbatimArguments {
|
||||
n := 0
|
||||
for _, a := range argv {
|
||||
n += len(a)
|
||||
}
|
||||
if len(argv) > 1 {
|
||||
n += len(argv) - 1
|
||||
}
|
||||
return n + 512
|
||||
}
|
||||
dstLen := 0
|
||||
for _, a := range argv {
|
||||
dstLen += len(a)
|
||||
}
|
||||
dstLen = dstLen*2 + len(argv)*2
|
||||
if len(argv) > 1 {
|
||||
dstLen += len(argv) - 1
|
||||
}
|
||||
return dstLen + 512
|
||||
}
|
||||
|
||||
func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult {
|
||||
if runtime.GOOS != "windows" {
|
||||
return fitPromptLinux(fixedArgs, prompt)
|
||||
}
|
||||
|
||||
e := env.OsEnvToMap()
|
||||
measured := func(p string) int {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = p
|
||||
resolved := env.ResolveAgentCommand(agentBin, args, e, cwd)
|
||||
return estimateCmdlineLength(resolved)
|
||||
}
|
||||
|
||||
if measured("") > maxCmdline {
|
||||
return FitPromptResult{
|
||||
OK: false,
|
||||
Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.",
|
||||
}
|
||||
}
|
||||
|
||||
if measured(prompt) <= maxCmdline {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = prompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: false,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(prompt),
|
||||
}
|
||||
}
|
||||
|
||||
prefix := WinPromptOmissionPrefix
|
||||
if measured(prefix) > maxCmdline {
|
||||
return FitPromptResult{
|
||||
OK: false,
|
||||
Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.",
|
||||
}
|
||||
}
|
||||
|
||||
lo, hi, best := 0, len(prompt), 0
|
||||
for lo <= hi {
|
||||
mid := (lo + hi) / 2
|
||||
var tail string
|
||||
if mid > 0 {
|
||||
tail = prompt[len(prompt)-mid:]
|
||||
}
|
||||
candidate := prefix + tail
|
||||
if measured(candidate) <= maxCmdline {
|
||||
best = mid
|
||||
lo = mid + 1
|
||||
} else {
|
||||
hi = mid - 1
|
||||
}
|
||||
}
|
||||
|
||||
var finalPrompt string
|
||||
if best == 0 {
|
||||
finalPrompt = prefix
|
||||
} else {
|
||||
finalPrompt = prefix + prompt[len(prompt)-best:]
|
||||
}
|
||||
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = finalPrompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: true,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(finalPrompt),
|
||||
}
|
||||
}
|
||||
|
||||
// fitPromptLinux handles Linux ARG_MAX truncation.
|
||||
func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult {
|
||||
argMax := safeLinuxArgMax()
|
||||
|
||||
// Estimate total cmdline size: sum of all fixed args + prompt + null terminators
|
||||
fixedLen := 0
|
||||
for _, a := range fixedArgs {
|
||||
fixedLen += len(a) + 1
|
||||
}
|
||||
totalLen := fixedLen + len(prompt) + 1
|
||||
|
||||
if totalLen <= argMax {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = prompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: false,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(prompt),
|
||||
}
|
||||
}
|
||||
|
||||
// Need to truncate: keep the tail of the prompt (most recent messages)
|
||||
prefix := LinuxPromptOmissionPrefix
|
||||
available := argMax - fixedLen - len(prefix) - 1
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
var finalPrompt string
|
||||
if available <= 0 {
|
||||
finalPrompt = prefix
|
||||
} else if available >= len(prompt) {
|
||||
finalPrompt = prefix + prompt
|
||||
} else {
|
||||
finalPrompt = prefix + prompt[len(prompt)-available:]
|
||||
}
|
||||
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = finalPrompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: true,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(finalPrompt),
|
||||
}
|
||||
}
|
||||
|
||||
func WarnPromptTruncated(originalLength, finalLength int) {
|
||||
_ = originalLength
|
||||
_ = finalLength
|
||||
}
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
package winlimit
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNonWindowsPassThrough(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping non-Windows test on Windows")
|
||||
}
|
||||
|
||||
fixedArgs := []string{"--print", "--model", "gpt-4"}
|
||||
prompt := "Hello world"
|
||||
result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp")
|
||||
|
||||
if !result.OK {
|
||||
t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error)
|
||||
}
|
||||
if result.Truncated {
|
||||
t.Error("expected no truncation on non-Windows")
|
||||
}
|
||||
if result.OriginalLength != len(prompt) {
|
||||
t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength)
|
||||
}
|
||||
// Last arg should be the prompt
|
||||
if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt {
|
||||
t.Errorf("expected last arg to be prompt, got %v", result.Args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOmissionPrefix(t *testing.T) {
|
||||
if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") {
|
||||
t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
package workspace
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type WorkspaceResult struct {
|
||||
WorkspaceDir string
|
||||
TempDir string
|
||||
}
|
||||
|
||||
func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult {
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
tempDir, err := os.MkdirTemp("", "cursor-proxy-")
|
||||
if err != nil {
|
||||
tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback")
|
||||
_ = os.MkdirAll(tempDir, 0700)
|
||||
}
|
||||
return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir}
|
||||
}
|
||||
|
||||
headerWs := strings.TrimSpace(workspaceHeader)
|
||||
if headerWs != "" {
|
||||
return WorkspaceResult{WorkspaceDir: headerWs}
|
||||
}
|
||||
return WorkspaceResult{WorkspaceDir: cfg.Workspace}
|
||||
}
|
||||
73
main.go
73
main.go
|
|
@ -1,26 +1,56 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/cmd"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"cursor-api-proxy/internal/server"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/handler"
|
||||
"cursor-api-proxy/internal/svc"
|
||||
|
||||
cmd "cursor-api-proxy/cmd/cli"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
||||
const version = "1.0.0"
|
||||
|
||||
var configFile = flag.String("f", "etc/chat-api.yaml", "the config file")
|
||||
|
||||
func main() {
|
||||
args, err := cmd.ParseArgs(os.Args[1:])
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
// Check for CLI commands first (before flag.Parse)
|
||||
args := os.Args[1:]
|
||||
if len(args) > 0 {
|
||||
parsed, err := cmd.ParseArgs(args)
|
||||
if err != nil {
|
||||
// Not a CLI command, proceed to HTTP server
|
||||
} else if handleCLICommand(parsed) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP server mode (go-zero)
|
||||
flag.Parse()
|
||||
|
||||
var c config.Config
|
||||
conf.MustLoad(*configFile, &c)
|
||||
|
||||
server := rest.MustNewServer(c.RestConf)
|
||||
defer server.Stop()
|
||||
|
||||
ctx := svc.NewServiceContext(c)
|
||||
handler.RegisterHandlers(server, ctx)
|
||||
|
||||
fmt.Printf("Starting server at %s:%d...\n", c.Host, c.Port)
|
||||
server.Start()
|
||||
}
|
||||
|
||||
func handleCLICommand(args cmd.ParsedArgs) bool {
|
||||
if args.Help {
|
||||
cmd.PrintHelp(version)
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
if args.Login {
|
||||
|
|
@ -28,7 +58,7 @@ func main() {
|
|||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
if args.Logout {
|
||||
|
|
@ -36,7 +66,7 @@ func main() {
|
|||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
if args.AccountsList {
|
||||
|
|
@ -44,7 +74,7 @@ func main() {
|
|||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
if args.ResetHwid {
|
||||
|
|
@ -52,22 +82,9 @@ func main() {
|
|||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
e := env.OsEnvToMap()
|
||||
if args.Tailscale {
|
||||
e["CURSOR_BRIDGE_HOST"] = "0.0.0.0"
|
||||
}
|
||||
|
||||
cwd, _ := os.Getwd()
|
||||
cfg := config.LoadBridgeConfig(e, cwd)
|
||||
|
||||
servers := server.StartBridgeServer(server.ServerOptions{
|
||||
Version: version,
|
||||
Config: cfg,
|
||||
})
|
||||
server.SetupGracefulShutdown(servers, 10000)
|
||||
|
||||
select {}
|
||||
// Not a CLI command
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/pkg/domain/entity"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -192,7 +193,7 @@ func LogAccountAssigned(configDir string) {
|
|||
fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset)
|
||||
}
|
||||
|
||||
func LogAccountStats(verbose bool, stats []pool.AccountStat) {
|
||||
func LogAccountStats(verbose bool, stats []entity.AccountStat) {
|
||||
if !verbose || len(stats) == 0 {
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package process_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/process"
|
||||
"cursor-api-proxy/pkg/infrastructure/process"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ package process
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"cursor-api-proxy/pkg/infrastructure/env"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
package winlimit
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/env"
|
||||
"cursor-api-proxy/pkg/infrastructure/env"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -182,7 +182,7 @@ func (p *PlaywrightProvider) Generate(ctx context.Context, model string, message
|
|||
fmt.Println("Browser is open. You can:")
|
||||
fmt.Println("1. Log in to Gemini now")
|
||||
fmt.Println("2. Continue without login")
|
||||
fmt.Println("========================================\n")
|
||||
fmt.Println("========================================")
|
||||
}
|
||||
} else {
|
||||
fmt.Println("[GeminiWeb] ✓ Logged in")
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ func (p *Provider) Generate(ctx context.Context, model string, messages []entity
|
|||
fmt.Println("Browser is open. You can:")
|
||||
fmt.Println("1. Log in to Gemini now")
|
||||
fmt.Println("2. Continue without login")
|
||||
fmt.Println("========================================\n")
|
||||
fmt.Println("========================================")
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("[GeminiWeb] Logged in\n")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/providers/geminiweb"
|
||||
"cursor-api-proxy/pkg/provider/geminiweb"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ func analyzeDOM(page *rod.Page) {
|
|||
ariaLabel, _ := el.Attribute("aria-label")
|
||||
placeholder, _ := el.Attribute("placeholder")
|
||||
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s placeholder=%s\n",
|
||||
i, tag, class, ariaLabel, placeholder)
|
||||
i, tag, ptrToStr(class), ptrToStr(ariaLabel), ptrToStr(placeholder))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -120,7 +120,7 @@ func analyzeDOM(page *rod.Page) {
|
|||
text, _ := el.Text()
|
||||
text = truncate(text, 30)
|
||||
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n",
|
||||
i, tag, class, ariaLabel, text)
|
||||
i, tag, ptrToStr(class), ptrToStr(ariaLabel), text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -145,7 +145,7 @@ func analyzeDOM(page *rod.Page) {
|
|||
ariaLabel, _ := el.Attribute("aria-label")
|
||||
text, _ := el.Text()
|
||||
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n",
|
||||
i, tag, class, ariaLabel, truncate(text, 30))
|
||||
i, tag, ptrToStr(class), ptrToStr(ariaLabel), truncate(text, 30))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -157,3 +157,10 @@ func truncate(s string, max int) string {
|
|||
}
|
||||
return s[:max] + "..."
|
||||
}
|
||||
|
||||
func ptrToStr(s *string) string {
|
||||
if s == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue