refactor(task-2): migrate infrastructure layer

- Migrate process (runner, kill_unix, kill_windows)
- Migrate parser (stream)
- Migrate httputil (httputil)
- Migrate logger (logger)
- Migrate env (env)
- Migrate workspace (workspace)
- Migrate winlimit (winlimit)
This commit is contained in:
王性驊 2026-04-03 17:22:51 +08:00
parent 8b6abbbba7
commit 80d7a4bb29
14 changed files with 2085 additions and 0 deletions

381
pkg/infrastructure/env/env.go vendored Normal file
View File

@ -0,0 +1,381 @@
package env
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
)
type EnvSource map[string]string
type LoadedEnv struct {
AgentBin string
AgentNode string
AgentScript string
CommandShell string
Host string
Port int
RequiredKey string
DefaultModel string
Provider string
Force bool
ApproveMcps bool
StrictModel bool
Workspace string
TimeoutMs int
TLSCertPath string
TLSKeyPath string
SessionsLogPath string
ChatOnlyWorkspace bool
Verbose bool
MaxMode bool
ConfigDirs []string
MultiPort bool
WinCmdlineMax int
GeminiAccountDir string
GeminiBrowserVisible bool
GeminiMaxSessions int
}
type AgentCommand struct {
Command string
Args []string
Env map[string]string
WindowsVerbatimArguments bool
AgentScriptPath string
ConfigDir string
}
func getEnvVal(e EnvSource, names []string) string {
for _, name := range names {
if v, ok := e[name]; ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
}
return ""
}
func envBool(e EnvSource, names []string, def bool) bool {
raw := getEnvVal(e, names)
if raw == "" {
return def
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
return def
}
func envInt(e EnvSource, names []string, def int) int {
raw := getEnvVal(e, names)
if raw == "" {
return def
}
v, err := strconv.Atoi(raw)
if err != nil {
return def
}
return v
}
func normalizeModelId(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return "auto"
}
parts := strings.Split(raw, "/")
last := parts[len(parts)-1]
if last == "" {
return "auto"
}
return last
}
func resolveAbs(raw, cwd string) string {
if raw == "" {
return ""
}
if filepath.IsAbs(raw) {
return raw
}
return filepath.Join(cwd, raw)
}
func isAuthenticatedAccountDir(dir string) bool {
data, err := os.ReadFile(filepath.Join(dir, "cli-config.json"))
if err != nil {
return false
}
var cfg struct {
AuthInfo *struct {
Email string `json:"email"`
} `json:"authInfo"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
return cfg.AuthInfo != nil && cfg.AuthInfo.Email != ""
}
func discoverAccountDirs(homeDir string) []string {
if homeDir == "" {
return nil
}
accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts")
entries, err := os.ReadDir(accountsDir)
if err != nil {
return nil
}
var dirs []string
for _, e := range entries {
if !e.IsDir() {
continue
}
dir := filepath.Join(accountsDir, e.Name())
if isAuthenticatedAccountDir(dir) {
dirs = append(dirs, dir)
}
}
return dirs
}
func parseDotEnv(path string) EnvSource {
data, err := os.ReadFile(path)
if err != nil {
return nil
}
m := make(EnvSource)
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
m[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
}
return m
}
func OsEnvToMap(cwdHint ...string) EnvSource {
m := make(EnvSource)
for _, kv := range os.Environ() {
parts := strings.SplitN(kv, "=", 2)
if len(parts) == 2 {
m[parts[0]] = parts[1]
}
}
cwd := ""
if len(cwdHint) > 0 && cwdHint[0] != "" {
cwd = cwdHint[0]
} else {
cwd, _ = os.Getwd()
}
if dotenv := parseDotEnv(filepath.Join(cwd, ".env")); dotenv != nil {
for k, v := range dotenv {
if _, exists := m[k]; !exists {
m[k] = v
}
}
}
return m
}
func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv {
if e == nil {
e = OsEnvToMap()
}
if cwd == "" {
var err error
cwd, err = os.Getwd()
if err != nil {
cwd = "."
}
}
host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"})
if host == "" {
host = "127.0.0.1"
}
port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765)
if port <= 0 {
port = 8765
}
home := getEnvVal(e, []string{"HOME", "USERPROFILE"})
sessionsLogPath := func() string {
if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" {
return p
}
if home != "" {
return filepath.Join(home, ".cursor-api-proxy", "sessions.log")
}
return filepath.Join(cwd, "sessions.log")
}()
var configDirs []string
if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" {
for _, d := range strings.Split(raw, ",") {
d = strings.TrimSpace(d)
if d != "" {
if p := resolveAbs(d, cwd); p != "" {
configDirs = append(configDirs, p)
}
}
}
}
if len(configDirs) == 0 {
configDirs = discoverAccountDirs(home)
}
winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000)
if winMax < 4096 {
winMax = 4096
}
if winMax > 32700 {
winMax = 32700
}
agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"})
if agentBin == "" {
agentBin = "agent"
}
commandShell := getEnvVal(e, []string{"COMSPEC"})
if commandShell == "" {
commandShell = "cmd.exe"
}
workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd)
if workspace == "" {
workspace = cwd
}
geminiAccountDir := getEnvVal(e, []string{"GEMINI_ACCOUNT_DIR"})
if geminiAccountDir == "" {
geminiAccountDir = filepath.Join(home, ".cursor-api-proxy", "gemini-accounts")
} else {
geminiAccountDir = resolveAbs(geminiAccountDir, cwd)
}
return LoadedEnv{
AgentBin: agentBin,
AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}),
AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}),
CommandShell: commandShell,
Host: host,
Port: port,
RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}),
DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})),
Provider: getEnvVal(e, []string{"CURSOR_BRIDGE_PROVIDER"}),
Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false),
ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false),
StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true),
Workspace: workspace,
TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000),
TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd),
TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd),
SessionsLogPath: sessionsLogPath,
ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true),
Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false),
MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false),
ConfigDirs: configDirs,
MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false),
WinCmdlineMax: winMax,
GeminiAccountDir: geminiAccountDir,
GeminiBrowserVisible: envBool(e, []string{"GEMINI_BROWSER_VISIBLE"}, false),
GeminiMaxSessions: envInt(e, []string{"GEMINI_MAX_SESSIONS"}, 3),
}
}
func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand {
if e == nil {
e = OsEnvToMap()
}
loaded := LoadEnvConfig(e, cwd)
cloneEnv := func() map[string]string {
m := make(map[string]string, len(e))
for k, v := range e {
m[k] = v
}
return m
}
if runtime.GOOS == "windows" {
if loaded.AgentNode != "" && loaded.AgentScript != "" {
agentScriptPath := loaded.AgentScript
if !filepath.IsAbs(agentScriptPath) {
agentScriptPath = filepath.Join(cwd, agentScriptPath)
}
agentDir := filepath.Dir(agentScriptPath)
configDir := filepath.Join(agentDir, "..", "data", "config")
env2 := cloneEnv()
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
ac := AgentCommand{
Command: loaded.AgentNode,
Args: append([]string{loaded.AgentScript}, args...),
Env: env2,
AgentScriptPath: agentScriptPath,
}
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
ac.ConfigDir = configDir
}
return ac
}
if strings.HasSuffix(strings.ToLower(cmd), ".cmd") {
cmdResolved := cmd
if !filepath.IsAbs(cmd) {
cmdResolved = filepath.Join(cwd, cmd)
}
dir := filepath.Dir(cmdResolved)
nodeBin := filepath.Join(dir, "node.exe")
script := filepath.Join(dir, "index.js")
if _, err1 := os.Stat(nodeBin); err1 == nil {
if _, err2 := os.Stat(script); err2 == nil {
configDir := filepath.Join(dir, "..", "data", "config")
env2 := cloneEnv()
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
ac := AgentCommand{
Command: nodeBin,
Args: append([]string{script}, args...),
Env: env2,
AgentScriptPath: script,
}
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
ac.ConfigDir = configDir
}
return ac
}
}
quotedArgs := make([]string, len(args))
for i, a := range args {
if strings.Contains(a, " ") {
quotedArgs[i] = `"` + a + `"`
} else {
quotedArgs[i] = a
}
}
cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"`
return AgentCommand{
Command: loaded.CommandShell,
Args: []string{"/d", "/s", "/c", cmdLine},
Env: cloneEnv(),
WindowsVerbatimArguments: true,
}
}
}
return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()}
}

65
pkg/infrastructure/env/env_test.go vendored Normal file
View File

@ -0,0 +1,65 @@
package env
import "testing"
func TestLoadEnvConfigDefaults(t *testing.T) {
e := EnvSource{}
loaded := LoadEnvConfig(e, "/tmp")
if loaded.Host != "127.0.0.1" {
t.Errorf("expected 127.0.0.1, got %s", loaded.Host)
}
if loaded.Port != 8765 {
t.Errorf("expected 8765, got %d", loaded.Port)
}
if loaded.DefaultModel != "auto" {
t.Errorf("expected auto, got %s", loaded.DefaultModel)
}
if loaded.AgentBin != "agent" {
t.Errorf("expected agent, got %s", loaded.AgentBin)
}
if !loaded.StrictModel {
t.Error("expected strictModel=true by default")
}
}
func TestLoadEnvConfigOverride(t *testing.T) {
e := EnvSource{
"CURSOR_BRIDGE_HOST": "0.0.0.0",
"CURSOR_BRIDGE_PORT": "9000",
"CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4",
"CURSOR_AGENT_BIN": "/usr/local/bin/agent",
}
loaded := LoadEnvConfig(e, "/tmp")
if loaded.Host != "0.0.0.0" {
t.Errorf("expected 0.0.0.0, got %s", loaded.Host)
}
if loaded.Port != 9000 {
t.Errorf("expected 9000, got %d", loaded.Port)
}
if loaded.DefaultModel != "gpt-4" {
t.Errorf("expected gpt-4, got %s", loaded.DefaultModel)
}
if loaded.AgentBin != "/usr/local/bin/agent" {
t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin)
}
}
func TestNormalizeModelID(t *testing.T) {
tests := []struct {
input string
want string
}{
{"gpt-4", "gpt-4"},
{"openai/gpt-4", "gpt-4"},
{"", "auto"},
{" ", "auto"},
}
for _, tc := range tests {
got := normalizeModelId(tc.input)
if got != tc.want {
t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want)
}
}
}

View File

@ -0,0 +1,50 @@
package httputil
import (
"encoding/json"
"io"
"net/http"
"regexp"
)
var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`)
func ExtractBearerToken(r *http.Request) string {
h := r.Header.Get("Authorization")
if h == "" {
return ""
}
m := bearerRe.FindStringSubmatch(h)
if m == nil {
return ""
}
return m[1]
}
func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) {
for k, v := range extraHeaders {
w.Header().Set(k, v)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) {
for k, v := range extraHeaders {
w.Header().Set(k, v)
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(200)
}
func ReadBody(r *http.Request) (string, error) {
data, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}
return string(data), nil
}

View File

@ -0,0 +1,50 @@
package httputil
import (
"net/http/httptest"
"testing"
)
func TestExtractBearerToken(t *testing.T) {
tests := []struct {
header string
want string
}{
{"Bearer mytoken123", "mytoken123"},
{"bearer MYTOKEN", "MYTOKEN"},
{"", ""},
{"Basic abc", ""},
{"Bearer ", ""},
}
for _, tc := range tests {
req := httptest.NewRequest("GET", "/", nil)
if tc.header != "" {
req.Header.Set("Authorization", tc.header)
}
got := ExtractBearerToken(req)
if got != tc.want {
t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want)
}
}
}
func TestWriteJSON(t *testing.T) {
w := httptest.NewRecorder()
WriteJSON(w, 200, map[string]string{"ok": "true"}, nil)
if w.Code != 200 {
t.Errorf("expected 200, got %d", w.Code)
}
if w.Header().Get("Content-Type") != "application/json" {
t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type"))
}
}
func TestWriteJSONWithExtraHeaders(t *testing.T) {
w := httptest.NewRecorder()
WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"})
if w.Header().Get("X-Custom") != "value" {
t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom"))
}
}

View File

@ -0,0 +1,309 @@
package logger
import (
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/pool"
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
const (
cReset = "\x1b[0m"
cBold = "\x1b[1m"
cDim = "\x1b[2m"
cCyan = "\x1b[36m"
cBCyan = "\x1b[1;96m"
cGreen = "\x1b[32m"
cBGreen = "\x1b[1;92m"
cYellow = "\x1b[33m"
cMagenta = "\x1b[35m"
cBMagenta = "\x1b[1;95m"
cRed = "\x1b[31m"
cGray = "\x1b[90m"
cWhite = "\x1b[97m"
)
var roleStyle = map[string]string{
"system": cYellow,
"user": cCyan,
"assistant": cGreen,
}
var roleEmoji = map[string]string{
"system": "🔧",
"user": "👤",
"assistant": "🤖",
}
func ts() string {
return cGray + time.Now().UTC().Format("15:04:05") + cReset
}
func tsDate() string {
return cGray + time.Now().UTC().Format("2006-01-02 15:04:05") + cReset
}
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
head := int(float64(max) * 0.6)
tail := max - head
omitted := len(s) - head - tail
return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset
}
func hr(ch string, length int) string {
return cGray + strings.Repeat(ch, length) + cReset
}
type TrafficMessage struct {
Role string
Content string
}
func LogDebug(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg)
}
func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) {
provider := cfg.Provider
if provider == "" {
provider = "cursor"
}
fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset)
fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n",
cBold, cBCyan, cReset, cBold, cWhite, version, cBCyan, cReset)
fmt.Printf("%s%s╚══════════════════════════════════════════╝%s\n\n", cBold, cBCyan, cReset)
url := fmt.Sprintf("%s://%s:%d", scheme, host, port)
fmt.Printf(" %s●%s listening %s%s%s\n", cBGreen, cReset, cBold, url, cReset)
fmt.Printf(" %s▸%s provider %s%s%s\n", cCyan, cReset, cBold, provider, cReset)
fmt.Printf(" %s▸%s agent %s%s%s\n", cCyan, cReset, cDim, cfg.AgentBin, cReset)
fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset)
fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset)
fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset)
fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset)
// 顯示 Gemini Web Provider 相關設定
if provider == "gemini-web" {
fmt.Printf(" %s▸%s gemini-dir %s%s%s\n", cCyan, cReset, cDim, cfg.GeminiAccountDir, cReset)
fmt.Printf(" %s▸%s max-sess %s%d%s\n", cCyan, cReset, cDim, cfg.GeminiMaxSessions, cReset)
}
flags := []string{}
if cfg.Force {
flags = append(flags, "force")
}
if cfg.ApproveMcps {
flags = append(flags, "approve-mcps")
}
if cfg.MaxMode {
flags = append(flags, "max-mode")
}
if cfg.Verbose {
flags = append(flags, "verbose")
}
if cfg.ChatOnlyWorkspace {
flags = append(flags, "chat-only")
}
if cfg.RequiredKey != "" {
flags = append(flags, "api-key-required")
}
if len(flags) > 0 {
fmt.Printf(" %s▸%s flags %s%s%s\n", cCyan, cReset, cYellow, strings.Join(flags, " · "), cReset)
}
if len(cfg.ConfigDirs) > 0 {
fmt.Printf(" %s▸%s pool %s%d accounts%s\n", cCyan, cReset, cBGreen, len(cfg.ConfigDirs), cReset)
}
fmt.Println()
}
func LogShutdown(sig string) {
fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset)
}
func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) {
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
if isStream {
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
}
fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n",
ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag)
}
func LogRequestDone(method, pathname, model string, latencyMs int64, code int) {
statusColor := cBGreen
if code != 0 {
statusColor = cRed
}
fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n",
ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset)
}
func LogRequestTimeout(method, pathname, model string, timeoutMs int) {
fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n",
ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset)
}
func LogClientDisconnect(method, pathname, model string, latencyMs int64) {
fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n",
ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset)
}
func LogStreamChunk(model string, text string, chunkNum int) {
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120)
fmt.Printf("%s %s▸%s #%d %s%s%s\n",
ts(), cDim, cReset, chunkNum, cWhite, preview, cReset)
}
func LogRawLine(line string) {
preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200)
fmt.Printf("%s %s│%s %sraw%s %s\n",
ts(), cGray, cReset, cDim, cReset, preview)
}
func LogIncoming(method, pathname, remoteAddress string) {
methodColor := cBCyan
switch method {
case "POST":
methodColor = cBMagenta
case "GET":
methodColor = cBCyan
case "DELETE":
methodColor = cRed
}
fmt.Printf("%s %s%s%s%s %s%s%s %s(%s)%s\n",
ts(),
methodColor, cBold, method, cReset,
cWhite, pathname, cReset,
cDim, remoteAddress, cReset,
)
}
func LogAccountAssigned(configDir string) {
if configDir == "" {
return
}
name := filepath.Base(configDir)
fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset)
}
func LogAccountStats(verbose bool, stats []pool.AccountStat) {
if !verbose || len(stats) == 0 {
return
}
now := time.Now().UnixMilli()
fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset)
for _, s := range stats {
name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir))
active := fmt.Sprintf("%sactive:0%s", cDim, cReset)
if s.ActiveRequests > 0 {
active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset)
}
total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset)
ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset)
errStr := fmt.Sprintf("%serr:0%s", cDim, cReset)
if s.TotalErrors > 0 {
errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset)
}
rl := fmt.Sprintf("%srl:0%s", cDim, cReset)
if s.TotalRateLimits > 0 {
rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset)
}
avg := "avg:-"
if s.TotalRequests > 0 {
avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests))
}
status := fmt.Sprintf("%s✓%s", cGreen, cReset)
if s.IsRateLimited {
recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339)
_ = now
status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset)
}
fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n",
cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status)
}
fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset)
}
func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) {
if !verbose {
return
}
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
if isStream {
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
}
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
fmt.Println(hr("─", 60))
fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag)
for _, m := range messages {
roleColor := cWhite
if c, ok := roleStyle[m.Role]; ok {
roleColor = c
}
emoji := "💬"
if e, ok := roleEmoji[m.Role]; ok {
emoji = e
}
label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset)
charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset)
preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280)
fmt.Printf(" %s %s %s\n", emoji, label, charCount)
fmt.Printf(" %s%s%s\n", cDim, preview, cReset)
}
}
func LogTrafficResponse(verbose bool, model, text string, isStream bool) {
if !verbose {
return
}
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
if isStream {
modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset)
}
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset)
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480)
fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount)
fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset)
fmt.Println(hr("─", 60))
}
func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) {
line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode)
dir := filepath.Dir(logPath)
if err := os.MkdirAll(dir, 0755); err == nil {
f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err == nil {
_, _ = f.WriteString(line)
f.Close()
}
}
}
func LogTruncation(originalLen, finalLen int) {
fmt.Printf("%s %s⚠ prompt truncated%s %s(%d → %d chars, tail preserved)%s\n",
ts(), cYellow, cReset, cDim, originalLen, finalLen, cReset)
}
func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string {
errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr))
fmt.Fprintf(os.Stderr, "%s %s✗ agent error%s %s%s%s\n", ts(), cRed, cReset, cDim, errMsg, cReset)
truncated := strings.TrimSpace(stderr)
if len(truncated) > 200 {
truncated = truncated[:200]
}
truncated = strings.ReplaceAll(truncated, "\n", " ")
line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n",
time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated)
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
_, _ = f.WriteString(line)
f.Close()
}
return errMsg
}

View File

@ -0,0 +1,110 @@
package parser
import "encoding/json"
type StreamParser func(line string)
type Parser struct {
Parse StreamParser
Flush func()
}
// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking
func CreateStreamParser(onText func(string), onDone func()) Parser {
return CreateStreamParserWithThinking(onText, nil, onDone)
}
// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。
// onThinking 可為 nil表示忽略思考過程。
func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser {
// accumulated 是所有已輸出內容的串接
accumulatedText := ""
accumulatedThinking := ""
done := false
parse := func(line string) {
if done {
return
}
var obj struct {
Type string `json:"type"`
Subtype string `json:"subtype"`
Message *struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
Thinking string `json:"thinking"`
} `json:"content"`
} `json:"message"`
}
if err := json.Unmarshal([]byte(line), &obj); err != nil {
return
}
if obj.Type == "assistant" && obj.Message != nil {
fullText := ""
fullThinking := ""
for _, p := range obj.Message.Content {
switch p.Type {
case "text":
if p.Text != "" {
fullText += p.Text
}
case "thinking":
if p.Thinking != "" {
fullThinking += p.Thinking
}
}
}
// 處理思考過程(不因去重而 return避免跳過同行的文字內容
if onThinking != nil && fullThinking != "" && fullThinking != accumulatedThinking {
// 增量模式:新內容以 accumulated 為前綴
if len(fullThinking) >= len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking {
delta := fullThinking[len(accumulatedThinking):]
if delta != "" {
onThinking(delta)
}
accumulatedThinking = fullThinking
} else {
// 獨立片段:直接輸出,但 accumulated 要串接
onThinking(fullThinking)
accumulatedThinking = accumulatedThinking + fullThinking
}
}
// 處理一般文字
if fullText == "" || fullText == accumulatedText {
return
}
// 增量模式:新內容以 accumulated 為前綴
if len(fullText) >= len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText {
delta := fullText[len(accumulatedText):]
if delta != "" {
onText(delta)
}
accumulatedText = fullText
} else {
// 獨立片段:直接輸出,但 accumulated 要串接
onText(fullText)
accumulatedText = accumulatedText + fullText
}
}
if obj.Type == "result" && obj.Subtype == "success" {
done = true
onDone()
}
}
flush := func() {
if !done {
done = true
onDone()
}
}
return Parser{Parse: parse, Flush: flush}
}

View File

@ -0,0 +1,304 @@
package parser
import (
"encoding/json"
"testing"
)
func makeAssistantLine(text string) string {
obj := map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "text", "text": text},
},
},
}
b, _ := json.Marshal(obj)
return string(b)
}
func makeResultLine() string {
b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"})
return string(b)
}
func TestStreamParserFragmentMode(t *testing.T) {
// cursor --stream-partial-output 模式:每個訊息是獨立 token fragment
var texts []string
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() {},
)
p.Parse(makeAssistantLine("你"))
p.Parse(makeAssistantLine("好!有"))
p.Parse(makeAssistantLine("什"))
p.Parse(makeAssistantLine("麼"))
if len(texts) != 4 {
t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts)
}
if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" {
t.Fatalf("unexpected texts: %v", texts)
}
}
func TestStreamParserDeduplicatesFinalFullText(t *testing.T) {
// 最後一個訊息是完整的累積文字,應被跳過(去重)
var texts []string
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() {},
)
p.Parse(makeAssistantLine("Hello"))
p.Parse(makeAssistantLine(" world"))
// 最後一個是完整累積文字,應被去重
p.Parse(makeAssistantLine("Hello world"))
if len(texts) != 2 {
t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts)
}
if texts[0] != "Hello" || texts[1] != " world" {
t.Fatalf("unexpected texts: %v", texts)
}
}
func TestStreamParserCallsOnDone(t *testing.T) {
var texts []string
doneCount := 0
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() { doneCount++ },
)
p.Parse(makeResultLine())
if doneCount != 1 {
t.Fatalf("expected onDone called once, got %d", doneCount)
}
if len(texts) != 0 {
t.Fatalf("expected no text, got %v", texts)
}
}
func TestStreamParserIgnoresLinesAfterDone(t *testing.T) {
var texts []string
doneCount := 0
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() { doneCount++ },
)
p.Parse(makeResultLine())
p.Parse(makeAssistantLine("late"))
if len(texts) != 0 {
t.Fatalf("expected no text after done, got %v", texts)
}
if doneCount != 1 {
t.Fatalf("expected onDone called once, got %d", doneCount)
}
}
func TestStreamParserIgnoresNonAssistantLines(t *testing.T) {
var texts []string
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() {},
)
b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}})
p.Parse(string(b1))
b2, _ := json.Marshal(map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{"content": []interface{}{}},
})
p.Parse(string(b2))
p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`)
if len(texts) != 0 {
t.Fatalf("expected no texts, got %v", texts)
}
}
func TestStreamParserIgnoresParseErrors(t *testing.T) {
var texts []string
doneCount := 0
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() { doneCount++ },
)
p.Parse("not json")
p.Parse("{")
p.Parse("")
if len(texts) != 0 || doneCount != 0 {
t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount)
}
}
func TestStreamParserJoinsMultipleTextParts(t *testing.T) {
var texts []string
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() {},
)
obj := map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "text", "text": "Hello"},
{"type": "text", "text": " "},
{"type": "text", "text": "world"},
},
},
}
b, _ := json.Marshal(obj)
p.Parse(string(b))
if len(texts) != 1 || texts[0] != "Hello world" {
t.Fatalf("expected ['Hello world'], got %v", texts)
}
}
func TestStreamParserFlushTriggersDone(t *testing.T) {
var texts []string
doneCount := 0
p := CreateStreamParser(
func(text string) { texts = append(texts, text) },
func() { doneCount++ },
)
p.Parse(makeAssistantLine("Hello"))
// agent 結束但沒有 result/success手動 flush
p.Flush()
if doneCount != 1 {
t.Fatalf("expected onDone called once after Flush, got %d", doneCount)
}
// 再 flush 不應重複觸發
p.Flush()
if doneCount != 1 {
t.Fatalf("expected onDone called only once, got %d", doneCount)
}
}
func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) {
doneCount := 0
p := CreateStreamParser(
func(text string) {},
func() { doneCount++ },
)
p.Parse(makeResultLine())
p.Flush()
if doneCount != 1 {
t.Fatalf("expected onDone called once, got %d", doneCount)
}
}
func makeThinkingLine(thinking string) string {
obj := map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "thinking", "thinking": thinking},
},
},
}
b, _ := json.Marshal(obj)
return string(b)
}
func makeThinkingAndTextLine(thinking, text string) string {
obj := map[string]interface{}{
"type": "assistant",
"message": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "thinking", "thinking": thinking},
{"type": "text", "text": text},
},
},
}
b, _ := json.Marshal(obj)
return string(b)
}
func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) {
var texts []string
var thinkings []string
p := CreateStreamParserWithThinking(
func(text string) { texts = append(texts, text) },
func(thinking string) { thinkings = append(thinkings, thinking) },
func() {},
)
p.Parse(makeThinkingLine("思考中..."))
p.Parse(makeAssistantLine("回答"))
if len(thinkings) != 1 || thinkings[0] != "思考中..." {
t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings)
}
if len(texts) != 1 || texts[0] != "回答" {
t.Fatalf("expected texts=['回答'], got %v", texts)
}
}
func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) {
var texts []string
p := CreateStreamParserWithThinking(
func(text string) { texts = append(texts, text) },
nil,
func() {},
)
p.Parse(makeThinkingLine("忽略的思考"))
p.Parse(makeAssistantLine("文字"))
if len(texts) != 1 || texts[0] != "文字" {
t.Fatalf("expected texts=['文字'], got %v", texts)
}
}
func TestStreamParserWithThinkingDeduplication(t *testing.T) {
var thinkings []string
p := CreateStreamParserWithThinking(
func(text string) {},
func(thinking string) { thinkings = append(thinkings, thinking) },
func() {},
)
p.Parse(makeThinkingLine("A"))
p.Parse(makeThinkingLine("B"))
// 重複的完整思考,應被跳過
p.Parse(makeThinkingLine("AB"))
if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" {
t.Fatalf("expected thinkings=['A','B'], got %v", thinkings)
}
}
// TestStreamParserThinkingDuplicateButTextStillEmitted 驗證 bug 修復:
// 當 thinking 重複(去重跳過)但同一行有 text 時text 仍必須輸出。
func TestStreamParserThinkingDuplicateButTextStillEmitted(t *testing.T) {
var texts []string
var thinkings []string
p := CreateStreamParserWithThinking(
func(text string) { texts = append(texts, text) },
func(thinking string) { thinkings = append(thinkings, thinking) },
func() {},
)
// 第一行thinking="思考中" + textthinking 為新增,兩者都應輸出)
p.Parse(makeThinkingAndTextLine("思考中", "第一段"))
// 第二行thinking 與上一行相同(去重),但 text 是新的text 仍應輸出
p.Parse(makeThinkingAndTextLine("思考中", "第二段"))
if len(thinkings) != 1 || thinkings[0] != "思考中" {
t.Fatalf("expected thinkings=['思考中'], got %v", thinkings)
}
if len(texts) != 2 || texts[0] != "第一段" || texts[1] != "第二段" {
t.Fatalf("expected texts=['第一段','第二段'], got %v", texts)
}
}

View File

@ -0,0 +1,21 @@
//go:build !windows
package process
import (
"os/exec"
"syscall"
)
func killProcessGroup(c *exec.Cmd) error {
if c.Process == nil {
return nil
}
// 殺死整個 process group負號表示 group
pgid, err := syscall.Getpgid(c.Process.Pid)
if err == nil {
_ = syscall.Kill(-pgid, syscall.SIGKILL)
}
// 同時也 kill 主程序,以防萬一
return c.Process.Kill()
}

View File

@ -0,0 +1,14 @@
//go:build windows
package process
import (
"os/exec"
)
func killProcessGroup(c *exec.Cmd) error {
if c.Process == nil {
return nil
}
return c.Process.Kill()
}

View File

@ -0,0 +1,283 @@
package process_test
import (
"context"
"cursor-api-proxy/internal/process"
"os"
"testing"
"time"
)
// sh 是跨平台 shell 執行小 script 的輔助函式
func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) {
t.Helper()
return process.Run("sh", []string{"-c", script}, opts)
}
func TestRun_StdoutAndStderr(t *testing.T) {
result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{})
if err != nil {
t.Fatal(err)
}
if result.Code != 0 {
t.Errorf("Code = %d, want 0", result.Code)
}
if result.Stdout != "hello\n" {
t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n")
}
if result.Stderr != "world\n" {
t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n")
}
}
func TestRun_BasicSpawn(t *testing.T) {
result, err := sh(t, "printf ok", process.RunOptions{})
if err != nil {
t.Fatal(err)
}
if result.Code != 0 {
t.Errorf("Code = %d, want 0", result.Code)
}
if result.Stdout != "ok" {
t.Errorf("Stdout = %q, want ok", result.Stdout)
}
}
func TestRun_ConfigDir_Propagated(t *testing.T) {
result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
process.RunOptions{ConfigDir: "/test/account/dir"})
if err != nil {
t.Fatal(err)
}
if result.Stdout != "/test/account/dir" {
t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout)
}
}
func TestRun_ConfigDir_Absent(t *testing.T) {
// 確保沒有殘留的環境變數
_ = os.Unsetenv("CURSOR_CONFIG_DIR")
result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`},
process.RunOptions{})
if err != nil {
t.Fatal(err)
}
if result.Stdout != "unset" {
t.Errorf("Stdout = %q, want unset", result.Stdout)
}
}
func TestRun_NonZeroExit(t *testing.T) {
result, err := sh(t, "exit 42", process.RunOptions{})
if err != nil {
t.Fatal(err)
}
if result.Code != 42 {
t.Errorf("Code = %d, want 42", result.Code)
}
}
func TestRun_Timeout(t *testing.T) {
start := time.Now()
result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300})
elapsed := time.Since(start)
if err != nil {
t.Fatal(err)
}
if result.Code == 0 {
t.Error("expected non-zero exit code after timeout")
}
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestRunStreaming_OnLine(t *testing.T) {
var lines []string
result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"},
process.RunStreamingOptions{
OnLine: func(line string) { lines = append(lines, line) },
})
if err != nil {
t.Fatal(err)
}
if result.Code != 0 {
t.Errorf("Code = %d, want 0", result.Code)
}
if len(lines) != 3 {
t.Errorf("got %d lines, want 3: %v", len(lines), lines)
}
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
t.Errorf("lines = %v, want [a b c]", lines)
}
}
func TestRunStreaming_FlushFinalLine(t *testing.T) {
var lines []string
result, err := process.RunStreaming("sh", []string{"-c", "printf tail"},
process.RunStreamingOptions{
OnLine: func(line string) { lines = append(lines, line) },
})
if err != nil {
t.Fatal(err)
}
if result.Code != 0 {
t.Errorf("Code = %d, want 0", result.Code)
}
if len(lines) != 1 {
t.Errorf("got %d lines, want 1: %v", len(lines), lines)
}
if lines[0] != "tail" {
t.Errorf("lines[0] = %q, want tail", lines[0])
}
}
func TestRunStreaming_ConfigDir(t *testing.T) {
var lines []string
_, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
process.RunStreamingOptions{
RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"},
OnLine: func(line string) { lines = append(lines, line) },
})
if err != nil {
t.Fatal(err)
}
if len(lines) != 1 || lines[0] != "/my/config/dir" {
t.Errorf("lines = %v, want [/my/config/dir]", lines)
}
}
func TestRunStreaming_Stderr(t *testing.T) {
result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"},
process.RunStreamingOptions{OnLine: func(string) {}})
if err != nil {
t.Fatal(err)
}
if result.Stderr == "" {
t.Error("expected stderr to contain output")
}
}
func TestRunStreaming_Timeout(t *testing.T) {
start := time.Now()
result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"},
process.RunStreamingOptions{
RunOptions: process.RunOptions{TimeoutMs: 300},
OnLine: func(string) {},
})
elapsed := time.Since(start)
if err != nil {
t.Fatal(err)
}
if result.Code == 0 {
t.Error("expected non-zero exit code after timeout")
}
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestRunStreaming_Concurrent(t *testing.T) {
var lines1, lines2 []string
done := make(chan struct{}, 2)
run := func(label string, target *[]string) {
process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"},
process.RunStreamingOptions{
OnLine: func(line string) { *target = append(*target, line) },
})
done <- struct{}{}
}
go run("stream1", &lines1)
go run("stream2", &lines2)
<-done
<-done
if len(lines1) != 1 || lines1[0] != "stream1" {
t.Errorf("lines1 = %v, want [stream1]", lines1)
}
if len(lines2) != 1 || lines2[0] != "stream2" {
t.Errorf("lines2 = %v, want [stream2]", lines2)
}
}
func TestRunStreaming_ContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
start := time.Now()
done := make(chan struct{})
go func() {
process.RunStreaming("sh", []string{"-c", "sleep 30"},
process.RunStreamingOptions{
RunOptions: process.RunOptions{Ctx: ctx},
OnLine: func(string) {},
})
close(done)
}()
time.AfterFunc(100*time.Millisecond, cancel)
<-done
elapsed := time.Since(start)
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestRun_ContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
start := time.Now()
done := make(chan process.RunResult, 1)
go func() {
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
done <- r
}()
time.AfterFunc(100*time.Millisecond, cancel)
result := <-done
elapsed := time.Since(start)
if result.Code == 0 {
t.Error("expected non-zero exit code after cancel")
}
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestRun_AlreadyCancelledContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // 已取消
start := time.Now()
result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
elapsed := time.Since(start)
if result.Code == 0 {
t.Error("expected non-zero exit code")
}
if elapsed > 2*time.Second {
t.Errorf("elapsed %v, want < 2s", elapsed)
}
}
func TestKillAllChildProcesses(t *testing.T) {
done := make(chan process.RunResult, 1)
go func() {
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{})
done <- r
}()
time.Sleep(80 * time.Millisecond)
process.KillAllChildProcesses()
result := <-done
if result.Code == 0 {
t.Error("expected non-zero exit code after kill")
}
// 再次呼叫不應 panic
process.KillAllChildProcesses()
}

View File

@ -0,0 +1,250 @@
package process
import (
"bufio"
"context"
"cursor-api-proxy/internal/env"
"fmt"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)
type RunResult struct {
Code int
Stdout string
Stderr string
}
type RunOptions struct {
Cwd string
TimeoutMs int
MaxMode bool
ConfigDir string
Ctx context.Context
}
type RunStreamingOptions struct {
RunOptions
OnLine func(line string)
}
// ─── Global child process registry ──────────────────────────────────────────
var (
activeMu sync.Mutex
activeChildren []*exec.Cmd
)
func registerChild(c *exec.Cmd) {
activeMu.Lock()
activeChildren = append(activeChildren, c)
activeMu.Unlock()
}
func unregisterChild(c *exec.Cmd) {
activeMu.Lock()
for i, ch := range activeChildren {
if ch == c {
activeChildren = append(activeChildren[:i], activeChildren[i+1:]...)
break
}
}
activeMu.Unlock()
}
func KillAllChildProcesses() {
activeMu.Lock()
all := make([]*exec.Cmd, len(activeChildren))
copy(all, activeChildren)
activeChildren = nil
activeMu.Unlock()
for _, c := range all {
killProcessGroup(c)
}
}
// ─── Spawn ────────────────────────────────────────────────────────────────
func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd {
envSrc := env.OsEnvToMap()
resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd)
if opts.MaxMode && maxModeFn != nil {
maxModeFn(resolved.AgentScriptPath, opts.ConfigDir)
}
envMap := make(map[string]string, len(resolved.Env))
for k, v := range resolved.Env {
envMap[k] = v
}
if opts.ConfigDir != "" {
envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir
} else if resolved.ConfigDir != "" {
if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists {
envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir
}
}
envSlice := make([]string, 0, len(envMap))
for k, v := range envMap {
envSlice = append(envSlice, k+"="+v)
}
ctx := opts.Ctx
if ctx == nil {
ctx = context.Background()
}
// 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出
c := exec.CommandContext(ctx, resolved.Command, resolved.Args...)
c.Dir = opts.Cwd
c.Env = envSlice
// 設定新的 process group使 kill 能傳遞給所有子孫程序
c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
// WaitDelaycontext cancel 後額外等待這麼久再強制關閉 pipes
c.WaitDelay = 5 * time.Second
// Cancel 函式:殺死整個 process group
c.Cancel = func() error {
return killProcessGroup(c)
}
return c
}
// MaxModeFn is set by the agent package to avoid import cycle.
var MaxModeFn func(agentScriptPath, configDir string)
func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) {
ctx := opts.Ctx
var cancel context.CancelFunc
if opts.TimeoutMs > 0 {
if ctx == nil {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
} else {
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
}
defer cancel()
opts.Ctx = ctx
} else if ctx == nil {
opts.Ctx = context.Background()
}
c := spawnChild(cmdStr, args, &opts, MaxModeFn)
var stdoutBuf, stderrBuf strings.Builder
c.Stdout = &stdoutBuf
c.Stderr = &stderrBuf
if err := c.Start(); err != nil {
// context 已取消或命令找不到時
if opts.Ctx != nil && opts.Ctx.Err() != nil {
return RunResult{Code: -1}, nil
}
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
}
return RunResult{}, err
}
registerChild(c)
defer unregisterChild(c)
err := c.Wait()
code := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
code = exitErr.ExitCode()
if code == 0 {
code = -1
}
} else {
// context cancelled or killed — return -1 but no error
return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil
}
}
return RunResult{
Code: code,
Stdout: stdoutBuf.String(),
Stderr: stderrBuf.String(),
}, nil
}
type StreamResult struct {
Code int
Stderr string
}
func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) {
ctx := opts.Ctx
var cancel context.CancelFunc
if opts.TimeoutMs > 0 {
if ctx == nil {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
} else {
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
}
defer cancel()
opts.RunOptions.Ctx = ctx
} else if opts.RunOptions.Ctx == nil {
opts.RunOptions.Ctx = context.Background()
}
c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn)
stdoutPipe, err := c.StdoutPipe()
if err != nil {
return StreamResult{}, err
}
stderrPipe, err := c.StderrPipe()
if err != nil {
return StreamResult{}, err
}
if err := c.Start(); err != nil {
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
}
return StreamResult{}, err
}
registerChild(c)
defer unregisterChild(c)
var stderrBuf strings.Builder
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stdoutPipe)
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) != "" {
opts.OnLine(line)
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stderrPipe)
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
for scanner.Scan() {
stderrBuf.WriteString(scanner.Text())
stderrBuf.WriteString("\n")
}
}()
wg.Wait()
err = c.Wait()
code := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
code = exitErr.ExitCode()
if code == 0 {
code = -1
}
}
}
return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil
}

View File

@ -0,0 +1,181 @@
package winlimit
import (
"cursor-api-proxy/internal/env"
"runtime"
)
const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n"
const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n"
// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux.
// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars.
func safeLinuxArgMax() int {
return 1536 * 1024
}
type FitPromptResult struct {
OK bool
Args []string
Truncated bool
OriginalLength int
FinalPromptLength int
Error string
}
func estimateCmdlineLength(resolved env.AgentCommand) int {
argv := append([]string{resolved.Command}, resolved.Args...)
if resolved.WindowsVerbatimArguments {
n := 0
for _, a := range argv {
n += len(a)
}
if len(argv) > 1 {
n += len(argv) - 1
}
return n + 512
}
dstLen := 0
for _, a := range argv {
dstLen += len(a)
}
dstLen = dstLen*2 + len(argv)*2
if len(argv) > 1 {
dstLen += len(argv) - 1
}
return dstLen + 512
}
func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult {
if runtime.GOOS != "windows" {
return fitPromptLinux(fixedArgs, prompt)
}
e := env.OsEnvToMap()
measured := func(p string) int {
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = p
resolved := env.ResolveAgentCommand(agentBin, args, e, cwd)
return estimateCmdlineLength(resolved)
}
if measured("") > maxCmdline {
return FitPromptResult{
OK: false,
Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.",
}
}
if measured(prompt) <= maxCmdline {
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = prompt
return FitPromptResult{
OK: true,
Args: args,
Truncated: false,
OriginalLength: len(prompt),
FinalPromptLength: len(prompt),
}
}
prefix := WinPromptOmissionPrefix
if measured(prefix) > maxCmdline {
return FitPromptResult{
OK: false,
Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.",
}
}
lo, hi, best := 0, len(prompt), 0
for lo <= hi {
mid := (lo + hi) / 2
var tail string
if mid > 0 {
tail = prompt[len(prompt)-mid:]
}
candidate := prefix + tail
if measured(candidate) <= maxCmdline {
best = mid
lo = mid + 1
} else {
hi = mid - 1
}
}
var finalPrompt string
if best == 0 {
finalPrompt = prefix
} else {
finalPrompt = prefix + prompt[len(prompt)-best:]
}
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = finalPrompt
return FitPromptResult{
OK: true,
Args: args,
Truncated: true,
OriginalLength: len(prompt),
FinalPromptLength: len(finalPrompt),
}
}
// fitPromptLinux handles Linux ARG_MAX truncation.
func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult {
argMax := safeLinuxArgMax()
// Estimate total cmdline size: sum of all fixed args + prompt + null terminators
fixedLen := 0
for _, a := range fixedArgs {
fixedLen += len(a) + 1
}
totalLen := fixedLen + len(prompt) + 1
if totalLen <= argMax {
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = prompt
return FitPromptResult{
OK: true,
Args: args,
Truncated: false,
OriginalLength: len(prompt),
FinalPromptLength: len(prompt),
}
}
// Need to truncate: keep the tail of the prompt (most recent messages)
prefix := LinuxPromptOmissionPrefix
available := argMax - fixedLen - len(prefix) - 1
if available < 0 {
available = 0
}
var finalPrompt string
if available <= 0 {
finalPrompt = prefix
} else if available >= len(prompt) {
finalPrompt = prefix + prompt
} else {
finalPrompt = prefix + prompt[len(prompt)-available:]
}
args := make([]string, len(fixedArgs)+1)
copy(args, fixedArgs)
args[len(fixedArgs)] = finalPrompt
return FitPromptResult{
OK: true,
Args: args,
Truncated: true,
OriginalLength: len(prompt),
FinalPromptLength: len(finalPrompt),
}
}
func WarnPromptTruncated(originalLength, finalLength int) {
_ = originalLength
_ = finalLength
}

View File

@ -0,0 +1,37 @@
package winlimit
import (
"runtime"
"strings"
"testing"
)
func TestNonWindowsPassThrough(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping non-Windows test on Windows")
}
fixedArgs := []string{"--print", "--model", "gpt-4"}
prompt := "Hello world"
result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp")
if !result.OK {
t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error)
}
if result.Truncated {
t.Error("expected no truncation on non-Windows")
}
if result.OriginalLength != len(prompt) {
t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength)
}
// Last arg should be the prompt
if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt {
t.Errorf("expected last arg to be prompt, got %v", result.Args)
}
}
func TestOmissionPrefix(t *testing.T) {
if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") {
t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix)
}
}

View File

@ -0,0 +1,30 @@
package workspace
import (
"cursor-api-proxy/internal/config"
"os"
"path/filepath"
"strings"
)
type WorkspaceResult struct {
WorkspaceDir string
TempDir string
}
func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult {
if cfg.ChatOnlyWorkspace {
tempDir, err := os.MkdirTemp("", "cursor-proxy-")
if err != nil {
tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback")
_ = os.MkdirAll(tempDir, 0700)
}
return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir}
}
headerWs := strings.TrimSpace(workspaceHeader)
if headerWs != "" {
return WorkspaceResult{WorkspaceDir: headerWs}
}
return WorkspaceResult{WorkspaceDir: cfg.Workspace}
}