diff --git a/pkg/infrastructure/env/env.go b/pkg/infrastructure/env/env.go new file mode 100644 index 0000000..45dc3be --- /dev/null +++ b/pkg/infrastructure/env/env.go @@ -0,0 +1,381 @@ +package env + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" +) + +type EnvSource map[string]string + +type LoadedEnv struct { + AgentBin string + AgentNode string + AgentScript string + CommandShell string + Host string + Port int + RequiredKey string + DefaultModel string + Provider string + Force bool + ApproveMcps bool + StrictModel bool + Workspace string + TimeoutMs int + TLSCertPath string + TLSKeyPath string + SessionsLogPath string + ChatOnlyWorkspace bool + Verbose bool + MaxMode bool + ConfigDirs []string + MultiPort bool + WinCmdlineMax int + GeminiAccountDir string + GeminiBrowserVisible bool + GeminiMaxSessions int +} + +type AgentCommand struct { + Command string + Args []string + Env map[string]string + WindowsVerbatimArguments bool + AgentScriptPath string + ConfigDir string +} + +func getEnvVal(e EnvSource, names []string) string { + for _, name := range names { + if v, ok := e[name]; ok && strings.TrimSpace(v) != "" { + return strings.TrimSpace(v) + } + } + return "" +} + +func envBool(e EnvSource, names []string, def bool) bool { + raw := getEnvVal(e, names) + if raw == "" { + return def + } + switch strings.ToLower(raw) { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + } + return def +} + +func envInt(e EnvSource, names []string, def int) int { + raw := getEnvVal(e, names) + if raw == "" { + return def + } + v, err := strconv.Atoi(raw) + if err != nil { + return def + } + return v +} + +func normalizeModelId(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "auto" + } + parts := strings.Split(raw, "/") + last := parts[len(parts)-1] + if last == "" { + return "auto" + } + return last +} + +func resolveAbs(raw, cwd string) string { + if raw == "" { + return "" + } + if filepath.IsAbs(raw) { + return raw + } + return filepath.Join(cwd, raw) +} + +func isAuthenticatedAccountDir(dir string) bool { + data, err := os.ReadFile(filepath.Join(dir, "cli-config.json")) + if err != nil { + return false + } + var cfg struct { + AuthInfo *struct { + Email string `json:"email"` + } `json:"authInfo"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return false + } + return cfg.AuthInfo != nil && cfg.AuthInfo.Email != "" +} + +func discoverAccountDirs(homeDir string) []string { + if homeDir == "" { + return nil + } + accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts") + entries, err := os.ReadDir(accountsDir) + if err != nil { + return nil + } + var dirs []string + for _, e := range entries { + if !e.IsDir() { + continue + } + dir := filepath.Join(accountsDir, e.Name()) + if isAuthenticatedAccountDir(dir) { + dirs = append(dirs, dir) + } + } + return dirs +} + +func parseDotEnv(path string) EnvSource { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + m := make(EnvSource) + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + m[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } + } + return m +} + +func OsEnvToMap(cwdHint ...string) EnvSource { + m := make(EnvSource) + for _, kv := range os.Environ() { + parts := strings.SplitN(kv, "=", 2) + if len(parts) == 2 { + m[parts[0]] = parts[1] + } + } + + cwd := "" + if len(cwdHint) > 0 && cwdHint[0] != "" { + cwd = cwdHint[0] + } else { + cwd, _ = os.Getwd() + } + + if dotenv := parseDotEnv(filepath.Join(cwd, ".env")); dotenv != nil { + for k, v := range dotenv { + if _, exists := m[k]; !exists { + m[k] = v + } + } + } + + return m +} + +func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv { + if e == nil { + e = OsEnvToMap() + } + if cwd == "" { + var err error + cwd, err = os.Getwd() + if err != nil { + cwd = "." + } + } + + host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"}) + if host == "" { + host = "127.0.0.1" + } + port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765) + if port <= 0 { + port = 8765 + } + + home := getEnvVal(e, []string{"HOME", "USERPROFILE"}) + + sessionsLogPath := func() string { + if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" { + return p + } + if home != "" { + return filepath.Join(home, ".cursor-api-proxy", "sessions.log") + } + return filepath.Join(cwd, "sessions.log") + }() + + var configDirs []string + if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" { + for _, d := range strings.Split(raw, ",") { + d = strings.TrimSpace(d) + if d != "" { + if p := resolveAbs(d, cwd); p != "" { + configDirs = append(configDirs, p) + } + } + } + } + if len(configDirs) == 0 { + configDirs = discoverAccountDirs(home) + } + + winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000) + if winMax < 4096 { + winMax = 4096 + } + if winMax > 32700 { + winMax = 32700 + } + + agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"}) + if agentBin == "" { + agentBin = "agent" + } + commandShell := getEnvVal(e, []string{"COMSPEC"}) + if commandShell == "" { + commandShell = "cmd.exe" + } + workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd) + if workspace == "" { + workspace = cwd + } + + geminiAccountDir := getEnvVal(e, []string{"GEMINI_ACCOUNT_DIR"}) + if geminiAccountDir == "" { + geminiAccountDir = filepath.Join(home, ".cursor-api-proxy", "gemini-accounts") + } else { + geminiAccountDir = resolveAbs(geminiAccountDir, cwd) + } + + return LoadedEnv{ + AgentBin: agentBin, + AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}), + AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}), + CommandShell: commandShell, + Host: host, + Port: port, + RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}), + DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})), + Provider: getEnvVal(e, []string{"CURSOR_BRIDGE_PROVIDER"}), + Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false), + ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false), + StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true), + Workspace: workspace, + TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000), + TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd), + TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd), + SessionsLogPath: sessionsLogPath, + ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true), + Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false), + MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false), + ConfigDirs: configDirs, + MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false), + WinCmdlineMax: winMax, + GeminiAccountDir: geminiAccountDir, + GeminiBrowserVisible: envBool(e, []string{"GEMINI_BROWSER_VISIBLE"}, false), + GeminiMaxSessions: envInt(e, []string{"GEMINI_MAX_SESSIONS"}, 3), + } +} + +func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand { + if e == nil { + e = OsEnvToMap() + } + loaded := LoadEnvConfig(e, cwd) + + cloneEnv := func() map[string]string { + m := make(map[string]string, len(e)) + for k, v := range e { + m[k] = v + } + return m + } + + if runtime.GOOS == "windows" { + if loaded.AgentNode != "" && loaded.AgentScript != "" { + agentScriptPath := loaded.AgentScript + if !filepath.IsAbs(agentScriptPath) { + agentScriptPath = filepath.Join(cwd, agentScriptPath) + } + agentDir := filepath.Dir(agentScriptPath) + configDir := filepath.Join(agentDir, "..", "data", "config") + env2 := cloneEnv() + env2["CURSOR_INVOKED_AS"] = "agent.cmd" + ac := AgentCommand{ + Command: loaded.AgentNode, + Args: append([]string{loaded.AgentScript}, args...), + Env: env2, + AgentScriptPath: agentScriptPath, + } + if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil { + ac.ConfigDir = configDir + } + return ac + } + + if strings.HasSuffix(strings.ToLower(cmd), ".cmd") { + cmdResolved := cmd + if !filepath.IsAbs(cmd) { + cmdResolved = filepath.Join(cwd, cmd) + } + dir := filepath.Dir(cmdResolved) + nodeBin := filepath.Join(dir, "node.exe") + script := filepath.Join(dir, "index.js") + if _, err1 := os.Stat(nodeBin); err1 == nil { + if _, err2 := os.Stat(script); err2 == nil { + configDir := filepath.Join(dir, "..", "data", "config") + env2 := cloneEnv() + env2["CURSOR_INVOKED_AS"] = "agent.cmd" + ac := AgentCommand{ + Command: nodeBin, + Args: append([]string{script}, args...), + Env: env2, + AgentScriptPath: script, + } + if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil { + ac.ConfigDir = configDir + } + return ac + } + } + + quotedArgs := make([]string, len(args)) + for i, a := range args { + if strings.Contains(a, " ") { + quotedArgs[i] = `"` + a + `"` + } else { + quotedArgs[i] = a + } + } + cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"` + return AgentCommand{ + Command: loaded.CommandShell, + Args: []string{"/d", "/s", "/c", cmdLine}, + Env: cloneEnv(), + WindowsVerbatimArguments: true, + } + } + } + + return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()} +} diff --git a/pkg/infrastructure/env/env_test.go b/pkg/infrastructure/env/env_test.go new file mode 100644 index 0000000..e589d59 --- /dev/null +++ b/pkg/infrastructure/env/env_test.go @@ -0,0 +1,65 @@ +package env + +import "testing" + +func TestLoadEnvConfigDefaults(t *testing.T) { + e := EnvSource{} + loaded := LoadEnvConfig(e, "/tmp") + + if loaded.Host != "127.0.0.1" { + t.Errorf("expected 127.0.0.1, got %s", loaded.Host) + } + if loaded.Port != 8765 { + t.Errorf("expected 8765, got %d", loaded.Port) + } + if loaded.DefaultModel != "auto" { + t.Errorf("expected auto, got %s", loaded.DefaultModel) + } + if loaded.AgentBin != "agent" { + t.Errorf("expected agent, got %s", loaded.AgentBin) + } + if !loaded.StrictModel { + t.Error("expected strictModel=true by default") + } +} + +func TestLoadEnvConfigOverride(t *testing.T) { + e := EnvSource{ + "CURSOR_BRIDGE_HOST": "0.0.0.0", + "CURSOR_BRIDGE_PORT": "9000", + "CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4", + "CURSOR_AGENT_BIN": "/usr/local/bin/agent", + } + loaded := LoadEnvConfig(e, "/tmp") + + if loaded.Host != "0.0.0.0" { + t.Errorf("expected 0.0.0.0, got %s", loaded.Host) + } + if loaded.Port != 9000 { + t.Errorf("expected 9000, got %d", loaded.Port) + } + if loaded.DefaultModel != "gpt-4" { + t.Errorf("expected gpt-4, got %s", loaded.DefaultModel) + } + if loaded.AgentBin != "/usr/local/bin/agent" { + t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin) + } +} + +func TestNormalizeModelID(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"gpt-4", "gpt-4"}, + {"openai/gpt-4", "gpt-4"}, + {"", "auto"}, + {" ", "auto"}, + } + for _, tc := range tests { + got := normalizeModelId(tc.input) + if got != tc.want { + t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/pkg/infrastructure/httputil/httputil.go b/pkg/infrastructure/httputil/httputil.go new file mode 100644 index 0000000..bb39663 --- /dev/null +++ b/pkg/infrastructure/httputil/httputil.go @@ -0,0 +1,50 @@ +package httputil + +import ( + "encoding/json" + "io" + "net/http" + "regexp" +) + +var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`) + +func ExtractBearerToken(r *http.Request) string { + h := r.Header.Get("Authorization") + if h == "" { + return "" + } + m := bearerRe.FindStringSubmatch(h) + if m == nil { + return "" + } + return m[1] +} + +func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) { + for k, v := range extraHeaders { + w.Header().Set(k, v) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) { + for k, v := range extraHeaders { + w.Header().Set(k, v) + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + w.WriteHeader(200) +} + +func ReadBody(r *http.Request) (string, error) { + data, err := io.ReadAll(r.Body) + if err != nil { + return "", err + } + return string(data), nil +} diff --git a/pkg/infrastructure/httputil/httputil_test.go b/pkg/infrastructure/httputil/httputil_test.go new file mode 100644 index 0000000..530e536 --- /dev/null +++ b/pkg/infrastructure/httputil/httputil_test.go @@ -0,0 +1,50 @@ +package httputil + +import ( + "net/http/httptest" + "testing" +) + +func TestExtractBearerToken(t *testing.T) { + tests := []struct { + header string + want string + }{ + {"Bearer mytoken123", "mytoken123"}, + {"bearer MYTOKEN", "MYTOKEN"}, + {"", ""}, + {"Basic abc", ""}, + {"Bearer ", ""}, + } + for _, tc := range tests { + req := httptest.NewRequest("GET", "/", nil) + if tc.header != "" { + req.Header.Set("Authorization", tc.header) + } + got := ExtractBearerToken(req) + if got != tc.want { + t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want) + } + } +} + +func TestWriteJSON(t *testing.T) { + w := httptest.NewRecorder() + WriteJSON(w, 200, map[string]string{"ok": "true"}, nil) + + if w.Code != 200 { + t.Errorf("expected 200, got %d", w.Code) + } + if w.Header().Get("Content-Type") != "application/json" { + t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type")) + } +} + +func TestWriteJSONWithExtraHeaders(t *testing.T) { + w := httptest.NewRecorder() + WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"}) + + if w.Header().Get("X-Custom") != "value" { + t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom")) + } +} diff --git a/pkg/infrastructure/logger/logger.go b/pkg/infrastructure/logger/logger.go new file mode 100644 index 0000000..db18a0a --- /dev/null +++ b/pkg/infrastructure/logger/logger.go @@ -0,0 +1,309 @@ +package logger + +import ( + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/pool" + "fmt" + "os" + "path/filepath" + "strings" + "time" +) + +const ( + cReset = "\x1b[0m" + cBold = "\x1b[1m" + cDim = "\x1b[2m" + cCyan = "\x1b[36m" + cBCyan = "\x1b[1;96m" + cGreen = "\x1b[32m" + cBGreen = "\x1b[1;92m" + cYellow = "\x1b[33m" + cMagenta = "\x1b[35m" + cBMagenta = "\x1b[1;95m" + cRed = "\x1b[31m" + cGray = "\x1b[90m" + cWhite = "\x1b[97m" +) + +var roleStyle = map[string]string{ + "system": cYellow, + "user": cCyan, + "assistant": cGreen, +} + +var roleEmoji = map[string]string{ + "system": "🔧", + "user": "👤", + "assistant": "🤖", +} + +func ts() string { + return cGray + time.Now().UTC().Format("15:04:05") + cReset +} + +func tsDate() string { + return cGray + time.Now().UTC().Format("2006-01-02 15:04:05") + cReset +} + +func truncate(s string, max int) string { + if len(s) <= max { + return s + } + head := int(float64(max) * 0.6) + tail := max - head + omitted := len(s) - head - tail + return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset +} + +func hr(ch string, length int) string { + return cGray + strings.Repeat(ch, length) + cReset +} + +type TrafficMessage struct { + Role string + Content string +} + +func LogDebug(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg) +} + +func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) { + provider := cfg.Provider + if provider == "" { + provider = "cursor" + } + fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset) + fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n", + cBold, cBCyan, cReset, cBold, cWhite, version, cBCyan, cReset) + fmt.Printf("%s%s╚══════════════════════════════════════════╝%s\n\n", cBold, cBCyan, cReset) + url := fmt.Sprintf("%s://%s:%d", scheme, host, port) + fmt.Printf(" %s●%s listening %s%s%s\n", cBGreen, cReset, cBold, url, cReset) + fmt.Printf(" %s▸%s provider %s%s%s\n", cCyan, cReset, cBold, provider, cReset) + fmt.Printf(" %s▸%s agent %s%s%s\n", cCyan, cReset, cDim, cfg.AgentBin, cReset) + fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset) + fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset) + fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset) + fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset) + + // 顯示 Gemini Web Provider 相關設定 + if provider == "gemini-web" { + fmt.Printf(" %s▸%s gemini-dir %s%s%s\n", cCyan, cReset, cDim, cfg.GeminiAccountDir, cReset) + fmt.Printf(" %s▸%s max-sess %s%d%s\n", cCyan, cReset, cDim, cfg.GeminiMaxSessions, cReset) + } + + flags := []string{} + if cfg.Force { + flags = append(flags, "force") + } + if cfg.ApproveMcps { + flags = append(flags, "approve-mcps") + } + if cfg.MaxMode { + flags = append(flags, "max-mode") + } + if cfg.Verbose { + flags = append(flags, "verbose") + } + if cfg.ChatOnlyWorkspace { + flags = append(flags, "chat-only") + } + if cfg.RequiredKey != "" { + flags = append(flags, "api-key-required") + } + if len(flags) > 0 { + fmt.Printf(" %s▸%s flags %s%s%s\n", cCyan, cReset, cYellow, strings.Join(flags, " · "), cReset) + } + if len(cfg.ConfigDirs) > 0 { + fmt.Printf(" %s▸%s pool %s%d accounts%s\n", cCyan, cReset, cBGreen, len(cfg.ConfigDirs), cReset) + } + fmt.Println() +} + +func LogShutdown(sig string) { + fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset) +} + +func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) { + modeTag := fmt.Sprintf("%ssync%s", cDim, cReset) + if isStream { + modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset) + } + fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n", + ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag) +} + +func LogRequestDone(method, pathname, model string, latencyMs int64, code int) { + statusColor := cBGreen + if code != 0 { + statusColor = cRed + } + fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n", + ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset) +} + +func LogRequestTimeout(method, pathname, model string, timeoutMs int) { + fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n", + ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset) +} + +func LogClientDisconnect(method, pathname, model string, latencyMs int64) { + fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n", + ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset) +} + +func LogStreamChunk(model string, text string, chunkNum int) { + preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120) + fmt.Printf("%s %s▸%s #%d %s%s%s\n", + ts(), cDim, cReset, chunkNum, cWhite, preview, cReset) +} + +func LogRawLine(line string) { + preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200) + fmt.Printf("%s %s│%s %sraw%s %s\n", + ts(), cGray, cReset, cDim, cReset, preview) +} + +func LogIncoming(method, pathname, remoteAddress string) { + methodColor := cBCyan + switch method { + case "POST": + methodColor = cBMagenta + case "GET": + methodColor = cBCyan + case "DELETE": + methodColor = cRed + } + fmt.Printf("%s %s%s%s%s %s%s%s %s(%s)%s\n", + ts(), + methodColor, cBold, method, cReset, + cWhite, pathname, cReset, + cDim, remoteAddress, cReset, + ) +} + +func LogAccountAssigned(configDir string) { + if configDir == "" { + return + } + name := filepath.Base(configDir) + fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset) +} + +func LogAccountStats(verbose bool, stats []pool.AccountStat) { + if !verbose || len(stats) == 0 { + return + } + now := time.Now().UnixMilli() + fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset) + for _, s := range stats { + name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir)) + active := fmt.Sprintf("%sactive:0%s", cDim, cReset) + if s.ActiveRequests > 0 { + active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset) + } + total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset) + ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset) + errStr := fmt.Sprintf("%serr:0%s", cDim, cReset) + if s.TotalErrors > 0 { + errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset) + } + rl := fmt.Sprintf("%srl:0%s", cDim, cReset) + if s.TotalRateLimits > 0 { + rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset) + } + avg := "avg:-" + if s.TotalRequests > 0 { + avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests)) + } + status := fmt.Sprintf("%s✓%s", cGreen, cReset) + if s.IsRateLimited { + recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339) + _ = now + status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset) + } + fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n", + cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status) + } + fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset) +} + +func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) { + if !verbose { + return + } + modeTag := fmt.Sprintf("%ssync%s", cDim, cReset) + if isStream { + modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset) + } + modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset) + fmt.Println(hr("─", 60)) + fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag) + for _, m := range messages { + roleColor := cWhite + if c, ok := roleStyle[m.Role]; ok { + roleColor = c + } + emoji := "💬" + if e, ok := roleEmoji[m.Role]; ok { + emoji = e + } + label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset) + charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset) + preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280) + fmt.Printf(" %s %s %s\n", emoji, label, charCount) + fmt.Printf(" %s%s%s\n", cDim, preview, cReset) + } +} + +func LogTrafficResponse(verbose bool, model, text string, isStream bool) { + if !verbose { + return + } + modeTag := fmt.Sprintf("%ssync%s", cDim, cReset) + if isStream { + modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset) + } + modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset) + charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset) + preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480) + fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount) + fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset) + fmt.Println(hr("─", 60)) +} + +func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) { + line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode) + dir := filepath.Dir(logPath) + if err := os.MkdirAll(dir, 0755); err == nil { + f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err == nil { + _, _ = f.WriteString(line) + f.Close() + } + } +} + +func LogTruncation(originalLen, finalLen int) { + fmt.Printf("%s %s⚠ prompt truncated%s %s(%d → %d chars, tail preserved)%s\n", + ts(), cYellow, cReset, cDim, originalLen, finalLen, cReset) +} + +func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string { + errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr)) + fmt.Fprintf(os.Stderr, "%s %s✗ agent error%s %s%s%s\n", ts(), cRed, cReset, cDim, errMsg, cReset) + truncated := strings.TrimSpace(stderr) + if len(truncated) > 200 { + truncated = truncated[:200] + } + truncated = strings.ReplaceAll(truncated, "\n", " ") + line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n", + time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated) + if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { + _, _ = f.WriteString(line) + f.Close() + } + return errMsg +} diff --git a/pkg/infrastructure/parser/stream.go b/pkg/infrastructure/parser/stream.go new file mode 100644 index 0000000..bbdd231 --- /dev/null +++ b/pkg/infrastructure/parser/stream.go @@ -0,0 +1,110 @@ +package parser + +import "encoding/json" + +type StreamParser func(line string) + +type Parser struct { + Parse StreamParser + Flush func() +} + +// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking) +func CreateStreamParser(onText func(string), onDone func()) Parser { + return CreateStreamParserWithThinking(onText, nil, onDone) +} + +// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。 +// onThinking 可為 nil,表示忽略思考過程。 +func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser { + // accumulated 是所有已輸出內容的串接 + accumulatedText := "" + accumulatedThinking := "" + done := false + + parse := func(line string) { + if done { + return + } + + var obj struct { + Type string `json:"type"` + Subtype string `json:"subtype"` + Message *struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + Thinking string `json:"thinking"` + } `json:"content"` + } `json:"message"` + } + + if err := json.Unmarshal([]byte(line), &obj); err != nil { + return + } + + if obj.Type == "assistant" && obj.Message != nil { + fullText := "" + fullThinking := "" + for _, p := range obj.Message.Content { + switch p.Type { + case "text": + if p.Text != "" { + fullText += p.Text + } + case "thinking": + if p.Thinking != "" { + fullThinking += p.Thinking + } + } + } + + // 處理思考過程(不因去重而 return,避免跳過同行的文字內容) + if onThinking != nil && fullThinking != "" && fullThinking != accumulatedThinking { + // 增量模式:新內容以 accumulated 為前綴 + if len(fullThinking) >= len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking { + delta := fullThinking[len(accumulatedThinking):] + if delta != "" { + onThinking(delta) + } + accumulatedThinking = fullThinking + } else { + // 獨立片段:直接輸出,但 accumulated 要串接 + onThinking(fullThinking) + accumulatedThinking = accumulatedThinking + fullThinking + } + } + + // 處理一般文字 + if fullText == "" || fullText == accumulatedText { + return + } + // 增量模式:新內容以 accumulated 為前綴 + if len(fullText) >= len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText { + delta := fullText[len(accumulatedText):] + if delta != "" { + onText(delta) + } + accumulatedText = fullText + } else { + // 獨立片段:直接輸出,但 accumulated 要串接 + onText(fullText) + accumulatedText = accumulatedText + fullText + } + } + + if obj.Type == "result" && obj.Subtype == "success" { + done = true + onDone() + } + } + + flush := func() { + if !done { + done = true + onDone() + } + } + + return Parser{Parse: parse, Flush: flush} +} diff --git a/pkg/infrastructure/parser/stream_test.go b/pkg/infrastructure/parser/stream_test.go new file mode 100644 index 0000000..146cdc2 --- /dev/null +++ b/pkg/infrastructure/parser/stream_test.go @@ -0,0 +1,304 @@ +package parser + +import ( + "encoding/json" + "testing" +) + +func makeAssistantLine(text string) string { + obj := map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": text}, + }, + }, + } + b, _ := json.Marshal(obj) + return string(b) +} + +func makeResultLine() string { + b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"}) + return string(b) +} + +func TestStreamParserFragmentMode(t *testing.T) { + // cursor --stream-partial-output 模式:每個訊息是獨立 token fragment + var texts []string + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() {}, + ) + + p.Parse(makeAssistantLine("你")) + p.Parse(makeAssistantLine("好!有")) + p.Parse(makeAssistantLine("什")) + p.Parse(makeAssistantLine("麼")) + + if len(texts) != 4 { + t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts) + } + if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" { + t.Fatalf("unexpected texts: %v", texts) + } +} + +func TestStreamParserDeduplicatesFinalFullText(t *testing.T) { + // 最後一個訊息是完整的累積文字,應被跳過(去重) + var texts []string + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() {}, + ) + + p.Parse(makeAssistantLine("Hello")) + p.Parse(makeAssistantLine(" world")) + // 最後一個是完整累積文字,應被去重 + p.Parse(makeAssistantLine("Hello world")) + + if len(texts) != 2 { + t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts) + } + if texts[0] != "Hello" || texts[1] != " world" { + t.Fatalf("unexpected texts: %v", texts) + } +} + +func TestStreamParserCallsOnDone(t *testing.T) { + var texts []string + doneCount := 0 + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + p.Parse(makeResultLine()) + if doneCount != 1 { + t.Fatalf("expected onDone called once, got %d", doneCount) + } + if len(texts) != 0 { + t.Fatalf("expected no text, got %v", texts) + } +} + +func TestStreamParserIgnoresLinesAfterDone(t *testing.T) { + var texts []string + doneCount := 0 + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + p.Parse(makeResultLine()) + p.Parse(makeAssistantLine("late")) + if len(texts) != 0 { + t.Fatalf("expected no text after done, got %v", texts) + } + if doneCount != 1 { + t.Fatalf("expected onDone called once, got %d", doneCount) + } +} + +func TestStreamParserIgnoresNonAssistantLines(t *testing.T) { + var texts []string + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() {}, + ) + + b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}}) + p.Parse(string(b1)) + b2, _ := json.Marshal(map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{"content": []interface{}{}}, + }) + p.Parse(string(b2)) + p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`) + + if len(texts) != 0 { + t.Fatalf("expected no texts, got %v", texts) + } +} + +func TestStreamParserIgnoresParseErrors(t *testing.T) { + var texts []string + doneCount := 0 + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + p.Parse("not json") + p.Parse("{") + p.Parse("") + + if len(texts) != 0 || doneCount != 0 { + t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount) + } +} + +func TestStreamParserJoinsMultipleTextParts(t *testing.T) { + var texts []string + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() {}, + ) + + obj := map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": " "}, + {"type": "text", "text": "world"}, + }, + }, + } + b, _ := json.Marshal(obj) + p.Parse(string(b)) + + if len(texts) != 1 || texts[0] != "Hello world" { + t.Fatalf("expected ['Hello world'], got %v", texts) + } +} + +func TestStreamParserFlushTriggersDone(t *testing.T) { + var texts []string + doneCount := 0 + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + p.Parse(makeAssistantLine("Hello")) + // agent 結束但沒有 result/success,手動 flush + p.Flush() + if doneCount != 1 { + t.Fatalf("expected onDone called once after Flush, got %d", doneCount) + } + // 再 flush 不應重複觸發 + p.Flush() + if doneCount != 1 { + t.Fatalf("expected onDone called only once, got %d", doneCount) + } +} + +func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) { + doneCount := 0 + p := CreateStreamParser( + func(text string) {}, + func() { doneCount++ }, + ) + + p.Parse(makeResultLine()) + p.Flush() + if doneCount != 1 { + t.Fatalf("expected onDone called once, got %d", doneCount) + } +} + +func makeThinkingLine(thinking string) string { + obj := map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "thinking", "thinking": thinking}, + }, + }, + } + b, _ := json.Marshal(obj) + return string(b) +} + +func makeThinkingAndTextLine(thinking, text string) string { + obj := map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "thinking", "thinking": thinking}, + {"type": "text", "text": text}, + }, + }, + } + b, _ := json.Marshal(obj) + return string(b) +} + +func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) { + var texts []string + var thinkings []string + p := CreateStreamParserWithThinking( + func(text string) { texts = append(texts, text) }, + func(thinking string) { thinkings = append(thinkings, thinking) }, + func() {}, + ) + + p.Parse(makeThinkingLine("思考中...")) + p.Parse(makeAssistantLine("回答")) + + if len(thinkings) != 1 || thinkings[0] != "思考中..." { + t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings) + } + if len(texts) != 1 || texts[0] != "回答" { + t.Fatalf("expected texts=['回答'], got %v", texts) + } +} + +func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) { + var texts []string + p := CreateStreamParserWithThinking( + func(text string) { texts = append(texts, text) }, + nil, + func() {}, + ) + + p.Parse(makeThinkingLine("忽略的思考")) + p.Parse(makeAssistantLine("文字")) + + if len(texts) != 1 || texts[0] != "文字" { + t.Fatalf("expected texts=['文字'], got %v", texts) + } +} + +func TestStreamParserWithThinkingDeduplication(t *testing.T) { + var thinkings []string + p := CreateStreamParserWithThinking( + func(text string) {}, + func(thinking string) { thinkings = append(thinkings, thinking) }, + func() {}, + ) + + p.Parse(makeThinkingLine("A")) + p.Parse(makeThinkingLine("B")) + // 重複的完整思考,應被跳過 + p.Parse(makeThinkingLine("AB")) + + if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" { + t.Fatalf("expected thinkings=['A','B'], got %v", thinkings) + } +} + +// TestStreamParserThinkingDuplicateButTextStillEmitted 驗證 bug 修復: +// 當 thinking 重複(去重跳過)但同一行有 text 時,text 仍必須輸出。 +func TestStreamParserThinkingDuplicateButTextStillEmitted(t *testing.T) { + var texts []string + var thinkings []string + p := CreateStreamParserWithThinking( + func(text string) { texts = append(texts, text) }, + func(thinking string) { thinkings = append(thinkings, thinking) }, + func() {}, + ) + + // 第一行:thinking="思考中" + text(thinking 為新增,兩者都應輸出) + p.Parse(makeThinkingAndTextLine("思考中", "第一段")) + // 第二行:thinking 與上一行相同(去重),但 text 是新的,text 仍應輸出 + p.Parse(makeThinkingAndTextLine("思考中", "第二段")) + + if len(thinkings) != 1 || thinkings[0] != "思考中" { + t.Fatalf("expected thinkings=['思考中'], got %v", thinkings) + } + if len(texts) != 2 || texts[0] != "第一段" || texts[1] != "第二段" { + t.Fatalf("expected texts=['第一段','第二段'], got %v", texts) + } +} diff --git a/pkg/infrastructure/process/kill_unix.go b/pkg/infrastructure/process/kill_unix.go new file mode 100644 index 0000000..b235d96 --- /dev/null +++ b/pkg/infrastructure/process/kill_unix.go @@ -0,0 +1,21 @@ +//go:build !windows + +package process + +import ( + "os/exec" + "syscall" +) + +func killProcessGroup(c *exec.Cmd) error { + if c.Process == nil { + return nil + } + // 殺死整個 process group(負號表示 group) + pgid, err := syscall.Getpgid(c.Process.Pid) + if err == nil { + _ = syscall.Kill(-pgid, syscall.SIGKILL) + } + // 同時也 kill 主程序,以防萬一 + return c.Process.Kill() +} diff --git a/pkg/infrastructure/process/kill_windows.go b/pkg/infrastructure/process/kill_windows.go new file mode 100644 index 0000000..1874bbf --- /dev/null +++ b/pkg/infrastructure/process/kill_windows.go @@ -0,0 +1,14 @@ +//go:build windows + +package process + +import ( + "os/exec" +) + +func killProcessGroup(c *exec.Cmd) error { + if c.Process == nil { + return nil + } + return c.Process.Kill() +} diff --git a/pkg/infrastructure/process/process_test.go b/pkg/infrastructure/process/process_test.go new file mode 100644 index 0000000..c48d4b5 --- /dev/null +++ b/pkg/infrastructure/process/process_test.go @@ -0,0 +1,283 @@ +package process_test + +import ( + "context" + "cursor-api-proxy/internal/process" + "os" + "testing" + "time" +) + +// sh 是跨平台 shell 執行小 script 的輔助函式 +func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) { + t.Helper() + return process.Run("sh", []string{"-c", script}, opts) +} + +func TestRun_StdoutAndStderr(t *testing.T) { + result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{}) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Errorf("Code = %d, want 0", result.Code) + } + if result.Stdout != "hello\n" { + t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n") + } + if result.Stderr != "world\n" { + t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n") + } +} + +func TestRun_BasicSpawn(t *testing.T) { + result, err := sh(t, "printf ok", process.RunOptions{}) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Errorf("Code = %d, want 0", result.Code) + } + if result.Stdout != "ok" { + t.Errorf("Stdout = %q, want ok", result.Stdout) + } +} + +func TestRun_ConfigDir_Propagated(t *testing.T) { + result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`}, + process.RunOptions{ConfigDir: "/test/account/dir"}) + if err != nil { + t.Fatal(err) + } + if result.Stdout != "/test/account/dir" { + t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout) + } +} + +func TestRun_ConfigDir_Absent(t *testing.T) { + // 確保沒有殘留的環境變數 + _ = os.Unsetenv("CURSOR_CONFIG_DIR") + result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`}, + process.RunOptions{}) + if err != nil { + t.Fatal(err) + } + if result.Stdout != "unset" { + t.Errorf("Stdout = %q, want unset", result.Stdout) + } +} + +func TestRun_NonZeroExit(t *testing.T) { + result, err := sh(t, "exit 42", process.RunOptions{}) + if err != nil { + t.Fatal(err) + } + if result.Code != 42 { + t.Errorf("Code = %d, want 42", result.Code) + } +} + +func TestRun_Timeout(t *testing.T) { + start := time.Now() + result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300}) + elapsed := time.Since(start) + if err != nil { + t.Fatal(err) + } + if result.Code == 0 { + t.Error("expected non-zero exit code after timeout") + } + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestRunStreaming_OnLine(t *testing.T) { + var lines []string + result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"}, + process.RunStreamingOptions{ + OnLine: func(line string) { lines = append(lines, line) }, + }) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Errorf("Code = %d, want 0", result.Code) + } + if len(lines) != 3 { + t.Errorf("got %d lines, want 3: %v", len(lines), lines) + } + if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" { + t.Errorf("lines = %v, want [a b c]", lines) + } +} + +func TestRunStreaming_FlushFinalLine(t *testing.T) { + var lines []string + result, err := process.RunStreaming("sh", []string{"-c", "printf tail"}, + process.RunStreamingOptions{ + OnLine: func(line string) { lines = append(lines, line) }, + }) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Errorf("Code = %d, want 0", result.Code) + } + if len(lines) != 1 { + t.Errorf("got %d lines, want 1: %v", len(lines), lines) + } + if lines[0] != "tail" { + t.Errorf("lines[0] = %q, want tail", lines[0]) + } +} + +func TestRunStreaming_ConfigDir(t *testing.T) { + var lines []string + _, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`}, + process.RunStreamingOptions{ + RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"}, + OnLine: func(line string) { lines = append(lines, line) }, + }) + if err != nil { + t.Fatal(err) + } + if len(lines) != 1 || lines[0] != "/my/config/dir" { + t.Errorf("lines = %v, want [/my/config/dir]", lines) + } +} + +func TestRunStreaming_Stderr(t *testing.T) { + result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"}, + process.RunStreamingOptions{OnLine: func(string) {}}) + if err != nil { + t.Fatal(err) + } + if result.Stderr == "" { + t.Error("expected stderr to contain output") + } +} + +func TestRunStreaming_Timeout(t *testing.T) { + start := time.Now() + result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"}, + process.RunStreamingOptions{ + RunOptions: process.RunOptions{TimeoutMs: 300}, + OnLine: func(string) {}, + }) + elapsed := time.Since(start) + if err != nil { + t.Fatal(err) + } + if result.Code == 0 { + t.Error("expected non-zero exit code after timeout") + } + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestRunStreaming_Concurrent(t *testing.T) { + var lines1, lines2 []string + done := make(chan struct{}, 2) + + run := func(label string, target *[]string) { + process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"}, + process.RunStreamingOptions{ + OnLine: func(line string) { *target = append(*target, line) }, + }) + done <- struct{}{} + } + + go run("stream1", &lines1) + go run("stream2", &lines2) + + <-done + <-done + + if len(lines1) != 1 || lines1[0] != "stream1" { + t.Errorf("lines1 = %v, want [stream1]", lines1) + } + if len(lines2) != 1 || lines2[0] != "stream2" { + t.Errorf("lines2 = %v, want [stream2]", lines2) + } +} + +func TestRunStreaming_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + start := time.Now() + done := make(chan struct{}) + + go func() { + process.RunStreaming("sh", []string{"-c", "sleep 30"}, + process.RunStreamingOptions{ + RunOptions: process.RunOptions{Ctx: ctx}, + OnLine: func(string) {}, + }) + close(done) + }() + + time.AfterFunc(100*time.Millisecond, cancel) + <-done + elapsed := time.Since(start) + + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestRun_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + start := time.Now() + done := make(chan process.RunResult, 1) + + go func() { + r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx}) + done <- r + }() + + time.AfterFunc(100*time.Millisecond, cancel) + result := <-done + elapsed := time.Since(start) + + if result.Code == 0 { + t.Error("expected non-zero exit code after cancel") + } + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestRun_AlreadyCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 已取消 + + start := time.Now() + result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx}) + elapsed := time.Since(start) + + if result.Code == 0 { + t.Error("expected non-zero exit code") + } + if elapsed > 2*time.Second { + t.Errorf("elapsed %v, want < 2s", elapsed) + } +} + +func TestKillAllChildProcesses(t *testing.T) { + done := make(chan process.RunResult, 1) + go func() { + r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{}) + done <- r + }() + + time.Sleep(80 * time.Millisecond) + process.KillAllChildProcesses() + result := <-done + + if result.Code == 0 { + t.Error("expected non-zero exit code after kill") + } + // 再次呼叫不應 panic + process.KillAllChildProcesses() +} diff --git a/pkg/infrastructure/process/runner.go b/pkg/infrastructure/process/runner.go new file mode 100644 index 0000000..681e42f --- /dev/null +++ b/pkg/infrastructure/process/runner.go @@ -0,0 +1,250 @@ +package process + +import ( + "bufio" + "context" + "cursor-api-proxy/internal/env" + "fmt" + "os/exec" + "strings" + "sync" + "syscall" + "time" +) + +type RunResult struct { + Code int + Stdout string + Stderr string +} + +type RunOptions struct { + Cwd string + TimeoutMs int + MaxMode bool + ConfigDir string + Ctx context.Context +} + +type RunStreamingOptions struct { + RunOptions + OnLine func(line string) +} + +// ─── Global child process registry ────────────────────────────────────────── + +var ( + activeMu sync.Mutex + activeChildren []*exec.Cmd +) + +func registerChild(c *exec.Cmd) { + activeMu.Lock() + activeChildren = append(activeChildren, c) + activeMu.Unlock() +} + +func unregisterChild(c *exec.Cmd) { + activeMu.Lock() + for i, ch := range activeChildren { + if ch == c { + activeChildren = append(activeChildren[:i], activeChildren[i+1:]...) + break + } + } + activeMu.Unlock() +} + +func KillAllChildProcesses() { + activeMu.Lock() + all := make([]*exec.Cmd, len(activeChildren)) + copy(all, activeChildren) + activeChildren = nil + activeMu.Unlock() + for _, c := range all { + killProcessGroup(c) + } +} + +// ─── Spawn ──────────────────────────────────────────────────────────────── + +func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd { + envSrc := env.OsEnvToMap() + resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd) + + if opts.MaxMode && maxModeFn != nil { + maxModeFn(resolved.AgentScriptPath, opts.ConfigDir) + } + + envMap := make(map[string]string, len(resolved.Env)) + for k, v := range resolved.Env { + envMap[k] = v + } + if opts.ConfigDir != "" { + envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir + } else if resolved.ConfigDir != "" { + if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists { + envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir + } + } + + envSlice := make([]string, 0, len(envMap)) + for k, v := range envMap { + envSlice = append(envSlice, k+"="+v) + } + + ctx := opts.Ctx + if ctx == nil { + ctx = context.Background() + } + + // 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出 + c := exec.CommandContext(ctx, resolved.Command, resolved.Args...) + c.Dir = opts.Cwd + c.Env = envSlice + // 設定新的 process group,使 kill 能傳遞給所有子孫程序 + c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + // WaitDelay:context cancel 後額外等待這麼久再強制關閉 pipes + c.WaitDelay = 5 * time.Second + // Cancel 函式:殺死整個 process group + c.Cancel = func() error { + return killProcessGroup(c) + } + return c +} + +// MaxModeFn is set by the agent package to avoid import cycle. +var MaxModeFn func(agentScriptPath, configDir string) + +func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) { + ctx := opts.Ctx + var cancel context.CancelFunc + if opts.TimeoutMs > 0 { + if ctx == nil { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond) + } else { + ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond) + } + defer cancel() + opts.Ctx = ctx + } else if ctx == nil { + opts.Ctx = context.Background() + } + + c := spawnChild(cmdStr, args, &opts, MaxModeFn) + var stdoutBuf, stderrBuf strings.Builder + c.Stdout = &stdoutBuf + c.Stderr = &stderrBuf + + if err := c.Start(); err != nil { + // context 已取消或命令找不到時 + if opts.Ctx != nil && opts.Ctx.Err() != nil { + return RunResult{Code: -1}, nil + } + if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") { + return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr) + } + return RunResult{}, err + } + registerChild(c) + defer unregisterChild(c) + + err := c.Wait() + code := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + code = exitErr.ExitCode() + if code == 0 { + code = -1 + } + } else { + // context cancelled or killed — return -1 but no error + return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil + } + } + return RunResult{ + Code: code, + Stdout: stdoutBuf.String(), + Stderr: stderrBuf.String(), + }, nil +} + +type StreamResult struct { + Code int + Stderr string +} + +func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) { + ctx := opts.Ctx + var cancel context.CancelFunc + if opts.TimeoutMs > 0 { + if ctx == nil { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond) + } else { + ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond) + } + defer cancel() + opts.RunOptions.Ctx = ctx + } else if opts.RunOptions.Ctx == nil { + opts.RunOptions.Ctx = context.Background() + } + + c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn) + stdoutPipe, err := c.StdoutPipe() + if err != nil { + return StreamResult{}, err + } + stderrPipe, err := c.StderrPipe() + if err != nil { + return StreamResult{}, err + } + + if err := c.Start(); err != nil { + if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") { + return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr) + } + return StreamResult{}, err + } + registerChild(c) + defer unregisterChild(c) + + var stderrBuf strings.Builder + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + scanner := bufio.NewScanner(stdoutPipe) + scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024) + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) != "" { + opts.OnLine(line) + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + scanner := bufio.NewScanner(stderrPipe) + scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024) + for scanner.Scan() { + stderrBuf.WriteString(scanner.Text()) + stderrBuf.WriteString("\n") + } + }() + + wg.Wait() + err = c.Wait() + code := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + code = exitErr.ExitCode() + if code == 0 { + code = -1 + } + } + } + return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil +} diff --git a/pkg/infrastructure/winlimit/winlimit.go b/pkg/infrastructure/winlimit/winlimit.go new file mode 100644 index 0000000..06044b1 --- /dev/null +++ b/pkg/infrastructure/winlimit/winlimit.go @@ -0,0 +1,181 @@ +package winlimit + +import ( + "cursor-api-proxy/internal/env" + "runtime" +) + +const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n" +const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n" + +// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux. +// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars. +func safeLinuxArgMax() int { + return 1536 * 1024 +} + +type FitPromptResult struct { + OK bool + Args []string + Truncated bool + OriginalLength int + FinalPromptLength int + Error string +} + +func estimateCmdlineLength(resolved env.AgentCommand) int { + argv := append([]string{resolved.Command}, resolved.Args...) + if resolved.WindowsVerbatimArguments { + n := 0 + for _, a := range argv { + n += len(a) + } + if len(argv) > 1 { + n += len(argv) - 1 + } + return n + 512 + } + dstLen := 0 + for _, a := range argv { + dstLen += len(a) + } + dstLen = dstLen*2 + len(argv)*2 + if len(argv) > 1 { + dstLen += len(argv) - 1 + } + return dstLen + 512 +} + +func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult { + if runtime.GOOS != "windows" { + return fitPromptLinux(fixedArgs, prompt) + } + + e := env.OsEnvToMap() + measured := func(p string) int { + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = p + resolved := env.ResolveAgentCommand(agentBin, args, e, cwd) + return estimateCmdlineLength(resolved) + } + + if measured("") > maxCmdline { + return FitPromptResult{ + OK: false, + Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.", + } + } + + if measured(prompt) <= maxCmdline { + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = prompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: false, + OriginalLength: len(prompt), + FinalPromptLength: len(prompt), + } + } + + prefix := WinPromptOmissionPrefix + if measured(prefix) > maxCmdline { + return FitPromptResult{ + OK: false, + Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.", + } + } + + lo, hi, best := 0, len(prompt), 0 + for lo <= hi { + mid := (lo + hi) / 2 + var tail string + if mid > 0 { + tail = prompt[len(prompt)-mid:] + } + candidate := prefix + tail + if measured(candidate) <= maxCmdline { + best = mid + lo = mid + 1 + } else { + hi = mid - 1 + } + } + + var finalPrompt string + if best == 0 { + finalPrompt = prefix + } else { + finalPrompt = prefix + prompt[len(prompt)-best:] + } + + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = finalPrompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: true, + OriginalLength: len(prompt), + FinalPromptLength: len(finalPrompt), + } +} + +// fitPromptLinux handles Linux ARG_MAX truncation. +func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult { + argMax := safeLinuxArgMax() + + // Estimate total cmdline size: sum of all fixed args + prompt + null terminators + fixedLen := 0 + for _, a := range fixedArgs { + fixedLen += len(a) + 1 + } + totalLen := fixedLen + len(prompt) + 1 + + if totalLen <= argMax { + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = prompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: false, + OriginalLength: len(prompt), + FinalPromptLength: len(prompt), + } + } + + // Need to truncate: keep the tail of the prompt (most recent messages) + prefix := LinuxPromptOmissionPrefix + available := argMax - fixedLen - len(prefix) - 1 + if available < 0 { + available = 0 + } + + var finalPrompt string + if available <= 0 { + finalPrompt = prefix + } else if available >= len(prompt) { + finalPrompt = prefix + prompt + } else { + finalPrompt = prefix + prompt[len(prompt)-available:] + } + + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = finalPrompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: true, + OriginalLength: len(prompt), + FinalPromptLength: len(finalPrompt), + } +} + +func WarnPromptTruncated(originalLength, finalLength int) { + _ = originalLength + _ = finalLength +} diff --git a/pkg/infrastructure/winlimit/winlimit_test.go b/pkg/infrastructure/winlimit/winlimit_test.go new file mode 100644 index 0000000..03d2488 --- /dev/null +++ b/pkg/infrastructure/winlimit/winlimit_test.go @@ -0,0 +1,37 @@ +package winlimit + +import ( + "runtime" + "strings" + "testing" +) + +func TestNonWindowsPassThrough(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping non-Windows test on Windows") + } + + fixedArgs := []string{"--print", "--model", "gpt-4"} + prompt := "Hello world" + result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp") + + if !result.OK { + t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error) + } + if result.Truncated { + t.Error("expected no truncation on non-Windows") + } + if result.OriginalLength != len(prompt) { + t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength) + } + // Last arg should be the prompt + if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt { + t.Errorf("expected last arg to be prompt, got %v", result.Args) + } +} + +func TestOmissionPrefix(t *testing.T) { + if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") { + t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix) + } +} diff --git a/pkg/infrastructure/workspace/workspace.go b/pkg/infrastructure/workspace/workspace.go new file mode 100644 index 0000000..8fe76bd --- /dev/null +++ b/pkg/infrastructure/workspace/workspace.go @@ -0,0 +1,30 @@ +package workspace + +import ( + "cursor-api-proxy/internal/config" + "os" + "path/filepath" + "strings" +) + +type WorkspaceResult struct { + WorkspaceDir string + TempDir string +} + +func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult { + if cfg.ChatOnlyWorkspace { + tempDir, err := os.MkdirTemp("", "cursor-proxy-") + if err != nil { + tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback") + _ = os.MkdirAll(tempDir, 0700) + } + return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir} + } + + headerWs := strings.TrimSpace(workspaceHeader) + if headerWs != "" { + return WorkspaceResult{WorkspaceDir: headerWs} + } + return WorkspaceResult{WorkspaceDir: cfg.Workspace} +}