refactor(task-2): migrate infrastructure layer
- Migrate process (runner, kill_unix, kill_windows) - Migrate parser (stream) - Migrate httputil (httputil) - Migrate logger (logger) - Migrate env (env) - Migrate workspace (workspace) - Migrate winlimit (winlimit)
This commit is contained in:
parent
8b6abbbba7
commit
80d7a4bb29
|
|
@ -0,0 +1,381 @@
|
|||
package env
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EnvSource map[string]string
|
||||
|
||||
type LoadedEnv struct {
|
||||
AgentBin string
|
||||
AgentNode string
|
||||
AgentScript string
|
||||
CommandShell string
|
||||
Host string
|
||||
Port int
|
||||
RequiredKey string
|
||||
DefaultModel string
|
||||
Provider string
|
||||
Force bool
|
||||
ApproveMcps bool
|
||||
StrictModel bool
|
||||
Workspace string
|
||||
TimeoutMs int
|
||||
TLSCertPath string
|
||||
TLSKeyPath string
|
||||
SessionsLogPath string
|
||||
ChatOnlyWorkspace bool
|
||||
Verbose bool
|
||||
MaxMode bool
|
||||
ConfigDirs []string
|
||||
MultiPort bool
|
||||
WinCmdlineMax int
|
||||
GeminiAccountDir string
|
||||
GeminiBrowserVisible bool
|
||||
GeminiMaxSessions int
|
||||
}
|
||||
|
||||
type AgentCommand struct {
|
||||
Command string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
WindowsVerbatimArguments bool
|
||||
AgentScriptPath string
|
||||
ConfigDir string
|
||||
}
|
||||
|
||||
func getEnvVal(e EnvSource, names []string) string {
|
||||
for _, name := range names {
|
||||
if v, ok := e[name]; ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func envBool(e EnvSource, names []string, def bool) bool {
|
||||
raw := getEnvVal(e, names)
|
||||
if raw == "" {
|
||||
return def
|
||||
}
|
||||
switch strings.ToLower(raw) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func envInt(e EnvSource, names []string, def int) int {
|
||||
raw := getEnvVal(e, names)
|
||||
if raw == "" {
|
||||
return def
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func normalizeModelId(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "auto"
|
||||
}
|
||||
parts := strings.Split(raw, "/")
|
||||
last := parts[len(parts)-1]
|
||||
if last == "" {
|
||||
return "auto"
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func resolveAbs(raw, cwd string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if filepath.IsAbs(raw) {
|
||||
return raw
|
||||
}
|
||||
return filepath.Join(cwd, raw)
|
||||
}
|
||||
|
||||
func isAuthenticatedAccountDir(dir string) bool {
|
||||
data, err := os.ReadFile(filepath.Join(dir, "cli-config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
var cfg struct {
|
||||
AuthInfo *struct {
|
||||
Email string `json:"email"`
|
||||
} `json:"authInfo"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
return cfg.AuthInfo != nil && cfg.AuthInfo.Email != ""
|
||||
}
|
||||
|
||||
func discoverAccountDirs(homeDir string) []string {
|
||||
if homeDir == "" {
|
||||
return nil
|
||||
}
|
||||
accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts")
|
||||
entries, err := os.ReadDir(accountsDir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var dirs []string
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
dir := filepath.Join(accountsDir, e.Name())
|
||||
if isAuthenticatedAccountDir(dir) {
|
||||
dirs = append(dirs, dir)
|
||||
}
|
||||
}
|
||||
return dirs
|
||||
}
|
||||
|
||||
func parseDotEnv(path string) EnvSource {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
m := make(EnvSource)
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
m[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func OsEnvToMap(cwdHint ...string) EnvSource {
|
||||
m := make(EnvSource)
|
||||
for _, kv := range os.Environ() {
|
||||
parts := strings.SplitN(kv, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
m[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
cwd := ""
|
||||
if len(cwdHint) > 0 && cwdHint[0] != "" {
|
||||
cwd = cwdHint[0]
|
||||
} else {
|
||||
cwd, _ = os.Getwd()
|
||||
}
|
||||
|
||||
if dotenv := parseDotEnv(filepath.Join(cwd, ".env")); dotenv != nil {
|
||||
for k, v := range dotenv {
|
||||
if _, exists := m[k]; !exists {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv {
|
||||
if e == nil {
|
||||
e = OsEnvToMap()
|
||||
}
|
||||
if cwd == "" {
|
||||
var err error
|
||||
cwd, err = os.Getwd()
|
||||
if err != nil {
|
||||
cwd = "."
|
||||
}
|
||||
}
|
||||
|
||||
host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"})
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765)
|
||||
if port <= 0 {
|
||||
port = 8765
|
||||
}
|
||||
|
||||
home := getEnvVal(e, []string{"HOME", "USERPROFILE"})
|
||||
|
||||
sessionsLogPath := func() string {
|
||||
if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" {
|
||||
return p
|
||||
}
|
||||
if home != "" {
|
||||
return filepath.Join(home, ".cursor-api-proxy", "sessions.log")
|
||||
}
|
||||
return filepath.Join(cwd, "sessions.log")
|
||||
}()
|
||||
|
||||
var configDirs []string
|
||||
if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" {
|
||||
for _, d := range strings.Split(raw, ",") {
|
||||
d = strings.TrimSpace(d)
|
||||
if d != "" {
|
||||
if p := resolveAbs(d, cwd); p != "" {
|
||||
configDirs = append(configDirs, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(configDirs) == 0 {
|
||||
configDirs = discoverAccountDirs(home)
|
||||
}
|
||||
|
||||
winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000)
|
||||
if winMax < 4096 {
|
||||
winMax = 4096
|
||||
}
|
||||
if winMax > 32700 {
|
||||
winMax = 32700
|
||||
}
|
||||
|
||||
agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"})
|
||||
if agentBin == "" {
|
||||
agentBin = "agent"
|
||||
}
|
||||
commandShell := getEnvVal(e, []string{"COMSPEC"})
|
||||
if commandShell == "" {
|
||||
commandShell = "cmd.exe"
|
||||
}
|
||||
workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd)
|
||||
if workspace == "" {
|
||||
workspace = cwd
|
||||
}
|
||||
|
||||
geminiAccountDir := getEnvVal(e, []string{"GEMINI_ACCOUNT_DIR"})
|
||||
if geminiAccountDir == "" {
|
||||
geminiAccountDir = filepath.Join(home, ".cursor-api-proxy", "gemini-accounts")
|
||||
} else {
|
||||
geminiAccountDir = resolveAbs(geminiAccountDir, cwd)
|
||||
}
|
||||
|
||||
return LoadedEnv{
|
||||
AgentBin: agentBin,
|
||||
AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}),
|
||||
AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}),
|
||||
CommandShell: commandShell,
|
||||
Host: host,
|
||||
Port: port,
|
||||
RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}),
|
||||
DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})),
|
||||
Provider: getEnvVal(e, []string{"CURSOR_BRIDGE_PROVIDER"}),
|
||||
Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false),
|
||||
ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false),
|
||||
StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true),
|
||||
Workspace: workspace,
|
||||
TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000),
|
||||
TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd),
|
||||
TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd),
|
||||
SessionsLogPath: sessionsLogPath,
|
||||
ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true),
|
||||
Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false),
|
||||
MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false),
|
||||
ConfigDirs: configDirs,
|
||||
MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false),
|
||||
WinCmdlineMax: winMax,
|
||||
GeminiAccountDir: geminiAccountDir,
|
||||
GeminiBrowserVisible: envBool(e, []string{"GEMINI_BROWSER_VISIBLE"}, false),
|
||||
GeminiMaxSessions: envInt(e, []string{"GEMINI_MAX_SESSIONS"}, 3),
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand {
|
||||
if e == nil {
|
||||
e = OsEnvToMap()
|
||||
}
|
||||
loaded := LoadEnvConfig(e, cwd)
|
||||
|
||||
cloneEnv := func() map[string]string {
|
||||
m := make(map[string]string, len(e))
|
||||
for k, v := range e {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if loaded.AgentNode != "" && loaded.AgentScript != "" {
|
||||
agentScriptPath := loaded.AgentScript
|
||||
if !filepath.IsAbs(agentScriptPath) {
|
||||
agentScriptPath = filepath.Join(cwd, agentScriptPath)
|
||||
}
|
||||
agentDir := filepath.Dir(agentScriptPath)
|
||||
configDir := filepath.Join(agentDir, "..", "data", "config")
|
||||
env2 := cloneEnv()
|
||||
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
|
||||
ac := AgentCommand{
|
||||
Command: loaded.AgentNode,
|
||||
Args: append([]string{loaded.AgentScript}, args...),
|
||||
Env: env2,
|
||||
AgentScriptPath: agentScriptPath,
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
|
||||
ac.ConfigDir = configDir
|
||||
}
|
||||
return ac
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.ToLower(cmd), ".cmd") {
|
||||
cmdResolved := cmd
|
||||
if !filepath.IsAbs(cmd) {
|
||||
cmdResolved = filepath.Join(cwd, cmd)
|
||||
}
|
||||
dir := filepath.Dir(cmdResolved)
|
||||
nodeBin := filepath.Join(dir, "node.exe")
|
||||
script := filepath.Join(dir, "index.js")
|
||||
if _, err1 := os.Stat(nodeBin); err1 == nil {
|
||||
if _, err2 := os.Stat(script); err2 == nil {
|
||||
configDir := filepath.Join(dir, "..", "data", "config")
|
||||
env2 := cloneEnv()
|
||||
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
|
||||
ac := AgentCommand{
|
||||
Command: nodeBin,
|
||||
Args: append([]string{script}, args...),
|
||||
Env: env2,
|
||||
AgentScriptPath: script,
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
|
||||
ac.ConfigDir = configDir
|
||||
}
|
||||
return ac
|
||||
}
|
||||
}
|
||||
|
||||
quotedArgs := make([]string, len(args))
|
||||
for i, a := range args {
|
||||
if strings.Contains(a, " ") {
|
||||
quotedArgs[i] = `"` + a + `"`
|
||||
} else {
|
||||
quotedArgs[i] = a
|
||||
}
|
||||
}
|
||||
cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"`
|
||||
return AgentCommand{
|
||||
Command: loaded.CommandShell,
|
||||
Args: []string{"/d", "/s", "/c", cmdLine},
|
||||
Env: cloneEnv(),
|
||||
WindowsVerbatimArguments: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()}
|
||||
}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
package env
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLoadEnvConfigDefaults(t *testing.T) {
|
||||
e := EnvSource{}
|
||||
loaded := LoadEnvConfig(e, "/tmp")
|
||||
|
||||
if loaded.Host != "127.0.0.1" {
|
||||
t.Errorf("expected 127.0.0.1, got %s", loaded.Host)
|
||||
}
|
||||
if loaded.Port != 8765 {
|
||||
t.Errorf("expected 8765, got %d", loaded.Port)
|
||||
}
|
||||
if loaded.DefaultModel != "auto" {
|
||||
t.Errorf("expected auto, got %s", loaded.DefaultModel)
|
||||
}
|
||||
if loaded.AgentBin != "agent" {
|
||||
t.Errorf("expected agent, got %s", loaded.AgentBin)
|
||||
}
|
||||
if !loaded.StrictModel {
|
||||
t.Error("expected strictModel=true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEnvConfigOverride(t *testing.T) {
|
||||
e := EnvSource{
|
||||
"CURSOR_BRIDGE_HOST": "0.0.0.0",
|
||||
"CURSOR_BRIDGE_PORT": "9000",
|
||||
"CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4",
|
||||
"CURSOR_AGENT_BIN": "/usr/local/bin/agent",
|
||||
}
|
||||
loaded := LoadEnvConfig(e, "/tmp")
|
||||
|
||||
if loaded.Host != "0.0.0.0" {
|
||||
t.Errorf("expected 0.0.0.0, got %s", loaded.Host)
|
||||
}
|
||||
if loaded.Port != 9000 {
|
||||
t.Errorf("expected 9000, got %d", loaded.Port)
|
||||
}
|
||||
if loaded.DefaultModel != "gpt-4" {
|
||||
t.Errorf("expected gpt-4, got %s", loaded.DefaultModel)
|
||||
}
|
||||
if loaded.AgentBin != "/usr/local/bin/agent" {
|
||||
t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"gpt-4", "gpt-4"},
|
||||
{"openai/gpt-4", "gpt-4"},
|
||||
{"", "auto"},
|
||||
{" ", "auto"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := normalizeModelId(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`)
|
||||
|
||||
func ExtractBearerToken(r *http.Request) string {
|
||||
h := r.Header.Get("Authorization")
|
||||
if h == "" {
|
||||
return ""
|
||||
}
|
||||
m := bearerRe.FindStringSubmatch(h)
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m[1]
|
||||
}
|
||||
|
||||
func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) {
|
||||
for k, v := range extraHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) {
|
||||
for k, v := range extraHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
|
||||
func ReadBody(r *http.Request) (string, error) {
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractBearerToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{"Bearer mytoken123", "mytoken123"},
|
||||
{"bearer MYTOKEN", "MYTOKEN"},
|
||||
{"", ""},
|
||||
{"Basic abc", ""},
|
||||
{"Bearer ", ""},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.header != "" {
|
||||
req.Header.Set("Authorization", tc.header)
|
||||
}
|
||||
got := ExtractBearerToken(req)
|
||||
if got != tc.want {
|
||||
t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
WriteJSON(w, 200, map[string]string{"ok": "true"}, nil)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Content-Type") != "application/json" {
|
||||
t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSONWithExtraHeaders(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"})
|
||||
|
||||
if w.Header().Get("X-Custom") != "value" {
|
||||
t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom"))
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,309 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
cReset = "\x1b[0m"
|
||||
cBold = "\x1b[1m"
|
||||
cDim = "\x1b[2m"
|
||||
cCyan = "\x1b[36m"
|
||||
cBCyan = "\x1b[1;96m"
|
||||
cGreen = "\x1b[32m"
|
||||
cBGreen = "\x1b[1;92m"
|
||||
cYellow = "\x1b[33m"
|
||||
cMagenta = "\x1b[35m"
|
||||
cBMagenta = "\x1b[1;95m"
|
||||
cRed = "\x1b[31m"
|
||||
cGray = "\x1b[90m"
|
||||
cWhite = "\x1b[97m"
|
||||
)
|
||||
|
||||
var roleStyle = map[string]string{
|
||||
"system": cYellow,
|
||||
"user": cCyan,
|
||||
"assistant": cGreen,
|
||||
}
|
||||
|
||||
var roleEmoji = map[string]string{
|
||||
"system": "🔧",
|
||||
"user": "👤",
|
||||
"assistant": "🤖",
|
||||
}
|
||||
|
||||
func ts() string {
|
||||
return cGray + time.Now().UTC().Format("15:04:05") + cReset
|
||||
}
|
||||
|
||||
func tsDate() string {
|
||||
return cGray + time.Now().UTC().Format("2006-01-02 15:04:05") + cReset
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
head := int(float64(max) * 0.6)
|
||||
tail := max - head
|
||||
omitted := len(s) - head - tail
|
||||
return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset
|
||||
}
|
||||
|
||||
func hr(ch string, length int) string {
|
||||
return cGray + strings.Repeat(ch, length) + cReset
|
||||
}
|
||||
|
||||
type TrafficMessage struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
func LogDebug(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg)
|
||||
}
|
||||
|
||||
func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) {
|
||||
provider := cfg.Provider
|
||||
if provider == "" {
|
||||
provider = "cursor"
|
||||
}
|
||||
fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset)
|
||||
fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n",
|
||||
cBold, cBCyan, cReset, cBold, cWhite, version, cBCyan, cReset)
|
||||
fmt.Printf("%s%s╚══════════════════════════════════════════╝%s\n\n", cBold, cBCyan, cReset)
|
||||
url := fmt.Sprintf("%s://%s:%d", scheme, host, port)
|
||||
fmt.Printf(" %s●%s listening %s%s%s\n", cBGreen, cReset, cBold, url, cReset)
|
||||
fmt.Printf(" %s▸%s provider %s%s%s\n", cCyan, cReset, cBold, provider, cReset)
|
||||
fmt.Printf(" %s▸%s agent %s%s%s\n", cCyan, cReset, cDim, cfg.AgentBin, cReset)
|
||||
fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset)
|
||||
fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset)
|
||||
fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset)
|
||||
fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset)
|
||||
|
||||
// 顯示 Gemini Web Provider 相關設定
|
||||
if provider == "gemini-web" {
|
||||
fmt.Printf(" %s▸%s gemini-dir %s%s%s\n", cCyan, cReset, cDim, cfg.GeminiAccountDir, cReset)
|
||||
fmt.Printf(" %s▸%s max-sess %s%d%s\n", cCyan, cReset, cDim, cfg.GeminiMaxSessions, cReset)
|
||||
}
|
||||
|
||||
flags := []string{}
|
||||
if cfg.Force {
|
||||
flags = append(flags, "force")
|
||||
}
|
||||
if cfg.ApproveMcps {
|
||||
flags = append(flags, "approve-mcps")
|
||||
}
|
||||
if cfg.MaxMode {
|
||||
flags = append(flags, "max-mode")
|
||||
}
|
||||
if cfg.Verbose {
|
||||
flags = append(flags, "verbose")
|
||||
}
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
flags = append(flags, "chat-only")
|
||||
}
|
||||
if cfg.RequiredKey != "" {
|
||||
flags = append(flags, "api-key-required")
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
fmt.Printf(" %s▸%s flags %s%s%s\n", cCyan, cReset, cYellow, strings.Join(flags, " · "), cReset)
|
||||
}
|
||||
if len(cfg.ConfigDirs) > 0 {
|
||||
fmt.Printf(" %s▸%s pool %s%d accounts%s\n", cCyan, cReset, cBGreen, len(cfg.ConfigDirs), cReset)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func LogShutdown(sig string) {
|
||||
fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset)
|
||||
}
|
||||
|
||||
func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) {
|
||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
|
||||
}
|
||||
fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n",
|
||||
ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag)
|
||||
}
|
||||
|
||||
func LogRequestDone(method, pathname, model string, latencyMs int64, code int) {
|
||||
statusColor := cBGreen
|
||||
if code != 0 {
|
||||
statusColor = cRed
|
||||
}
|
||||
fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n",
|
||||
ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset)
|
||||
}
|
||||
|
||||
func LogRequestTimeout(method, pathname, model string, timeoutMs int) {
|
||||
fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n",
|
||||
ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset)
|
||||
}
|
||||
|
||||
func LogClientDisconnect(method, pathname, model string, latencyMs int64) {
|
||||
fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n",
|
||||
ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset)
|
||||
}
|
||||
|
||||
func LogStreamChunk(model string, text string, chunkNum int) {
|
||||
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120)
|
||||
fmt.Printf("%s %s▸%s #%d %s%s%s\n",
|
||||
ts(), cDim, cReset, chunkNum, cWhite, preview, cReset)
|
||||
}
|
||||
|
||||
func LogRawLine(line string) {
|
||||
preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200)
|
||||
fmt.Printf("%s %s│%s %sraw%s %s\n",
|
||||
ts(), cGray, cReset, cDim, cReset, preview)
|
||||
}
|
||||
|
||||
func LogIncoming(method, pathname, remoteAddress string) {
|
||||
methodColor := cBCyan
|
||||
switch method {
|
||||
case "POST":
|
||||
methodColor = cBMagenta
|
||||
case "GET":
|
||||
methodColor = cBCyan
|
||||
case "DELETE":
|
||||
methodColor = cRed
|
||||
}
|
||||
fmt.Printf("%s %s%s%s%s %s%s%s %s(%s)%s\n",
|
||||
ts(),
|
||||
methodColor, cBold, method, cReset,
|
||||
cWhite, pathname, cReset,
|
||||
cDim, remoteAddress, cReset,
|
||||
)
|
||||
}
|
||||
|
||||
func LogAccountAssigned(configDir string) {
|
||||
if configDir == "" {
|
||||
return
|
||||
}
|
||||
name := filepath.Base(configDir)
|
||||
fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset)
|
||||
}
|
||||
|
||||
func LogAccountStats(verbose bool, stats []pool.AccountStat) {
|
||||
if !verbose || len(stats) == 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset)
|
||||
for _, s := range stats {
|
||||
name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir))
|
||||
active := fmt.Sprintf("%sactive:0%s", cDim, cReset)
|
||||
if s.ActiveRequests > 0 {
|
||||
active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset)
|
||||
}
|
||||
total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset)
|
||||
ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset)
|
||||
errStr := fmt.Sprintf("%serr:0%s", cDim, cReset)
|
||||
if s.TotalErrors > 0 {
|
||||
errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset)
|
||||
}
|
||||
rl := fmt.Sprintf("%srl:0%s", cDim, cReset)
|
||||
if s.TotalRateLimits > 0 {
|
||||
rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset)
|
||||
}
|
||||
avg := "avg:-"
|
||||
if s.TotalRequests > 0 {
|
||||
avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests))
|
||||
}
|
||||
status := fmt.Sprintf("%s✓%s", cGreen, cReset)
|
||||
if s.IsRateLimited {
|
||||
recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339)
|
||||
_ = now
|
||||
status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset)
|
||||
}
|
||||
fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n",
|
||||
cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status)
|
||||
}
|
||||
fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset)
|
||||
}
|
||||
|
||||
func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) {
|
||||
if !verbose {
|
||||
return
|
||||
}
|
||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
|
||||
}
|
||||
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
|
||||
fmt.Println(hr("─", 60))
|
||||
fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag)
|
||||
for _, m := range messages {
|
||||
roleColor := cWhite
|
||||
if c, ok := roleStyle[m.Role]; ok {
|
||||
roleColor = c
|
||||
}
|
||||
emoji := "💬"
|
||||
if e, ok := roleEmoji[m.Role]; ok {
|
||||
emoji = e
|
||||
}
|
||||
label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset)
|
||||
charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset)
|
||||
preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280)
|
||||
fmt.Printf(" %s %s %s\n", emoji, label, charCount)
|
||||
fmt.Printf(" %s%s%s\n", cDim, preview, cReset)
|
||||
}
|
||||
}
|
||||
|
||||
func LogTrafficResponse(verbose bool, model, text string, isStream bool) {
|
||||
if !verbose {
|
||||
return
|
||||
}
|
||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset)
|
||||
}
|
||||
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
|
||||
charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset)
|
||||
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480)
|
||||
fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount)
|
||||
fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset)
|
||||
fmt.Println(hr("─", 60))
|
||||
}
|
||||
|
||||
func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) {
|
||||
line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode)
|
||||
dir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(dir, 0755); err == nil {
|
||||
f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err == nil {
|
||||
_, _ = f.WriteString(line)
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func LogTruncation(originalLen, finalLen int) {
|
||||
fmt.Printf("%s %s⚠ prompt truncated%s %s(%d → %d chars, tail preserved)%s\n",
|
||||
ts(), cYellow, cReset, cDim, originalLen, finalLen, cReset)
|
||||
}
|
||||
|
||||
func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string {
|
||||
errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr))
|
||||
fmt.Fprintf(os.Stderr, "%s %s✗ agent error%s %s%s%s\n", ts(), cRed, cReset, cDim, errMsg, cReset)
|
||||
truncated := strings.TrimSpace(stderr)
|
||||
if len(truncated) > 200 {
|
||||
truncated = truncated[:200]
|
||||
}
|
||||
truncated = strings.ReplaceAll(truncated, "\n", " ")
|
||||
line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n",
|
||||
time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated)
|
||||
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
|
||||
_, _ = f.WriteString(line)
|
||||
f.Close()
|
||||
}
|
||||
return errMsg
|
||||
}
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
package parser
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type StreamParser func(line string)
|
||||
|
||||
type Parser struct {
|
||||
Parse StreamParser
|
||||
Flush func()
|
||||
}
|
||||
|
||||
// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking)
|
||||
func CreateStreamParser(onText func(string), onDone func()) Parser {
|
||||
return CreateStreamParserWithThinking(onText, nil, onDone)
|
||||
}
|
||||
|
||||
// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。
|
||||
// onThinking 可為 nil,表示忽略思考過程。
|
||||
func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser {
|
||||
// accumulated 是所有已輸出內容的串接
|
||||
accumulatedText := ""
|
||||
accumulatedThinking := ""
|
||||
done := false
|
||||
|
||||
parse := func(line string) {
|
||||
if done {
|
||||
return
|
||||
}
|
||||
|
||||
var obj struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
Message *struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Thinking string `json:"thinking"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(line), &obj); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if obj.Type == "assistant" && obj.Message != nil {
|
||||
fullText := ""
|
||||
fullThinking := ""
|
||||
for _, p := range obj.Message.Content {
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
fullText += p.Text
|
||||
}
|
||||
case "thinking":
|
||||
if p.Thinking != "" {
|
||||
fullThinking += p.Thinking
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 處理思考過程(不因去重而 return,避免跳過同行的文字內容)
|
||||
if onThinking != nil && fullThinking != "" && fullThinking != accumulatedThinking {
|
||||
// 增量模式:新內容以 accumulated 為前綴
|
||||
if len(fullThinking) >= len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking {
|
||||
delta := fullThinking[len(accumulatedThinking):]
|
||||
if delta != "" {
|
||||
onThinking(delta)
|
||||
}
|
||||
accumulatedThinking = fullThinking
|
||||
} else {
|
||||
// 獨立片段:直接輸出,但 accumulated 要串接
|
||||
onThinking(fullThinking)
|
||||
accumulatedThinking = accumulatedThinking + fullThinking
|
||||
}
|
||||
}
|
||||
|
||||
// 處理一般文字
|
||||
if fullText == "" || fullText == accumulatedText {
|
||||
return
|
||||
}
|
||||
// 增量模式:新內容以 accumulated 為前綴
|
||||
if len(fullText) >= len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText {
|
||||
delta := fullText[len(accumulatedText):]
|
||||
if delta != "" {
|
||||
onText(delta)
|
||||
}
|
||||
accumulatedText = fullText
|
||||
} else {
|
||||
// 獨立片段:直接輸出,但 accumulated 要串接
|
||||
onText(fullText)
|
||||
accumulatedText = accumulatedText + fullText
|
||||
}
|
||||
}
|
||||
|
||||
if obj.Type == "result" && obj.Subtype == "success" {
|
||||
done = true
|
||||
onDone()
|
||||
}
|
||||
}
|
||||
|
||||
flush := func() {
|
||||
if !done {
|
||||
done = true
|
||||
onDone()
|
||||
}
|
||||
}
|
||||
|
||||
return Parser{Parse: parse, Flush: flush}
|
||||
}
|
||||
|
|
@ -0,0 +1,304 @@
|
|||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeAssistantLine(text string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func makeResultLine() string {
|
||||
b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"})
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestStreamParserFragmentMode(t *testing.T) {
|
||||
// cursor --stream-partial-output 模式:每個訊息是獨立 token fragment
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeAssistantLine("你"))
|
||||
p.Parse(makeAssistantLine("好!有"))
|
||||
p.Parse(makeAssistantLine("什"))
|
||||
p.Parse(makeAssistantLine("麼"))
|
||||
|
||||
if len(texts) != 4 {
|
||||
t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts)
|
||||
}
|
||||
if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" {
|
||||
t.Fatalf("unexpected texts: %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserDeduplicatesFinalFullText(t *testing.T) {
|
||||
// 最後一個訊息是完整的累積文字,應被跳過(去重)
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeAssistantLine("Hello"))
|
||||
p.Parse(makeAssistantLine(" world"))
|
||||
// 最後一個是完整累積文字,應被去重
|
||||
p.Parse(makeAssistantLine("Hello world"))
|
||||
|
||||
if len(texts) != 2 {
|
||||
t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts)
|
||||
}
|
||||
if texts[0] != "Hello" || texts[1] != " world" {
|
||||
t.Fatalf("unexpected texts: %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserCallsOnDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeResultLine())
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no text, got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresLinesAfterDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeResultLine())
|
||||
p.Parse(makeAssistantLine("late"))
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no text after done, got %v", texts)
|
||||
}
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresNonAssistantLines(t *testing.T) {
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}})
|
||||
p.Parse(string(b1))
|
||||
b2, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{"content": []interface{}{}},
|
||||
})
|
||||
p.Parse(string(b2))
|
||||
p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`)
|
||||
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no texts, got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserIgnoresParseErrors(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse("not json")
|
||||
p.Parse("{")
|
||||
p.Parse("")
|
||||
|
||||
if len(texts) != 0 || doneCount != 0 {
|
||||
t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserJoinsMultipleTextParts(t *testing.T) {
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": " "},
|
||||
{"type": "text", "text": "world"},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
p.Parse(string(b))
|
||||
|
||||
if len(texts) != 1 || texts[0] != "Hello world" {
|
||||
t.Fatalf("expected ['Hello world'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserFlushTriggersDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeAssistantLine("Hello"))
|
||||
// agent 結束但沒有 result/success,手動 flush
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once after Flush, got %d", doneCount)
|
||||
}
|
||||
// 再 flush 不應重複觸發
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called only once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) {
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) {},
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeResultLine())
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func makeThinkingLine(thinking string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "thinking", "thinking": thinking},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func makeThinkingAndTextLine(thinking, text string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "thinking", "thinking": thinking},
|
||||
{"type": "text", "text": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) {
|
||||
var texts []string
|
||||
var thinkings []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("思考中..."))
|
||||
p.Parse(makeAssistantLine("回答"))
|
||||
|
||||
if len(thinkings) != 1 || thinkings[0] != "思考中..." {
|
||||
t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings)
|
||||
}
|
||||
if len(texts) != 1 || texts[0] != "回答" {
|
||||
t.Fatalf("expected texts=['回答'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) {
|
||||
var texts []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
nil,
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("忽略的思考"))
|
||||
p.Parse(makeAssistantLine("文字"))
|
||||
|
||||
if len(texts) != 1 || texts[0] != "文字" {
|
||||
t.Fatalf("expected texts=['文字'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingDeduplication(t *testing.T) {
|
||||
var thinkings []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) {},
|
||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("A"))
|
||||
p.Parse(makeThinkingLine("B"))
|
||||
// 重複的完整思考,應被跳過
|
||||
p.Parse(makeThinkingLine("AB"))
|
||||
|
||||
if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" {
|
||||
t.Fatalf("expected thinkings=['A','B'], got %v", thinkings)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamParserThinkingDuplicateButTextStillEmitted 驗證 bug 修復:
|
||||
// 當 thinking 重複(去重跳過)但同一行有 text 時,text 仍必須輸出。
|
||||
func TestStreamParserThinkingDuplicateButTextStillEmitted(t *testing.T) {
|
||||
var texts []string
|
||||
var thinkings []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
// 第一行:thinking="思考中" + text(thinking 為新增,兩者都應輸出)
|
||||
p.Parse(makeThinkingAndTextLine("思考中", "第一段"))
|
||||
// 第二行:thinking 與上一行相同(去重),但 text 是新的,text 仍應輸出
|
||||
p.Parse(makeThinkingAndTextLine("思考中", "第二段"))
|
||||
|
||||
if len(thinkings) != 1 || thinkings[0] != "思考中" {
|
||||
t.Fatalf("expected thinkings=['思考中'], got %v", thinkings)
|
||||
}
|
||||
if len(texts) != 2 || texts[0] != "第一段" || texts[1] != "第二段" {
|
||||
t.Fatalf("expected texts=['第一段','第二段'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func killProcessGroup(c *exec.Cmd) error {
|
||||
if c.Process == nil {
|
||||
return nil
|
||||
}
|
||||
// 殺死整個 process group(負號表示 group)
|
||||
pgid, err := syscall.Getpgid(c.Process.Pid)
|
||||
if err == nil {
|
||||
_ = syscall.Kill(-pgid, syscall.SIGKILL)
|
||||
}
|
||||
// 同時也 kill 主程序,以防萬一
|
||||
return c.Process.Kill()
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func killProcessGroup(c *exec.Cmd) error {
|
||||
if c.Process == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Process.Kill()
|
||||
}
|
||||
|
|
@ -0,0 +1,283 @@
|
|||
package process_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/process"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// sh 是跨平台 shell 執行小 script 的輔助函式
|
||||
func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) {
|
||||
t.Helper()
|
||||
return process.Run("sh", []string{"-c", script}, opts)
|
||||
}
|
||||
|
||||
func TestRun_StdoutAndStderr(t *testing.T) {
|
||||
result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if result.Stdout != "hello\n" {
|
||||
t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n")
|
||||
}
|
||||
if result.Stderr != "world\n" {
|
||||
t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_BasicSpawn(t *testing.T) {
|
||||
result, err := sh(t, "printf ok", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if result.Stdout != "ok" {
|
||||
t.Errorf("Stdout = %q, want ok", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ConfigDir_Propagated(t *testing.T) {
|
||||
result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
|
||||
process.RunOptions{ConfigDir: "/test/account/dir"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stdout != "/test/account/dir" {
|
||||
t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ConfigDir_Absent(t *testing.T) {
|
||||
// 確保沒有殘留的環境變數
|
||||
_ = os.Unsetenv("CURSOR_CONFIG_DIR")
|
||||
result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`},
|
||||
process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stdout != "unset" {
|
||||
t.Errorf("Stdout = %q, want unset", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_NonZeroExit(t *testing.T) {
|
||||
result, err := sh(t, "exit 42", process.RunOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 42 {
|
||||
t.Errorf("Code = %d, want 42", result.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_Timeout(t *testing.T) {
|
||||
start := time.Now()
|
||||
result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300})
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after timeout")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_OnLine(t *testing.T) {
|
||||
var lines []string
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if len(lines) != 3 {
|
||||
t.Errorf("got %d lines, want 3: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
|
||||
t.Errorf("lines = %v, want [a b c]", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_FlushFinalLine(t *testing.T) {
|
||||
var lines []string
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "printf tail"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
t.Errorf("Code = %d, want 0", result.Code)
|
||||
}
|
||||
if len(lines) != 1 {
|
||||
t.Errorf("got %d lines, want 1: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "tail" {
|
||||
t.Errorf("lines[0] = %q, want tail", lines[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_ConfigDir(t *testing.T) {
|
||||
var lines []string
|
||||
_, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"},
|
||||
OnLine: func(line string) { lines = append(lines, line) },
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(lines) != 1 || lines[0] != "/my/config/dir" {
|
||||
t.Errorf("lines = %v, want [/my/config/dir]", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Stderr(t *testing.T) {
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"},
|
||||
process.RunStreamingOptions{OnLine: func(string) {}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Stderr == "" {
|
||||
t.Error("expected stderr to contain output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Timeout(t *testing.T) {
|
||||
start := time.Now()
|
||||
result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{TimeoutMs: 300},
|
||||
OnLine: func(string) {},
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after timeout")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_Concurrent(t *testing.T) {
|
||||
var lines1, lines2 []string
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
run := func(label string, target *[]string) {
|
||||
process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"},
|
||||
process.RunStreamingOptions{
|
||||
OnLine: func(line string) { *target = append(*target, line) },
|
||||
})
|
||||
done <- struct{}{}
|
||||
}
|
||||
|
||||
go run("stream1", &lines1)
|
||||
go run("stream2", &lines2)
|
||||
|
||||
<-done
|
||||
<-done
|
||||
|
||||
if len(lines1) != 1 || lines1[0] != "stream1" {
|
||||
t.Errorf("lines1 = %v, want [stream1]", lines1)
|
||||
}
|
||||
if len(lines2) != 1 || lines2[0] != "stream2" {
|
||||
t.Errorf("lines2 = %v, want [stream2]", lines2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreaming_ContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
process.RunStreaming("sh", []string{"-c", "sleep 30"},
|
||||
process.RunStreamingOptions{
|
||||
RunOptions: process.RunOptions{Ctx: ctx},
|
||||
OnLine: func(string) {},
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.AfterFunc(100*time.Millisecond, cancel)
|
||||
<-done
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan process.RunResult, 1)
|
||||
|
||||
go func() {
|
||||
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
|
||||
done <- r
|
||||
}()
|
||||
|
||||
time.AfterFunc(100*time.Millisecond, cancel)
|
||||
result := <-done
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after cancel")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_AlreadyCancelledContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 已取消
|
||||
|
||||
start := time.Now()
|
||||
result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code")
|
||||
}
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKillAllChildProcesses(t *testing.T) {
|
||||
done := make(chan process.RunResult, 1)
|
||||
go func() {
|
||||
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{})
|
||||
done <- r
|
||||
}()
|
||||
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
process.KillAllChildProcesses()
|
||||
result := <-done
|
||||
|
||||
if result.Code == 0 {
|
||||
t.Error("expected non-zero exit code after kill")
|
||||
}
|
||||
// 再次呼叫不應 panic
|
||||
process.KillAllChildProcesses()
|
||||
}
|
||||
|
|
@ -0,0 +1,250 @@
|
|||
package process
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"cursor-api-proxy/internal/env"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RunResult struct {
|
||||
Code int
|
||||
Stdout string
|
||||
Stderr string
|
||||
}
|
||||
|
||||
type RunOptions struct {
|
||||
Cwd string
|
||||
TimeoutMs int
|
||||
MaxMode bool
|
||||
ConfigDir string
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
type RunStreamingOptions struct {
|
||||
RunOptions
|
||||
OnLine func(line string)
|
||||
}
|
||||
|
||||
// ─── Global child process registry ──────────────────────────────────────────
|
||||
|
||||
var (
|
||||
activeMu sync.Mutex
|
||||
activeChildren []*exec.Cmd
|
||||
)
|
||||
|
||||
func registerChild(c *exec.Cmd) {
|
||||
activeMu.Lock()
|
||||
activeChildren = append(activeChildren, c)
|
||||
activeMu.Unlock()
|
||||
}
|
||||
|
||||
func unregisterChild(c *exec.Cmd) {
|
||||
activeMu.Lock()
|
||||
for i, ch := range activeChildren {
|
||||
if ch == c {
|
||||
activeChildren = append(activeChildren[:i], activeChildren[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
activeMu.Unlock()
|
||||
}
|
||||
|
||||
func KillAllChildProcesses() {
|
||||
activeMu.Lock()
|
||||
all := make([]*exec.Cmd, len(activeChildren))
|
||||
copy(all, activeChildren)
|
||||
activeChildren = nil
|
||||
activeMu.Unlock()
|
||||
for _, c := range all {
|
||||
killProcessGroup(c)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Spawn ────────────────────────────────────────────────────────────────
|
||||
|
||||
func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd {
|
||||
envSrc := env.OsEnvToMap()
|
||||
resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd)
|
||||
|
||||
if opts.MaxMode && maxModeFn != nil {
|
||||
maxModeFn(resolved.AgentScriptPath, opts.ConfigDir)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string, len(resolved.Env))
|
||||
for k, v := range resolved.Env {
|
||||
envMap[k] = v
|
||||
}
|
||||
if opts.ConfigDir != "" {
|
||||
envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir
|
||||
} else if resolved.ConfigDir != "" {
|
||||
if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists {
|
||||
envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir
|
||||
}
|
||||
}
|
||||
|
||||
envSlice := make([]string, 0, len(envMap))
|
||||
for k, v := range envMap {
|
||||
envSlice = append(envSlice, k+"="+v)
|
||||
}
|
||||
|
||||
ctx := opts.Ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出
|
||||
c := exec.CommandContext(ctx, resolved.Command, resolved.Args...)
|
||||
c.Dir = opts.Cwd
|
||||
c.Env = envSlice
|
||||
// 設定新的 process group,使 kill 能傳遞給所有子孫程序
|
||||
c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
// WaitDelay:context cancel 後額外等待這麼久再強制關閉 pipes
|
||||
c.WaitDelay = 5 * time.Second
|
||||
// Cancel 函式:殺死整個 process group
|
||||
c.Cancel = func() error {
|
||||
return killProcessGroup(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// MaxModeFn is set by the agent package to avoid import cycle.
|
||||
var MaxModeFn func(agentScriptPath, configDir string)
|
||||
|
||||
func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) {
|
||||
ctx := opts.Ctx
|
||||
var cancel context.CancelFunc
|
||||
if opts.TimeoutMs > 0 {
|
||||
if ctx == nil {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
} else {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
opts.Ctx = ctx
|
||||
} else if ctx == nil {
|
||||
opts.Ctx = context.Background()
|
||||
}
|
||||
|
||||
c := spawnChild(cmdStr, args, &opts, MaxModeFn)
|
||||
var stdoutBuf, stderrBuf strings.Builder
|
||||
c.Stdout = &stdoutBuf
|
||||
c.Stderr = &stderrBuf
|
||||
|
||||
if err := c.Start(); err != nil {
|
||||
// context 已取消或命令找不到時
|
||||
if opts.Ctx != nil && opts.Ctx.Err() != nil {
|
||||
return RunResult{Code: -1}, nil
|
||||
}
|
||||
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
|
||||
return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
|
||||
}
|
||||
return RunResult{}, err
|
||||
}
|
||||
registerChild(c)
|
||||
defer unregisterChild(c)
|
||||
|
||||
err := c.Wait()
|
||||
code := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
code = exitErr.ExitCode()
|
||||
if code == 0 {
|
||||
code = -1
|
||||
}
|
||||
} else {
|
||||
// context cancelled or killed — return -1 but no error
|
||||
return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil
|
||||
}
|
||||
}
|
||||
return RunResult{
|
||||
Code: code,
|
||||
Stdout: stdoutBuf.String(),
|
||||
Stderr: stderrBuf.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type StreamResult struct {
|
||||
Code int
|
||||
Stderr string
|
||||
}
|
||||
|
||||
func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) {
|
||||
ctx := opts.Ctx
|
||||
var cancel context.CancelFunc
|
||||
if opts.TimeoutMs > 0 {
|
||||
if ctx == nil {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
} else {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
opts.RunOptions.Ctx = ctx
|
||||
} else if opts.RunOptions.Ctx == nil {
|
||||
opts.RunOptions.Ctx = context.Background()
|
||||
}
|
||||
|
||||
c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn)
|
||||
stdoutPipe, err := c.StdoutPipe()
|
||||
if err != nil {
|
||||
return StreamResult{}, err
|
||||
}
|
||||
stderrPipe, err := c.StderrPipe()
|
||||
if err != nil {
|
||||
return StreamResult{}, err
|
||||
}
|
||||
|
||||
if err := c.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
|
||||
return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
|
||||
}
|
||||
return StreamResult{}, err
|
||||
}
|
||||
registerChild(c)
|
||||
defer unregisterChild(c)
|
||||
|
||||
var stderrBuf strings.Builder
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) != "" {
|
||||
opts.OnLine(line)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
stderrBuf.WriteString(scanner.Text())
|
||||
stderrBuf.WriteString("\n")
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
err = c.Wait()
|
||||
code := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
code = exitErr.ExitCode()
|
||||
if code == 0 {
|
||||
code = -1
|
||||
}
|
||||
}
|
||||
}
|
||||
return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,181 @@
|
|||
package winlimit
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/env"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n"
|
||||
const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n"
|
||||
|
||||
// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux.
|
||||
// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars.
|
||||
func safeLinuxArgMax() int {
|
||||
return 1536 * 1024
|
||||
}
|
||||
|
||||
type FitPromptResult struct {
|
||||
OK bool
|
||||
Args []string
|
||||
Truncated bool
|
||||
OriginalLength int
|
||||
FinalPromptLength int
|
||||
Error string
|
||||
}
|
||||
|
||||
func estimateCmdlineLength(resolved env.AgentCommand) int {
|
||||
argv := append([]string{resolved.Command}, resolved.Args...)
|
||||
if resolved.WindowsVerbatimArguments {
|
||||
n := 0
|
||||
for _, a := range argv {
|
||||
n += len(a)
|
||||
}
|
||||
if len(argv) > 1 {
|
||||
n += len(argv) - 1
|
||||
}
|
||||
return n + 512
|
||||
}
|
||||
dstLen := 0
|
||||
for _, a := range argv {
|
||||
dstLen += len(a)
|
||||
}
|
||||
dstLen = dstLen*2 + len(argv)*2
|
||||
if len(argv) > 1 {
|
||||
dstLen += len(argv) - 1
|
||||
}
|
||||
return dstLen + 512
|
||||
}
|
||||
|
||||
func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult {
|
||||
if runtime.GOOS != "windows" {
|
||||
return fitPromptLinux(fixedArgs, prompt)
|
||||
}
|
||||
|
||||
e := env.OsEnvToMap()
|
||||
measured := func(p string) int {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = p
|
||||
resolved := env.ResolveAgentCommand(agentBin, args, e, cwd)
|
||||
return estimateCmdlineLength(resolved)
|
||||
}
|
||||
|
||||
if measured("") > maxCmdline {
|
||||
return FitPromptResult{
|
||||
OK: false,
|
||||
Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.",
|
||||
}
|
||||
}
|
||||
|
||||
if measured(prompt) <= maxCmdline {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = prompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: false,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(prompt),
|
||||
}
|
||||
}
|
||||
|
||||
prefix := WinPromptOmissionPrefix
|
||||
if measured(prefix) > maxCmdline {
|
||||
return FitPromptResult{
|
||||
OK: false,
|
||||
Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.",
|
||||
}
|
||||
}
|
||||
|
||||
lo, hi, best := 0, len(prompt), 0
|
||||
for lo <= hi {
|
||||
mid := (lo + hi) / 2
|
||||
var tail string
|
||||
if mid > 0 {
|
||||
tail = prompt[len(prompt)-mid:]
|
||||
}
|
||||
candidate := prefix + tail
|
||||
if measured(candidate) <= maxCmdline {
|
||||
best = mid
|
||||
lo = mid + 1
|
||||
} else {
|
||||
hi = mid - 1
|
||||
}
|
||||
}
|
||||
|
||||
var finalPrompt string
|
||||
if best == 0 {
|
||||
finalPrompt = prefix
|
||||
} else {
|
||||
finalPrompt = prefix + prompt[len(prompt)-best:]
|
||||
}
|
||||
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = finalPrompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: true,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(finalPrompt),
|
||||
}
|
||||
}
|
||||
|
||||
// fitPromptLinux handles Linux ARG_MAX truncation.
|
||||
func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult {
|
||||
argMax := safeLinuxArgMax()
|
||||
|
||||
// Estimate total cmdline size: sum of all fixed args + prompt + null terminators
|
||||
fixedLen := 0
|
||||
for _, a := range fixedArgs {
|
||||
fixedLen += len(a) + 1
|
||||
}
|
||||
totalLen := fixedLen + len(prompt) + 1
|
||||
|
||||
if totalLen <= argMax {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = prompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: false,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(prompt),
|
||||
}
|
||||
}
|
||||
|
||||
// Need to truncate: keep the tail of the prompt (most recent messages)
|
||||
prefix := LinuxPromptOmissionPrefix
|
||||
available := argMax - fixedLen - len(prefix) - 1
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
var finalPrompt string
|
||||
if available <= 0 {
|
||||
finalPrompt = prefix
|
||||
} else if available >= len(prompt) {
|
||||
finalPrompt = prefix + prompt
|
||||
} else {
|
||||
finalPrompt = prefix + prompt[len(prompt)-available:]
|
||||
}
|
||||
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = finalPrompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: true,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(finalPrompt),
|
||||
}
|
||||
}
|
||||
|
||||
func WarnPromptTruncated(originalLength, finalLength int) {
|
||||
_ = originalLength
|
||||
_ = finalLength
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
package winlimit
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNonWindowsPassThrough(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping non-Windows test on Windows")
|
||||
}
|
||||
|
||||
fixedArgs := []string{"--print", "--model", "gpt-4"}
|
||||
prompt := "Hello world"
|
||||
result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp")
|
||||
|
||||
if !result.OK {
|
||||
t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error)
|
||||
}
|
||||
if result.Truncated {
|
||||
t.Error("expected no truncation on non-Windows")
|
||||
}
|
||||
if result.OriginalLength != len(prompt) {
|
||||
t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength)
|
||||
}
|
||||
// Last arg should be the prompt
|
||||
if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt {
|
||||
t.Errorf("expected last arg to be prompt, got %v", result.Args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOmissionPrefix(t *testing.T) {
|
||||
if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") {
|
||||
t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
package workspace
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/internal/config"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type WorkspaceResult struct {
|
||||
WorkspaceDir string
|
||||
TempDir string
|
||||
}
|
||||
|
||||
func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult {
|
||||
if cfg.ChatOnlyWorkspace {
|
||||
tempDir, err := os.MkdirTemp("", "cursor-proxy-")
|
||||
if err != nil {
|
||||
tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback")
|
||||
_ = os.MkdirAll(tempDir, 0700)
|
||||
}
|
||||
return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir}
|
||||
}
|
||||
|
||||
headerWs := strings.TrimSpace(workspaceHeader)
|
||||
if headerWs != "" {
|
||||
return WorkspaceResult{WorkspaceDir: headerWs}
|
||||
}
|
||||
return WorkspaceResult{WorkspaceDir: cfg.Workspace}
|
||||
}
|
||||
Loading…
Reference in New Issue