From 9919fc7bb975e1ebd86cc56c8bd2cc81642d19ab Mon Sep 17 00:00:00 2001 From: daniel Date: Wed, 1 Apr 2026 00:53:34 +0000 Subject: [PATCH] fix --- .env | 2 +- Makefile | 90 +++++- README.md | 206 +++++++++++++- internal/agent/cmdargs.go | 3 +- internal/anthropic/anthropic.go | 40 +++ internal/handlers/anthropic_handler.go | 351 +++++++++++++++++++---- internal/handlers/chat.go | 367 ++++++++++++++++++++----- internal/logger/logger.go | 46 ++++ internal/models/cursormap.go | 90 ++++-- internal/models/cursormap_test.go | 163 +++++++++++ internal/openai/openai.go | 73 ++++- internal/parser/stream.go | 90 ++++-- internal/parser/stream_test.go | 204 +++++++++++--- internal/pool/pool.go | 25 ++ internal/process/process.go | 2 + internal/router/router.go | 6 +- internal/server/server.go | 10 +- internal/toolcall/toolcall.go | 153 +++++++++++ internal/winlimit/winlimit.go | 71 ++++- 19 files changed, 1731 insertions(+), 261 deletions(-) create mode 100644 internal/models/cursormap_test.go create mode 100644 internal/toolcall/toolcall.go diff --git a/.env b/.env index 280992a..0bcfc8d 100644 --- a/.env +++ b/.env @@ -10,7 +10,7 @@ CURSOR_AGENT_NODE= CURSOR_AGENT_SCRIPT= CURSOR_BRIDGE_DEFAULT_MODEL=auto CURSOR_BRIDGE_STRICT_MODEL=true -CURSOR_BRIDGE_MAX_MODE=true +CURSOR_BRIDGE_MAX_MODE=false CURSOR_BRIDGE_FORCE=false CURSOR_BRIDGE_APPROVE_MCPS=false CURSOR_BRIDGE_WORKSPACE= diff --git a/Makefile b/Makefile index e11af0c..31a3175 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ AGENT_NODE ?= AGENT_SCRIPT ?= DEFAULT_MODEL ?= auto STRICT_MODEL ?= true -MAX_MODE ?= true +MAX_MODE ?= false FORCE ?= false APPROVE_MCPS ?= false @@ -37,7 +37,9 @@ SESSIONS_LOG ?= ENV_FILE ?= .env -.PHONY: env run build clean help +OPENCODE_CONFIG ?= $(HOME)/.config/opencode/opencode.json + +.PHONY: env run build clean help opencode opencode-models pm2 pm2-stop pm2-logs claude-code pm2-claude-code ## 產生 .env 檔(預設輸出至 .env,可用 ENV_FILE=xxx 覆寫) env: @@ -80,13 +82,89 @@ run: build clean: rm -f cursor-api-proxy $(ENV_FILE) +## 設定 OpenCode 使用此代理(更新 opencode.json 的 cursor provider) +opencode: build + @if [ ! -f "$(OPENCODE_CONFIG)" ]; then \ + echo "找不到 $(OPENCODE_CONFIG),建立新設定檔"; \ + mkdir -p $$(dirname "$(OPENCODE_CONFIG)"); \ + printf '{\n "provider": {\n "cursor": {\n "npm": "@ai-sdk/openai-compatible",\n "name": "Cursor Agent",\n "options": {\n "baseURL": "http://$(HOST):$(PORT)/v1",\n "apiKey": "unused"\n },\n "models": { "auto": { "name": "Cursor Auto" } }\n }\n }\n}\n' > "$(OPENCODE_CONFIG)"; \ + echo "已建立 $(OPENCODE_CONFIG)"; \ + elif [ -n "$(API_KEY)" ]; then \ + jq '.provider.cursor.options.baseURL = "http://$(HOST):$(PORT)/v1" | .provider.cursor.options.apiKey = "$(API_KEY)"' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \ + echo "已更新 $(OPENCODE_CONFIG)(baseURL → http://$(HOST):$(PORT)/v1,apiKey 已設定)"; \ + else \ + jq '.provider.cursor.options.baseURL = "http://$(HOST):$(PORT)/v1"' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \ + echo "已更新 $(OPENCODE_CONFIG)(baseURL → http://$(HOST):$(PORT)/v1)"; \ + fi + +## 啟動代理並用 curl 同步模型列表到 opencode.json +opencode-models: opencode + @echo "啟動代理以取得模型列表..." + @set -a && . ./$(ENV_FILE) 2>/dev/null; set +a; \ + ./cursor-api-proxy & PID=$$!; \ + sleep 2; \ + MODELS=$$(curl -s http://$(HOST):$(PORT)/v1/models | jq '[.data[].id]'); \ + kill $$PID 2>/dev/null; wait $$PID 2>/dev/null; \ + if [ -n "$$MODELS" ] && [ "$$MODELS" != "null" ]; then \ + jq --argjson ids "$$MODELS" 'reduce $ids[] as $id (.; .provider.cursor.models[$id] = { name: $id })' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \ + echo "已同步模型列表到 $(OPENCODE_CONFIG)"; \ + else \ + echo "無法取得模型列表,請確認代理已啟動"; \ + fi + +## 編譯並用 pm2 啟動 +pm2: build + @if [ -f "$(ENV_FILE)" ]; then \ + env $$(cat $(ENV_FILE) | grep -v '^#' | xargs) CURSOR_BRIDGE_HOST=$(HOST) CURSOR_BRIDGE_PORT=$(PORT) pm2 start ./cursor-api-proxy --name cursor-api-proxy --update-env; \ + else \ + CURSOR_BRIDGE_HOST=$(HOST) CURSOR_BRIDGE_PORT=$(PORT) pm2 start ./cursor-api-proxy --name cursor-api-proxy; \ + fi + @pm2 save + @echo "pm2 已啟動 cursor-api-proxy(http://$(HOST):$(PORT))" + +## 用 pm2 啟動 OpenCode 代理(設定 + 啟動一步完成) +pm2-opencode: opencode pm2 + @echo "OpenCode 設定已更新並用 pm2 啟動代理" + +## 編譯並用 pm2 啟動 + 設定 Claude Code 環境變數 +pm2-claude-code: pm2 + @echo "" + @echo "Claude Code 設定:將以下指令加入你的 shell 啟動檔(~/.bashrc 或 ~/.zshrc):" + @echo "" + @echo " export ANTHROPIC_BASE_URL=http://$(HOST):$(PORT)" + @echo " export ANTHROPIC_API_KEY=$(if $(API_KEY),$(API_KEY),dummy-key)" + @echo "" + @echo "或在當前 shell 執行:" + @echo "" + @echo " export ANTHROPIC_BASE_URL=http://$(HOST):$(PORT)" + @echo " export ANTHROPIC_API_KEY=$(if $(API_KEY),$(API_KEY),dummy-key)" + @echo " claude" + @echo "" + +## 停止 pm2 中的代理 +pm2-stop: + pm2 stop cursor-api-proxy 2>/dev/null || echo "cursor-api-proxy 未在執行" + +## 查看 pm2 日誌 +pm2-logs: + pm2 logs cursor-api-proxy + ## 顯示說明 help: @echo "可用目標:" - @echo " make env 產生 .env(先在 Makefile 頂端填好變數)" - @echo " make build 編譯 cursor-api-proxy 二進位檔" - @echo " make run 編譯並載入 .env 執行" - @echo " make clean 刪除二進位檔與 .env" + @echo " make env 產生 .env(先在 Makefile 頂端填好變數)" + @echo " make build 編譯 cursor-api-proxy 二進位檔" + @echo " make run 編譯並載入 .env 執行" + @echo " make pm2 編譯並用 pm2 啟動代理" + @echo " make pm2-stop 停止 pm2 中的代理" + @echo " make pm2-logs 查看 pm2 日誌" + @echo " make pm2-claude-code 啟動代理 + 輸出 Claude Code 設定指令" + @echo " make opencode 編譯並設定 OpenCode(更新 opencode.json)" + @echo " make pm2-opencode 設定 OpenCode + 啟動代理" + @echo " make opencode-models 編譯、設定 OpenCode 並同步模型列表" + @echo " make clean 刪除二進位檔與 .env" @echo "" @echo "覆寫範例:" @echo " make env PORT=9000 API_KEY=mysecret TIMEOUT_MS=60000" + @echo " make pm2-claude-code PORT=8765 API_KEY=mykey" + @echo " make pm2-opencode PORT=8765" diff --git a/README.md b/README.md index a4f8554..edaa297 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,21 @@ # Cursor API Proxy -[English](./README.md) | 繁體中文 - 一個讓你可以透過標準 OpenAI/Anthropic API 格式存取 Cursor AI 編輯器的代理伺服器。 +可以把 Cursor 的模型無縫接入 **Claude Code**、**OpenCode** 或任何支援 OpenAI/Anthropic API 的工具。 + ## 功能特色 - **API 相容**:支援 OpenAI 格式和 Anthropic 格式的 API 呼叫 - **多帳號管理**:支援新增、移除、切換多個 Cursor 帳號 +- **動態模型對映**:自動將 Cursor CLI 支援的所有模型轉換為 `claude-*` 格式,供 Claude Code / OpenCode 使用 - **Tailscale 支援**:可綁定到 `0.0.0.0` 供區域網路存取 - **HWID 重置**:內建反偵測功能,可重置機器識別碼 - **連線池**:最佳化的連線管理 -## 安裝 +## 快速開始 + +### 1. 建置 ```bash git clone https://github.com/your-repo/cursor-api-proxy-go.git @@ -20,20 +23,179 @@ cd cursor-api-proxy-go go build -o cursor-api-proxy . ``` -## 使用方式 +### 2. 登入 Cursor 帳號 -### 啟動伺服器 +```bash +./cursor-api-proxy login myaccount +``` + +### 3. 啟動伺服器 ```bash ./cursor-api-proxy ``` -預設監聽 `127.0.0.1:8080`。 +預設監聽 `127.0.0.1:8765`。 + +### 4. 設定 API Key(選用) + +如果需要外部存取,建議設定 API Key: + +```bash +export CURSOR_BRIDGE_API_KEY=my-secret-key +./cursor-api-proxy +``` + +--- + +## 接入 Claude Code + +Claude Code 使用 Anthropic SDK,透過環境變數設定即可改用你的代理。 + +### 步驟 + +```bash +# 1. 啟動代理(確認已在背景執行) +./cursor-api-proxy & + +# 2. 設定環境變數,讓 Claude Code 改用你的代理 +export ANTHROPIC_BASE_URL=http://127.0.0.1:8765 + +# 3. 如果代理有設定 CURSOR_BRIDGE_API_KEY,這裡填相同的值;沒有的話隨便填 +export ANTHROPIC_API_KEY=my-secret-key + +# 4. 啟動 Claude Code +claude +``` + +### 切換模型 + +在 Claude Code 中輸入 `/model`,即可看到你 Cursor CLI 支援的所有模型。 + +代理會自動將 Cursor 模型 ID 轉換為 `claude-*` 格式: + +| Cursor CLI 模型 | Claude Code 看到的 | +|---|---| +| `opus-4.6` | `claude-opus-4-6` | +| `sonnet-4.6` | `claude-sonnet-4-6` | +| `opus-4.5-thinking` | `claude-opus-4-5-thinking` | +| `sonnet-4.7` (未來新增) | `claude-sonnet-4-7` (自動生成) | + +### 也可直接指定模型 + +```bash +claude --model claude-sonnet-4-6 +``` + +--- + +## 接入 OpenCode + +OpenCode 透過 `~/.config/opencode/opencode.json` 設定 provider,使用 OpenAI 相容格式。 + +### 步驟 + +1. 啟動代理伺服器 + +```bash +./cursor-api-proxy & +``` + +2. 編輯 `~/.config/opencode/opencode.json`,在 `provider` 中新增一個 provider: + +```json +{ + "provider": { + "cursor": { + "npm": "@ai-sdk/openai-compatible", + "name": "Cursor Agent", + "options": { + "baseURL": "http://127.0.0.1:8765/v1", + "apiKey": "unused" + }, + "models": { + "auto": { "name": "Cursor Auto" }, + "sonnet-4.6": { "name": "Sonnet 4.6" }, + "opus-4.6": { "name": "Opus 4.6" }, + "opus-4.6-thinking": { "name": "Opus 4.6 Thinking" } + } + } + } +} +``` + +若代理有設定 `CURSOR_BRIDGE_API_KEY`,把 `apiKey` 改成相同的值。 + +3. 在 `opencode.json` 中設定預設模型: + +```json +{ + "model": "cursor/sonnet-4.6" +} +``` + +### 查看可用模型 + +```bash +curl http://127.0.0.1:8765/v1/models | jq '.data[].id' +``` + +代理的 `/v1/models` 端點會回傳 Cursor CLI 目前支援的所有模型 ID,把結果加到 `opencode.json` 的 `models` 中即可。 + +### 在 OpenCode 中切換模型 + +OpenCode 的模型切換使用 `/model` 指令,從 `opencode.json` 的 `models` 清單中選擇。 + +--- + +## 模型對映原理 + +兩端使用不同的模型 ID 格式: + +``` +Claude Code OpenCode + │ │ + │ claude-opus-4-6 │ opus-4.6 (Cursor 原生 ID) + │ (Anthropic alias) │ (OpenAI 相容格式) + ▼ ▼ +┌──────────────────────┐ ┌──────────────────────┐ +│ ResolveToCursorModel │ │ 直接傳給代理 │ +│ claude-opus-4-6 │ │ opus-4.6 │ +│ ↓ │ │ ↓ │ +│ opus-4.6 │ │ opus-4.6 │ +└──────────────────────┘ └──────────────────────┘ + │ │ + └──────────┬───────────┘ + ▼ + Cursor CLI (agent --model opus-4.6) +``` + +### Anthropic 模型對映表(Claude Code 使用) + +| Cursor 模型 | Anthropic ID | 說明 | +|---|---|---| +| `opus-4.6` | `claude-opus-4-6` | Claude 4.6 Opus | +| `opus-4.6-thinking` | `claude-opus-4-6-thinking` | Claude 4.6 Opus (Thinking) | +| `sonnet-4.6` | `claude-sonnet-4-6` | Claude 4.6 Sonnet | +| `sonnet-4.6-thinking` | `claude-sonnet-4-6-thinking` | Claude 4.6 Sonnet (Thinking) | +| `opus-4.5` | `claude-opus-4-5` | Claude 4.5 Opus | +| `opus-4.5-thinking` | `claude-opus-4-5-thinking` | Claude 4.5 Opus (Thinking) | +| `sonnet-4.5` | `claude-sonnet-4-5` | Claude 4.5 Sonnet | +| `sonnet-4.5-thinking` | `claude-sonnet-4-5-thinking` | Claude 4.5 Sonnet (Thinking) | + +### 動態對映(自動) + +Cursor CLI 新增模型時(如 `opus-4.7`、`sonnet-5.0`),代理自動生成對應的 `claude-*` ID,無需手動更新。 + +規則:`-.` → `claude---` + +--- + +## 帳號管理 ### 登入帳號 ```bash -# 登入帳號 ./cursor-api-proxy login myaccount # 使用代理登入 @@ -62,21 +224,25 @@ go build -o cursor-api-proxy . ./cursor-api-proxy reset-hwid --deep-clean ``` -### 其他選項 +### 啟動選項 | 選項 | 說明 | |------|------| | `--tailscale` | 綁定到 `0.0.0.0` 供區域網路存取 | | `-h, --help` | 顯示說明 | +--- + ## API 端點 | 端點 | 方法 | 說明 | |------|------|------| -| `http://127.0.0.1:8080/v1/chat/completions` | POST | OpenAI 格式聊天完成 | -| `http://127.0.0.1:8080/v1/models` | GET | 列出可用模型 | -| `http://127.0.0.1:8080/v1/chat/messages` | POST | Anthropic 格式聊天 | -| `http://127.0.0.1:8080/health` | GET | 健康檢查 | +| `http://127.0.0.1:8765/v1/chat/completions` | POST | OpenAI 格式聊天完成 | +| `http://127.0.0.1:8765/v1/models` | GET | 列出可用模型 | +| `http://127.0.0.1:8765/v1/chat/messages` | POST | Anthropic 格式聊天 | +| `http://127.0.0.1:8765/health` | GET | 健康檢查 | + +--- ## 環境變數 @@ -127,17 +293,27 @@ go build -o cursor-api-proxy . | `CURSOR_BRIDGE_WIN_CMDLINE_MAX` | `30000` | Windows 命令列最大長度(4096–32700) | | `COMSPEC` | `cmd.exe` | Windows 命令直譯器路徑 | +--- + ## 常見問題 -**Q: 為什麼需要登入帳號?** +**Q: 為什麼需要登入帳號?** A: Cursor API 需要驗證才能使用,請先登入你的 Cursor 帳號。 -**Q: 如何處理被BAN的問題?** +**Q: 如何處理被BAN的問題?** A: 使用 `reset-hwid` 命令重置機器識別碼,加上 `--deep-clean` 進行更徹底的清理。 -**Q: 可以在其他設備上使用嗎?** +**Q: 可以在其他設備上使用嗎?** A: 可以,使用 `--tailscale` 選項啟動伺服器,然後透過區域網路 IP 存取。 +**Q: 模型列表多久更新一次?** +A: 每次呼叫 `GET /v1/models` 時,代理會即時呼叫 Cursor CLI 的 `--list-models` 取得最新模型,並自動生成對應的 `claude-*` ID。 + +**Q: 未來 Cursor 新增模型怎麼辦?** +A: 不用改任何東西。只要新模型符合 `-.` 命名規則,代理會自動生成對應的 `claude-*` ID。 + +--- + ## 授權 MIT License diff --git a/internal/agent/cmdargs.go b/internal/agent/cmdargs.go index aeab0b7..9b4cd2e 100644 --- a/internal/agent/cmdargs.go +++ b/internal/agent/cmdargs.go @@ -3,7 +3,7 @@ package agent import "cursor-api-proxy/internal/config" func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, stream bool) []string { - args := []string{"--print"} + args := []string{"--print", "--plan"} if cfg.ApproveMcps { args = append(args, "--approve-mcps") } @@ -13,7 +13,6 @@ func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, st if cfg.ChatOnlyWorkspace { args = append(args, "--trust") } - args = append(args, "--mode", "ask") args = append(args, "--workspace", workspaceDir) args = append(args, "--model", model) if stream { diff --git a/internal/anthropic/anthropic.go b/internal/anthropic/anthropic.go index fb1ac95..1f29404 100644 --- a/internal/anthropic/anthropic.go +++ b/internal/anthropic/anthropic.go @@ -2,6 +2,8 @@ package anthropic import ( "cursor-api-proxy/internal/openai" + "encoding/json" + "fmt" "strings" ) @@ -83,6 +85,44 @@ func anthropicBlockToText(p interface{}) string { return "[Document: " + title + "]" } return "[Document]" + case "tool_use": + name, _ := v["name"].(string) + id, _ := v["id"].(string) + input := v["input"] + inputJSON, _ := json.Marshal(input) + if inputJSON == nil { + inputJSON = []byte("{}") + } + tag := fmt.Sprintf("\n{\"name\": \"%s\", \"arguments\": %s}\n", name, string(inputJSON)) + if id != "" { + tag = fmt.Sprintf("[tool_use_id=%s] ", id) + tag + } + return tag + case "tool_result": + toolUseID, _ := v["tool_use_id"].(string) + content := v["content"] + var contentText string + switch c := content.(type) { + case string: + contentText = c + case []interface{}: + var parts []string + for _, block := range c { + if bm, ok := block.(map[string]interface{}); ok { + if bm["type"] == "text" { + if t, ok := bm["text"].(string); ok { + parts = append(parts, t) + } + } + } + } + contentText = strings.Join(parts, "\n") + } + label := "Tool result" + if toolUseID != "" { + label += " [id=" + toolUseID + "]" + } + return label + ": " + contentText } } return "" diff --git a/internal/handlers/anthropic_handler.go b/internal/handlers/anthropic_handler.go index 9691a36..e6f26da 100644 --- a/internal/handlers/anthropic_handler.go +++ b/internal/handlers/anthropic_handler.go @@ -1,6 +1,7 @@ package handlers import ( + "context" "cursor-api-proxy/internal/agent" "cursor-api-proxy/internal/anthropic" "cursor-api-proxy/internal/config" @@ -11,9 +12,11 @@ import ( "cursor-api-proxy/internal/parser" "cursor-api-proxy/internal/pool" "cursor-api-proxy/internal/sanitize" + "cursor-api-proxy/internal/toolcall" "cursor-api-proxy/internal/winlimit" "cursor-api-proxy/internal/workspace" "encoding/json" + "strings" "fmt" "net/http" "time" @@ -21,7 +24,7 @@ import ( "github.com/google/uuid" ) -func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { +func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { var req anthropic.MessagesRequest if err := json.Unmarshal([]byte(rawBody), &req); err != nil { httputil.WriteJSON(w, 400, map[string]interface{}{ @@ -33,13 +36,11 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config. requested := openai.NormalizeModelID(req.Model) model := ResolveModel(requested, lastModelRef, cfg) - // Parse system from raw body to handle both string and array var rawMap map[string]interface{} _ = json.Unmarshal([]byte(rawBody), &rawMap) cleanSystem := sanitize.SanitizeSystem(req.System) - // SanitizeMessages expects []interface{} rawMessages := make([]interface{}, len(req.Messages)) for i, m := range req.Messages { rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content} @@ -103,6 +104,15 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config. fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream) fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir) + if cfg.Verbose { + if len(prompt) > 200 { + logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...") + } else { + logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt) + } + logger.LogDebug("cmd_args=%v", fit.Args) + } + if !fit.OK { httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"type": "api_error", "message": fit.Error}, @@ -110,8 +120,7 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config. return } if fit.Truncated { - fmt.Printf("[%s] Windows: prompt truncated (%d -> %d chars).\n", - time.Now().UTC().Format(time.RFC3339), fit.OriginalLength, fit.FinalPromptLength) + logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength) } cmdArgs := fit.Args @@ -122,6 +131,12 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config. truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} } + hasTools := len(req.Tools) > 0 + var toolNames map[string]bool + if hasTools { + toolNames = toolcall.CollectToolNames(req.Tools) + } + if req.Stream { httputil.WriteSSEHeaders(w, truncatedHeaders) flusher, _ := w.(http.Flusher) @@ -134,6 +149,11 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config. } } + var accumulated string + var accumulatedThinking string + var chunkNum int + var p parser.Parser + writeEvent(map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ @@ -144,74 +164,252 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config. "content": []interface{}{}, }, }) - writeEvent(map[string]interface{}{ - "type": "content_block_start", - "index": 0, - "content_block": map[string]string{"type": "text", "text": ""}, - }) - var accumulated string - parseLine := parser.CreateStreamParser( - func(text string) { - accumulated += text - writeEvent(map[string]interface{}{ - "type": "content_block_delta", - "index": 0, - "delta": map[string]string{"type": "text_delta", "text": text}, - }) - }, - func() { - logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) - writeEvent(map[string]interface{}{"type": "content_block_stop", "index": 0}) - writeEvent(map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, - "usage": map[string]int{"output_tokens": 0}, - }) - writeEvent(map[string]interface{}{"type": "message_stop"}) - }, - ) + if hasTools { + // tools 模式:先累積所有內容,完成後再一次性輸出(因為 tool_calls 需要完整解析) + p = parser.CreateStreamParserWithThinking( + func(text string) { + accumulated += text + chunkNum++ + logger.LogStreamChunk(model, text, chunkNum) + }, + func(thinking string) { + accumulatedThinking += thinking + }, + func() { + logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) + parsed := toolcall.ExtractToolCalls(accumulated, toolNames) - configDir := pool.GetNextAccountConfigDir() + blockIndex := 0 + if accumulatedThinking != "" { + writeEvent(map[string]interface{}{ + "type": "content_block_start", "index": blockIndex, + "content_block": map[string]string{"type": "thinking", "thinking": ""}, + }) + writeEvent(map[string]interface{}{ + "type": "content_block_delta", "index": blockIndex, + "delta": map[string]string{"type": "thinking_delta", "thinking": accumulatedThinking}, + }) + writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) + blockIndex++ + } + + if parsed.HasToolCalls() { + if parsed.TextContent != "" { + writeEvent(map[string]interface{}{ + "type": "content_block_start", "index": blockIndex, + "content_block": map[string]string{"type": "text", "text": ""}, + }) + writeEvent(map[string]interface{}{ + "type": "content_block_delta", "index": blockIndex, + "delta": map[string]string{"type": "text_delta", "text": parsed.TextContent}, + }) + writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) + blockIndex++ + } + for _, tc := range parsed.ToolCalls { + toolID := "toolu_" + uuid.New().String()[:12] + var inputObj interface{} + _ = json.Unmarshal([]byte(tc.Arguments), &inputObj) + if inputObj == nil { + inputObj = map[string]interface{}{} + } + writeEvent(map[string]interface{}{ + "type": "content_block_start", "index": blockIndex, + "content_block": map[string]interface{}{ + "type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{}, + }, + }) + writeEvent(map[string]interface{}{ + "type": "content_block_delta", "index": blockIndex, + "delta": map[string]interface{}{ + "type": "input_json_delta", "partial_json": tc.Arguments, + }, + }) + writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) + blockIndex++ + } + writeEvent(map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil}, + "usage": map[string]int{"output_tokens": 0}, + }) + writeEvent(map[string]interface{}{"type": "message_stop"}) + } else { + writeEvent(map[string]interface{}{ + "type": "content_block_start", "index": blockIndex, + "content_block": map[string]string{"type": "text", "text": ""}, + }) + if accumulated != "" { + writeEvent(map[string]interface{}{ + "type": "content_block_delta", "index": blockIndex, + "delta": map[string]string{"type": "text_delta", "text": accumulated}, + }) + } + writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) + writeEvent(map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, + "usage": map[string]int{"output_tokens": 0}, + }) + writeEvent(map[string]interface{}{"type": "message_stop"}) + } + }, + ) + } else { + // 非 tools 模式:即時串流 thinking 和 text + // blockCount 追蹤已開啟的 block 數量 + // thinkingOpen 代表 thinking block 是否已開啟且尚未關閉 + // textOpen 代表 text block 是否已開啟且尚未關閉 + blockCount := 0 + thinkingOpen := false + textOpen := false + + p = parser.CreateStreamParserWithThinking( + func(text string) { + accumulated += text + chunkNum++ + logger.LogStreamChunk(model, text, chunkNum) + // 若 thinking block 尚未關閉,先關閉它 + if thinkingOpen { + writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) + thinkingOpen = false + } + // 若 text block 尚未開啟,先開啟它 + if !textOpen { + writeEvent(map[string]interface{}{ + "type": "content_block_start", + "index": blockCount, + "content_block": map[string]string{"type": "text", "text": ""}, + }) + textOpen = true + blockCount++ + } + writeEvent(map[string]interface{}{ + "type": "content_block_delta", + "index": blockCount - 1, + "delta": map[string]string{"type": "text_delta", "text": text}, + }) + }, + func(thinking string) { + accumulatedThinking += thinking + chunkNum++ + // 若 thinking block 尚未開啟,先開啟它 + if !thinkingOpen { + writeEvent(map[string]interface{}{ + "type": "content_block_start", + "index": blockCount, + "content_block": map[string]string{"type": "thinking", "thinking": ""}, + }) + thinkingOpen = true + blockCount++ + } + writeEvent(map[string]interface{}{ + "type": "content_block_delta", + "index": blockCount - 1, + "delta": map[string]string{"type": "thinking_delta", "thinking": thinking}, + }) + }, + func() { + logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) + // 關閉尚未關閉的 thinking block + if thinkingOpen { + writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) + thinkingOpen = false + } + // 若 text block 尚未開啟(全部都是 thinking,沒有 text),開啟並立即關閉空的 text block + if !textOpen { + writeEvent(map[string]interface{}{ + "type": "content_block_start", + "index": blockCount, + "content_block": map[string]string{"type": "text", "text": ""}, + }) + blockCount++ + } + // 關閉 text block + writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) + writeEvent(map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, + "usage": map[string]int{"output_tokens": 0}, + }) + writeEvent(map[string]interface{}{"type": "message_stop"}) + }, + ) + } + + configDir := ph.GetNextConfigDir() logger.LogAccountAssigned(configDir) - pool.ReportRequestStart(configDir) + ph.ReportRequestStart(configDir) + logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) streamStart := time.Now().UnixMilli() ctx := r.Context() - result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx) + wrappedParser := func(line string) { + logger.LogRawLine(line) + p.Parse(line) + } + result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx) + + // agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾 + if ctx.Err() == nil { + p.Flush() + } latencyMs := time.Now().UnixMilli() - streamStart - pool.ReportRequestEnd(configDir) + ph.ReportRequestEnd(configDir) - if err == nil && isRateLimited(result.Stderr) { - pool.ReportRateLimit(configDir, 60000) + if ctx.Err() == context.DeadlineExceeded { + logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) + } else if ctx.Err() == context.Canceled { + logger.LogClientDisconnect(method, pathname, model, latencyMs) + } else if err == nil && isRateLimited(result.Stderr) { + ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) } - if err != nil || result.Code != 0 { - pool.ReportRequestError(configDir, latencyMs) + + if err != nil || (result.Code != 0 && ctx.Err() == nil) { + ph.ReportRequestError(configDir, latencyMs) if err != nil { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) } else { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr) } - } else { - pool.ReportRequestSuccess(configDir, latencyMs) + logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) + } else if ctx.Err() == nil { + ph.ReportRequestSuccess(configDir, latencyMs) + logger.LogRequestDone(method, pathname, model, latencyMs, 0) } - logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + logger.LogAccountStats(cfg.Verbose, ph.GetStats()) return } - configDir := pool.GetNextAccountConfigDir() + configDir := ph.GetNextConfigDir() logger.LogAccountAssigned(configDir) - pool.ReportRequestStart(configDir) + ph.ReportRequestStart(configDir) + logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false) syncStart := time.Now().UnixMilli() out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context()) syncLatency := time.Now().UnixMilli() - syncStart - pool.ReportRequestEnd(configDir) + ph.ReportRequestEnd(configDir) + + ctx := r.Context() + if ctx.Err() == context.DeadlineExceeded { + logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) + httputil.WriteJSON(w, 504, map[string]interface{}{ + "error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)}, + }, nil) + return + } + if ctx.Err() == context.Canceled { + logger.LogClientDisconnect(method, pathname, model, syncLatency) + return + } if err != nil { - pool.ReportRequestError(configDir, syncLatency) - logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + ph.ReportRequestError(configDir, syncLatency) + logger.LogAccountStats(cfg.Verbose, ph.GetStats()) + logger.LogRequestDone(method, pathname, model, syncLatency, -1) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"type": "api_error", "message": err.Error()}, }, nil) @@ -219,32 +417,67 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config. } if isRateLimited(out.Stderr) { - pool.ReportRateLimit(configDir, 60000) + ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr)) } if out.Code != 0 { - pool.ReportRequestError(configDir, syncLatency) - logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + ph.ReportRequestError(configDir, syncLatency) + logger.LogAccountStats(cfg.Verbose, ph.GetStats()) errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr) + logger.LogRequestDone(method, pathname, model, syncLatency, out.Code) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"type": "api_error", "message": errMsg}, }, nil) return } - pool.ReportRequestSuccess(configDir, syncLatency) - content := trimSpace(out.Stdout) + ph.ReportRequestSuccess(configDir, syncLatency) + content := strings.TrimSpace(out.Stdout) logger.LogTrafficResponse(cfg.Verbose, model, content, false) - logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + logger.LogAccountStats(cfg.Verbose, ph.GetStats()) + logger.LogRequestDone(method, pathname, model, syncLatency, 0) + + if hasTools { + parsed := toolcall.ExtractToolCalls(content, toolNames) + if parsed.HasToolCalls() { + var contentBlocks []map[string]interface{} + if parsed.TextContent != "" { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "text", "text": parsed.TextContent, + }) + } + for _, tc := range parsed.ToolCalls { + toolID := "toolu_" + uuid.New().String()[:12] + var inputObj interface{} + _ = json.Unmarshal([]byte(tc.Arguments), &inputObj) + if inputObj == nil { + inputObj = map[string]interface{}{} + } + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj, + }) + } + httputil.WriteJSON(w, 200, map[string]interface{}{ + "id": msgID, + "type": "message", + "role": "assistant", + "content": contentBlocks, + "model": model, + "stop_reason": "tool_use", + "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, + }, truncatedHeaders) + return + } + } httputil.WriteJSON(w, 200, map[string]interface{}{ - "id": msgID, - "type": "message", - "role": "assistant", - "content": []map[string]string{{"type": "text", "text": content}}, - "model": model, + "id": msgID, + "type": "message", + "role": "assistant", + "content": []map[string]string{{"type": "text", "text": content}}, + "model": model, "stop_reason": "end_turn", - "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, + "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, }, truncatedHeaders) } diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go index 0a6d0f5..9d67324 100644 --- a/internal/handlers/chat.go +++ b/internal/handlers/chat.go @@ -1,6 +1,7 @@ package handlers import ( + "context" "cursor-api-proxy/internal/agent" "cursor-api-proxy/internal/config" "cursor-api-proxy/internal/httputil" @@ -10,24 +11,37 @@ import ( "cursor-api-proxy/internal/parser" "cursor-api-proxy/internal/pool" "cursor-api-proxy/internal/sanitize" + "cursor-api-proxy/internal/toolcall" "cursor-api-proxy/internal/winlimit" "cursor-api-proxy/internal/workspace" "encoding/json" "fmt" "net/http" "regexp" + "strconv" + "strings" "time" "github.com/google/uuid" ) var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`) +var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`) func isRateLimited(stderr string) bool { return rateLimitRe.MatchString(stderr) } -func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { +func extractRetryAfterMs(stderr string) int64 { + if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 { + if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 { + return secs * 1000 + } + } + return 60000 +} + +func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { var bodyMap map[string]interface{} if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil { httputil.WriteJSON(w, 400, map[string]interface{}{ @@ -86,9 +100,22 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.Br headerWs := r.Header.Get("x-cursor-workspace") ws := workspace.ResolveWorkspace(cfg, headerWs) + promptLen := len(prompt) + if cfg.Verbose { + if promptLen > 200 { + logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200]) + } else { + logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt) + } + } + fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream) fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir) + if cfg.Verbose { + logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args) + } + if !fit.OK { httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"}, @@ -108,90 +135,261 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.Br truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} } + hasTools := len(tools) > 0 || len(funcs) > 0 + var toolNames map[string]bool + if hasTools { + toolNames = toolcall.CollectToolNames(tools) + for _, f := range funcs { + if fm, ok := f.(map[string]interface{}); ok { + if name, ok := fm["name"].(string); ok { + toolNames[name] = true + } + } + } + } + if isStream { httputil.WriteSSEHeaders(w, truncatedHeaders) flusher, _ := w.(http.Flusher) var accumulated string - parseLine := parser.CreateStreamParser( - func(text string) { - accumulated += text - chunk := map[string]interface{}{ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, - }, - } - data, _ := json.Marshal(chunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() - } - }, - func() { - logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) - stopChunk := map[string]interface{}{ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, - }, - } - data, _ := json.Marshal(stopChunk) - fmt.Fprintf(w, "data: %s\n\n", data) - fmt.Fprintf(w, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - }, - ) + var chunkNum int + var p parser.Parser - configDir := pool.GetNextAccountConfigDir() + // toolCallMarkerRe 偵測 tool call 開頭標記,一旦出現就停止即時輸出並進入累積模式 + toolCallMarkerRe := regexp.MustCompile(`|`) + if hasTools { + var toolCallMode bool // 是否已進入 tool call 累積模式 + p = parser.CreateStreamParserWithThinking( + func(text string) { + accumulated += text + chunkNum++ + logger.LogStreamChunk(model, text, chunkNum) + if toolCallMode { + // 已進入累積模式,不即時輸出 + return + } + if toolCallMarkerRe.MatchString(text) { + // 偵測到 tool call 標記,切換為累積模式 + toolCallMode = true + return + } + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + }, + func(_ string) {}, // thinking ignored in tools mode + func() { + logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) + parsed := toolcall.ExtractToolCalls(accumulated, toolNames) + + if parsed.HasToolCalls() { + if parsed.TextContent != "" && toolCallMode { + // 已有部分 text 被即時輸出,只補發剩餘的 + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + } + for i, tc := range parsed.ToolCalls { + callID := "call_" + uuid.New().String()[:8] + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{ + { + "index": i, + "id": callID, + "type": "function", + "function": map[string]interface{}{ + "name": tc.Name, + "arguments": tc.Arguments, + }, + }, + }, + }, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + } + stopChunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"}, + }, + } + data, _ := json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + } else { + stopChunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, + }, + } + data, _ := json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + } + }, + ) + } else { + p = parser.CreateStreamParserWithThinking( + func(text string) { + accumulated += text + chunkNum++ + logger.LogStreamChunk(model, text, chunkNum) + chunk := map[string]interface{}{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + }, + func(thinking string) { + chunk := map[string]interface{}{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + }, + func() { + logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) + stopChunk := map[string]interface{}{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, + }, + } + data, _ := json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + }, + ) + } + + configDir := ph.GetNextConfigDir() logger.LogAccountAssigned(configDir) - pool.ReportRequestStart(configDir) + ph.ReportRequestStart(configDir) + logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) streamStart := time.Now().UnixMilli() ctx := r.Context() - result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx) + wrappedParser := func(line string) { + logger.LogRawLine(line) + p.Parse(line) + } + result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx) + + // agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾 + if ctx.Err() == nil { + p.Flush() + } latencyMs := time.Now().UnixMilli() - streamStart - pool.ReportRequestEnd(configDir) + ph.ReportRequestEnd(configDir) - if err == nil && isRateLimited(result.Stderr) { - pool.ReportRateLimit(configDir, 60000) + if ctx.Err() == context.DeadlineExceeded { + logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) + } else if ctx.Err() == context.Canceled { + logger.LogClientDisconnect(method, pathname, model, latencyMs) + } else if err == nil && isRateLimited(result.Stderr) { + ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) } if err != nil || (result.Code != 0 && ctx.Err() == nil) { - pool.ReportRequestError(configDir, latencyMs) + ph.ReportRequestError(configDir, latencyMs) if err != nil { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) } else { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr) } - } else { - pool.ReportRequestSuccess(configDir, latencyMs) + logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) + } else if ctx.Err() == nil { + ph.ReportRequestSuccess(configDir, latencyMs) + logger.LogRequestDone(method, pathname, model, latencyMs, 0) } - logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + logger.LogAccountStats(cfg.Verbose, ph.GetStats()) return } - configDir := pool.GetNextAccountConfigDir() + configDir := ph.GetNextConfigDir() logger.LogAccountAssigned(configDir) - pool.ReportRequestStart(configDir) + ph.ReportRequestStart(configDir) + logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false) syncStart := time.Now().UnixMilli() out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context()) syncLatency := time.Now().UnixMilli() - syncStart - pool.ReportRequestEnd(configDir) + ph.ReportRequestEnd(configDir) + + ctx := r.Context() + if ctx.Err() == context.DeadlineExceeded { + logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) + httputil.WriteJSON(w, 504, map[string]interface{}{ + "error": map[string]string{"message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs), "code": "timeout"}, + }, nil) + return + } + if ctx.Err() == context.Canceled { + logger.LogClientDisconnect(method, pathname, model, syncLatency) + return + } if err != nil { - pool.ReportRequestError(configDir, syncLatency) - logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + ph.ReportRequestError(configDir, syncLatency) + logger.LogAccountStats(cfg.Verbose, ph.GetStats()) + logger.LogRequestDone(method, pathname, model, syncLatency, -1) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"}, }, nil) @@ -199,23 +397,61 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.Br } if isRateLimited(out.Stderr) { - pool.ReportRateLimit(configDir, 60000) + ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr)) } if out.Code != 0 { - pool.ReportRequestError(configDir, syncLatency) - logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + ph.ReportRequestError(configDir, syncLatency) + logger.LogAccountStats(cfg.Verbose, ph.GetStats()) errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr) + logger.LogRequestDone(method, pathname, model, syncLatency, out.Code) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": errMsg, "code": "cursor_cli_error"}, }, nil) return } - pool.ReportRequestSuccess(configDir, syncLatency) - content := trimSpace(out.Stdout) + ph.ReportRequestSuccess(configDir, syncLatency) + content := strings.TrimSpace(out.Stdout) logger.LogTrafficResponse(cfg.Verbose, model, content, false) - logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats()) + logger.LogAccountStats(cfg.Verbose, ph.GetStats()) + logger.LogRequestDone(method, pathname, model, syncLatency, 0) + + if hasTools { + parsed := toolcall.ExtractToolCalls(content, toolNames) + if parsed.HasToolCalls() { + msg := map[string]interface{}{"role": "assistant"} + if parsed.TextContent != "" { + msg["content"] = parsed.TextContent + } else { + msg["content"] = nil + } + var tcArr []map[string]interface{} + for _, tc := range parsed.ToolCalls { + callID := "call_" + uuid.New().String()[:8] + tcArr = append(tcArr, map[string]interface{}{ + "id": callID, + "type": "function", + "function": map[string]interface{}{ + "name": tc.Name, + "arguments": tc.Arguments, + }, + }) + } + msg["tool_calls"] = tcArr + httputil.WriteJSON(w, 200, map[string]interface{}{ + "id": id, + "object": "chat.completion", + "created": created, + "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "message": msg, "finish_reason": "tool_calls"}, + }, + "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }, truncatedHeaders) + return + } + } httputil.WriteJSON(w, 200, map[string]interface{}{ "id": id, @@ -233,16 +469,3 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.Br }, truncatedHeaders) } -func trimSpace(s string) string { - result := "" - start := 0 - end := len(s) - for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') { - start++ - } - for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') { - end-- - } - result = s[start:end] - return result -} diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 5d26d13..c658e92 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -65,6 +65,11 @@ type TrafficMessage struct { Content string } +func LogDebug(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg) +} + func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) { fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset) fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n", @@ -76,6 +81,7 @@ func LogServerStart(version, scheme, host string, port int, cfg config.BridgeCon fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset) fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset) fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset) + fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset) flags := []string{} if cfg.Force { @@ -109,6 +115,46 @@ func LogShutdown(sig string) { fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset) } +func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) { + modeTag := fmt.Sprintf("%ssync%s", cDim, cReset) + if isStream { + modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset) + } + fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n", + ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag) +} + +func LogRequestDone(method, pathname, model string, latencyMs int64, code int) { + statusColor := cBGreen + if code != 0 { + statusColor = cRed + } + fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n", + ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset) +} + +func LogRequestTimeout(method, pathname, model string, timeoutMs int) { + fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n", + ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset) +} + +func LogClientDisconnect(method, pathname, model string, latencyMs int64) { + fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n", + ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset) +} + +func LogStreamChunk(model string, text string, chunkNum int) { + preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120) + fmt.Printf("%s %s▸%s #%d %s%s%s\n", + ts(), cDim, cReset, chunkNum, cWhite, preview, cReset) +} + +func LogRawLine(line string) { + preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200) + fmt.Printf("%s %s│%s %sraw%s %s\n", + ts(), cGray, cReset, cDim, cReset, preview) +} + func LogIncoming(method, pathname, remoteAddress string) { methodColor := cBCyan switch method { diff --git a/internal/models/cursormap.go b/internal/models/cursormap.go index 0f2e00e..bdcc568 100644 --- a/internal/models/cursormap.go +++ b/internal/models/cursormap.go @@ -1,26 +1,29 @@ package models -import "strings" +import ( + "regexp" + "strings" +) var anthropicToCursor = map[string]string{ - "claude-opus-4-6": "opus-4.6", - "claude-opus-4.6": "opus-4.6", - "claude-sonnet-4-6": "sonnet-4.6", - "claude-sonnet-4.6": "sonnet-4.6", - "claude-opus-4-5": "opus-4.5", - "claude-opus-4.5": "opus-4.5", - "claude-sonnet-4-5": "sonnet-4.5", - "claude-sonnet-4.5": "sonnet-4.5", - "claude-opus-4": "opus-4.6", - "claude-sonnet-4": "sonnet-4.6", - "claude-haiku-4-5-20251001": "sonnet-4.5", - "claude-haiku-4-5": "sonnet-4.5", - "claude-haiku-4-6": "sonnet-4.6", - "claude-haiku-4": "sonnet-4.5", - "claude-opus-4-6-thinking": "opus-4.6-thinking", - "claude-sonnet-4-6-thinking": "sonnet-4.6-thinking", - "claude-opus-4-5-thinking": "opus-4.5-thinking", - "claude-sonnet-4-5-thinking": "sonnet-4.5-thinking", + "claude-opus-4-6": "opus-4.6", + "claude-opus-4.6": "opus-4.6", + "claude-sonnet-4-6": "sonnet-4.6", + "claude-sonnet-4.6": "sonnet-4.6", + "claude-opus-4-5": "opus-4.5", + "claude-opus-4.5": "opus-4.5", + "claude-sonnet-4-5": "sonnet-4.5", + "claude-sonnet-4.5": "sonnet-4.5", + "claude-opus-4": "opus-4.6", + "claude-sonnet-4": "sonnet-4.6", + "claude-haiku-4-5-20251001": "sonnet-4.5", + "claude-haiku-4-5": "sonnet-4.5", + "claude-haiku-4-6": "sonnet-4.6", + "claude-haiku-4": "sonnet-4.5", + "claude-opus-4-6-thinking": "opus-4.6-thinking", + "claude-sonnet-4-6-thinking": "sonnet-4.6-thinking", + "claude-opus-4-5-thinking": "opus-4.5-thinking", + "claude-sonnet-4-5-thinking": "sonnet-4.5-thinking", } type ModelAlias struct { @@ -40,6 +43,40 @@ var cursorToAnthropicAlias = []ModelAlias{ {"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"}, } +// cursorModelPattern matches cursor model IDs like "opus-4.6", "sonnet-4.7-thinking". +var cursorModelPattern = regexp.MustCompile(`^([a-zA-Z]+)-(\d+)\.(\d+)(-thinking)?$`) + +// reverseDynamicPattern matches dynamically generated anthropic aliases +// like "claude-opus-4-7", "claude-sonnet-4-7-thinking". +var reverseDynamicPattern = regexp.MustCompile(`^claude-([a-zA-Z]+)-(\d+)-(\d+)(-thinking)?$`) + +func generateDynamicAlias(cursorID string) (AnthropicAlias, bool) { + m := cursorModelPattern.FindStringSubmatch(cursorID) + if m == nil { + return AnthropicAlias{}, false + } + family := m[1] + major := m[2] + minor := m[3] + thinking := m[4] + + anthropicID := "claude-" + family + "-" + major + "-" + minor + thinking + capFamily := strings.ToUpper(family[:1]) + family[1:] + name := capFamily + " " + major + "." + minor + if thinking == "-thinking" { + name += " (Thinking)" + } + return AnthropicAlias{ID: anthropicID, Name: name}, true +} + +func reverseDynamicAlias(anthropicID string) (string, bool) { + m := reverseDynamicPattern.FindStringSubmatch(anthropicID) + if m == nil { + return "", false + } + return m[1] + "-" + m[2] + "." + m[3] + m[4], true +} + func ResolveToCursorModel(requested string) string { if strings.TrimSpace(requested) == "" { return "" @@ -48,6 +85,9 @@ func ResolveToCursorModel(requested string) string { if v, ok := anthropicToCursor[key]; ok { return v } + if v, ok := reverseDynamicAlias(key); ok { + return v + } return strings.TrimSpace(requested) } @@ -61,11 +101,23 @@ func GetAnthropicModelAliases(availableCursorIDs []string) []AnthropicAlias { for _, id := range availableCursorIDs { set[id] = true } + + staticSet := make(map[string]bool, len(cursorToAnthropicAlias)) var result []AnthropicAlias for _, a := range cursorToAnthropicAlias { if set[a.CursorID] { + staticSet[a.CursorID] = true result = append(result, AnthropicAlias{ID: a.AnthropicID, Name: a.Name}) } } + + for _, id := range availableCursorIDs { + if staticSet[id] { + continue + } + if alias, ok := generateDynamicAlias(id); ok { + result = append(result, alias) + } + } return result } diff --git a/internal/models/cursormap_test.go b/internal/models/cursormap_test.go new file mode 100644 index 0000000..3b3d54b --- /dev/null +++ b/internal/models/cursormap_test.go @@ -0,0 +1,163 @@ +package models + +import "testing" + +func TestGetAnthropicModelAliases_StaticOnly(t *testing.T) { + aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.5"}) + if len(aliases) != 2 { + t.Fatalf("expected 2 aliases, got %d: %v", len(aliases), aliases) + } + ids := map[string]string{} + for _, a := range aliases { + ids[a.ID] = a.Name + } + if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" { + t.Errorf("unexpected name for claude-sonnet-4-6: %s", ids["claude-sonnet-4-6"]) + } + if ids["claude-opus-4-5"] != "Claude 4.5 Opus" { + t.Errorf("unexpected name for claude-opus-4-5: %s", ids["claude-opus-4-5"]) + } +} + +func TestGetAnthropicModelAliases_DynamicFallback(t *testing.T) { + aliases := GetAnthropicModelAliases([]string{"sonnet-4.7", "opus-5.0-thinking", "gpt-4o"}) + ids := map[string]string{} + for _, a := range aliases { + ids[a.ID] = a.Name + } + if ids["claude-sonnet-4-7"] != "Sonnet 4.7" { + t.Errorf("unexpected name for claude-sonnet-4-7: %s", ids["claude-sonnet-4-7"]) + } + if ids["claude-opus-5-0-thinking"] != "Opus 5.0 (Thinking)" { + t.Errorf("unexpected name for claude-opus-5-0-thinking: %s", ids["claude-opus-5-0-thinking"]) + } + if _, ok := ids["claude-gpt-4o"]; ok { + t.Errorf("gpt-4o should not generate a claude alias") + } +} + +func TestGetAnthropicModelAliases_Mixed(t *testing.T) { + aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.7", "gpt-4o"}) + ids := map[string]string{} + for _, a := range aliases { + ids[a.ID] = a.Name + } + // static entry keeps its custom name + if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" { + t.Errorf("static alias should keep original name, got: %s", ids["claude-sonnet-4-6"]) + } + // dynamic entry uses auto-generated name + if ids["claude-opus-4-7"] != "Opus 4.7" { + t.Errorf("dynamic alias name mismatch: %s", ids["claude-opus-4-7"]) + } +} + +func TestGetAnthropicModelAliases_UnknownPattern(t *testing.T) { + aliases := GetAnthropicModelAliases([]string{"some-unknown-model"}) + if len(aliases) != 0 { + t.Fatalf("expected 0 aliases for unknown pattern, got %d: %v", len(aliases), aliases) + } +} + +func TestResolveToCursorModel_Static(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"claude-opus-4-6", "opus-4.6"}, + {"claude-opus-4.6", "opus-4.6"}, + {"claude-sonnet-4-5", "sonnet-4.5"}, + {"claude-opus-4-6-thinking", "opus-4.6-thinking"}, + } + for _, tc := range tests { + got := ResolveToCursorModel(tc.input) + if got != tc.want { + t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestResolveToCursorModel_DynamicFallback(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"claude-opus-4-7", "opus-4.7"}, + {"claude-sonnet-5-0", "sonnet-5.0"}, + {"claude-opus-4-7-thinking", "opus-4.7-thinking"}, + {"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking"}, + } + for _, tc := range tests { + got := ResolveToCursorModel(tc.input) + if got != tc.want { + t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestResolveToCursorModel_Passthrough(t *testing.T) { + tests := []string{"sonnet-4.6", "gpt-4o", "custom-model"} + for _, input := range tests { + got := ResolveToCursorModel(input) + if got != input { + t.Errorf("ResolveToCursorModel(%q) = %q, want passthrough %q", input, got, input) + } + } +} + +func TestResolveToCursorModel_Empty(t *testing.T) { + if got := ResolveToCursorModel(""); got != "" { + t.Errorf("ResolveToCursorModel(\"\") = %q, want empty", got) + } + if got := ResolveToCursorModel(" "); got != "" { + t.Errorf("ResolveToCursorModel(\" \") = %q, want empty", got) + } +} + +func TestGenerateDynamicAlias(t *testing.T) { + tests := []struct { + input string + want AnthropicAlias + ok bool + }{ + {"opus-4.7", AnthropicAlias{"claude-opus-4-7", "Opus 4.7"}, true}, + {"sonnet-5.0-thinking", AnthropicAlias{"claude-sonnet-5-0-thinking", "Sonnet 5.0 (Thinking)"}, true}, + {"gpt-4o", AnthropicAlias{}, false}, + {"invalid", AnthropicAlias{}, false}, + } + for _, tc := range tests { + got, ok := generateDynamicAlias(tc.input) + if ok != tc.ok { + t.Errorf("generateDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok) + continue + } + if ok && (got.ID != tc.want.ID || got.Name != tc.want.Name) { + t.Errorf("generateDynamicAlias(%q) = {%q, %q}, want {%q, %q}", tc.input, got.ID, got.Name, tc.want.ID, tc.want.Name) + } + } +} + +func TestReverseDynamicAlias(t *testing.T) { + tests := []struct { + input string + want string + ok bool + }{ + {"claude-opus-4-7", "opus-4.7", true}, + {"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking", true}, + {"claude-opus-4-6", "opus-4.6", true}, + {"claude-opus-4.6", "", false}, + {"claude-haiku-4-5-20251001", "", false}, + {"some-model", "", false}, + } + for _, tc := range tests { + got, ok := reverseDynamicAlias(tc.input) + if ok != tc.ok { + t.Errorf("reverseDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok) + continue + } + if ok && got != tc.want { + t.Errorf("reverseDynamicAlias(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/internal/openai/openai.go b/internal/openai/openai.go index 120187b..86919a6 100644 --- a/internal/openai/openai.go +++ b/internal/openai/openai.go @@ -2,6 +2,7 @@ package openai import ( "encoding/json" + "fmt" "strings" ) @@ -129,10 +130,28 @@ func ToolsToSystemText(tools []interface{}, functions []interface{}) string { if b, err := json.MarshalIndent(p, "", " "); err == nil { params = string(b) } + } else if p := fn["input_schema"]; p != nil { + if b, err := json.MarshalIndent(p, "", " "); err == nil { + params = string(b) + } } lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params) } + lines = append(lines, "", + "When you want to call a tool, use this EXACT format:", + "", + "", + `{"name": "function_name", "arguments": {"param1": "value1"}}`, + "", + "", + "Rules:", + "- Write your reasoning BEFORE the tool call", + "- You may make multiple tool calls by using multiple blocks", + "- STOP writing after the last tag", + "- If no tool is needed, respond normally without tags", + ) + return strings.Join(lines, "\n") } @@ -152,18 +171,58 @@ func BuildPromptFromMessages(messages []interface{}) string { } role, _ := m["role"].(string) text := MessageContentToText(m["content"]) - if text == "" { - continue - } + switch role { case "system", "developer": - systemParts = append(systemParts, text) + if text != "" { + systemParts = append(systemParts, text) + } case "user": - convo = append(convo, "User: "+text) + if text != "" { + convo = append(convo, "User: "+text) + } case "assistant": - convo = append(convo, "Assistant: "+text) + toolCalls, _ := m["tool_calls"].([]interface{}) + if len(toolCalls) > 0 { + var parts []string + if text != "" { + parts = append(parts, text) + } + for _, tc := range toolCalls { + tcMap, ok := tc.(map[string]interface{}) + if !ok { + continue + } + fn, _ := tcMap["function"].(map[string]interface{}) + if fn == nil { + continue + } + name, _ := fn["name"].(string) + args, _ := fn["arguments"].(string) + if args == "" { + args = "{}" + } + parts = append(parts, fmt.Sprintf("\n{\"name\": \"%s\", \"arguments\": %s}\n", name, args)) + } + if len(parts) > 0 { + convo = append(convo, "Assistant: "+strings.Join(parts, "\n")) + } + } else if text != "" { + convo = append(convo, "Assistant: "+text) + } case "tool", "function": - convo = append(convo, "Tool: "+text) + name, _ := m["name"].(string) + toolCallID, _ := m["tool_call_id"].(string) + label := "Tool result" + if name != "" { + label = "Tool result (" + name + ")" + } + if toolCallID != "" { + label += " [id=" + toolCallID + "]" + } + if text != "" { + convo = append(convo, label+": "+text) + } } } diff --git a/internal/parser/stream.go b/internal/parser/stream.go index aa9dc98..aa13cb5 100644 --- a/internal/parser/stream.go +++ b/internal/parser/stream.go @@ -4,11 +4,24 @@ import "encoding/json" type StreamParser func(line string) -func CreateStreamParser(onText func(string), onDone func()) StreamParser { - accumulated := "" +type Parser struct { + Parse StreamParser + Flush func() +} + +// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking) +func CreateStreamParser(onText func(string), onDone func()) Parser { + return CreateStreamParserWithThinking(onText, nil, onDone) +} + +// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。 +// onThinking 可為 nil,表示忽略思考過程。 +func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser { + accumulatedText := "" + accumulatedThinking := "" done := false - return func(line string) { + parse := func(line string) { if done { return } @@ -18,8 +31,9 @@ func CreateStreamParser(onText func(string), onDone func()) StreamParser { Subtype string `json:"subtype"` Message *struct { Content []struct { - Type string `json:"type"` - Text string `json:"text"` + Type string `json:"type"` + Text string `json:"text"` + Thinking string `json:"thinking"` } `json:"content"` } `json:"message"` } @@ -29,27 +43,52 @@ func CreateStreamParser(onText func(string), onDone func()) StreamParser { } if obj.Type == "assistant" && obj.Message != nil { - text := "" + fullText := "" + fullThinking := "" for _, p := range obj.Message.Content { - if p.Type == "text" && p.Text != "" { - text += p.Text + switch p.Type { + case "text": + if p.Text != "" { + fullText += p.Text + } + case "thinking": + if p.Thinking != "" { + fullThinking += p.Thinking + } } } - if text == "" { - return - } - if text == accumulated { - return - } - if len(accumulated) > 0 && len(text) > len(accumulated) && text[:len(accumulated)] == accumulated { - delta := text[len(accumulated):] - if delta != "" { - onText(delta) + + // 處理思考過程 delta + if onThinking != nil && fullThinking != "" { + if fullThinking == accumulatedThinking { + // 重複的完整思考文字,跳過 + } else if len(fullThinking) > len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking { + delta := fullThinking[len(accumulatedThinking):] + onThinking(delta) + accumulatedThinking = fullThinking + } else { + onThinking(fullThinking) + accumulatedThinking += fullThinking } - accumulated = text + } + + // 處理一般文字 delta + if fullText == "" { + return + } + // 若此訊息文字等於已累積內容(重複的完整文字),跳過 + if fullText == accumulatedText { + return + } + // 若此訊息是已累積內容的延伸,只輸出新的 delta + if len(fullText) > len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText { + delta := fullText[len(accumulatedText):] + onText(delta) + accumulatedText = fullText } else { - onText(text) - accumulated += text + // 獨立的 token fragment(一般情況),直接輸出 + onText(fullText) + accumulatedText += fullText } } @@ -58,4 +97,13 @@ func CreateStreamParser(onText func(string), onDone func()) StreamParser { onDone() } } + + flush := func() { + if !done { + done = true + onDone() + } + } + + return Parser{Parse: parse, Flush: flush} } diff --git a/internal/parser/stream_test.go b/internal/parser/stream_test.go index b54f7c8..29d1b03 100644 --- a/internal/parser/stream_test.go +++ b/internal/parser/stream_test.go @@ -23,57 +23,57 @@ func makeResultLine() string { return string(b) } -func TestStreamParserIncrementalDeltas(t *testing.T) { +func TestStreamParserFragmentMode(t *testing.T) { + // cursor --stream-partial-output 模式:每個訊息是獨立 token fragment var texts []string - doneCount := 0 - parse := CreateStreamParser( - func(text string) { texts = append(texts, text) }, - func() { doneCount++ }, - ) - - parse(makeAssistantLine("Hello")) - if len(texts) != 1 || texts[0] != "Hello" { - t.Fatalf("expected [Hello], got %v", texts) - } - - parse(makeAssistantLine("Hello world")) - if len(texts) != 2 || texts[1] != " world" { - t.Fatalf("expected second call with ' world', got %v", texts) - } -} - -func TestStreamParserDeduplicatesFinalMessage(t *testing.T) { - var texts []string - parse := CreateStreamParser( + p := CreateStreamParser( func(text string) { texts = append(texts, text) }, func() {}, ) - parse(makeAssistantLine("Hi")) - parse(makeAssistantLine("Hi there")) - if len(texts) != 2 { - t.Fatalf("expected 2 calls, got %d: %v", len(texts), texts) + p.Parse(makeAssistantLine("你")) + p.Parse(makeAssistantLine("好!有")) + p.Parse(makeAssistantLine("什")) + p.Parse(makeAssistantLine("麼")) + + if len(texts) != 4 { + t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts) } - if texts[0] != "Hi" || texts[1] != " there" { + if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" { t.Fatalf("unexpected texts: %v", texts) } +} + +func TestStreamParserDeduplicatesFinalFullText(t *testing.T) { + // 最後一個訊息是完整的累積文字,應被跳過(去重) + var texts []string + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() {}, + ) + + p.Parse(makeAssistantLine("Hello")) + p.Parse(makeAssistantLine(" world")) + // 最後一個是完整累積文字,應被去重 + p.Parse(makeAssistantLine("Hello world")) - // Final duplicate: full accumulated text again - parse(makeAssistantLine("Hi there")) if len(texts) != 2 { - t.Fatalf("expected no new call after duplicate, got %d: %v", len(texts), texts) + t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts) + } + if texts[0] != "Hello" || texts[1] != " world" { + t.Fatalf("unexpected texts: %v", texts) } } func TestStreamParserCallsOnDone(t *testing.T) { var texts []string doneCount := 0 - parse := CreateStreamParser( + p := CreateStreamParser( func(text string) { texts = append(texts, text) }, func() { doneCount++ }, ) - parse(makeResultLine()) + p.Parse(makeResultLine()) if doneCount != 1 { t.Fatalf("expected onDone called once, got %d", doneCount) } @@ -85,13 +85,13 @@ func TestStreamParserCallsOnDone(t *testing.T) { func TestStreamParserIgnoresLinesAfterDone(t *testing.T) { var texts []string doneCount := 0 - parse := CreateStreamParser( + p := CreateStreamParser( func(text string) { texts = append(texts, text) }, func() { doneCount++ }, ) - parse(makeResultLine()) - parse(makeAssistantLine("late")) + p.Parse(makeResultLine()) + p.Parse(makeAssistantLine("late")) if len(texts) != 0 { t.Fatalf("expected no text after done, got %v", texts) } @@ -102,19 +102,19 @@ func TestStreamParserIgnoresLinesAfterDone(t *testing.T) { func TestStreamParserIgnoresNonAssistantLines(t *testing.T) { var texts []string - parse := CreateStreamParser( + p := CreateStreamParser( func(text string) { texts = append(texts, text) }, func() {}, ) b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}}) - parse(string(b1)) + p.Parse(string(b1)) b2, _ := json.Marshal(map[string]interface{}{ "type": "assistant", "message": map[string]interface{}{"content": []interface{}{}}, }) - parse(string(b2)) - parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`) + p.Parse(string(b2)) + p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`) if len(texts) != 0 { t.Fatalf("expected no texts, got %v", texts) @@ -124,14 +124,14 @@ func TestStreamParserIgnoresNonAssistantLines(t *testing.T) { func TestStreamParserIgnoresParseErrors(t *testing.T) { var texts []string doneCount := 0 - parse := CreateStreamParser( + p := CreateStreamParser( func(text string) { texts = append(texts, text) }, func() { doneCount++ }, ) - parse("not json") - parse("{") - parse("") + p.Parse("not json") + p.Parse("{") + p.Parse("") if len(texts) != 0 || doneCount != 0 { t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount) @@ -140,7 +140,7 @@ func TestStreamParserIgnoresParseErrors(t *testing.T) { func TestStreamParserJoinsMultipleTextParts(t *testing.T) { var texts []string - parse := CreateStreamParser( + p := CreateStreamParser( func(text string) { texts = append(texts, text) }, func() {}, ) @@ -156,9 +156,125 @@ func TestStreamParserJoinsMultipleTextParts(t *testing.T) { }, } b, _ := json.Marshal(obj) - parse(string(b)) + p.Parse(string(b)) if len(texts) != 1 || texts[0] != "Hello world" { t.Fatalf("expected ['Hello world'], got %v", texts) } } + +func TestStreamParserFlushTriggersDone(t *testing.T) { + var texts []string + doneCount := 0 + p := CreateStreamParser( + func(text string) { texts = append(texts, text) }, + func() { doneCount++ }, + ) + + p.Parse(makeAssistantLine("Hello")) + // agent 結束但沒有 result/success,手動 flush + p.Flush() + if doneCount != 1 { + t.Fatalf("expected onDone called once after Flush, got %d", doneCount) + } + // 再 flush 不應重複觸發 + p.Flush() + if doneCount != 1 { + t.Fatalf("expected onDone called only once, got %d", doneCount) + } +} + +func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) { + doneCount := 0 + p := CreateStreamParser( + func(text string) {}, + func() { doneCount++ }, + ) + + p.Parse(makeResultLine()) + p.Flush() + if doneCount != 1 { + t.Fatalf("expected onDone called once, got %d", doneCount) + } +} + +func makeThinkingLine(thinking string) string { + obj := map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "thinking", "thinking": thinking}, + }, + }, + } + b, _ := json.Marshal(obj) + return string(b) +} + +func makeThinkingAndTextLine(thinking, text string) string { + obj := map[string]interface{}{ + "type": "assistant", + "message": map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "thinking", "thinking": thinking}, + {"type": "text", "text": text}, + }, + }, + } + b, _ := json.Marshal(obj) + return string(b) +} + +func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) { + var texts []string + var thinkings []string + p := CreateStreamParserWithThinking( + func(text string) { texts = append(texts, text) }, + func(thinking string) { thinkings = append(thinkings, thinking) }, + func() {}, + ) + + p.Parse(makeThinkingLine("思考中...")) + p.Parse(makeAssistantLine("回答")) + + if len(thinkings) != 1 || thinkings[0] != "思考中..." { + t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings) + } + if len(texts) != 1 || texts[0] != "回答" { + t.Fatalf("expected texts=['回答'], got %v", texts) + } +} + +func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) { + var texts []string + p := CreateStreamParserWithThinking( + func(text string) { texts = append(texts, text) }, + nil, + func() {}, + ) + + p.Parse(makeThinkingLine("忽略的思考")) + p.Parse(makeAssistantLine("文字")) + + if len(texts) != 1 || texts[0] != "文字" { + t.Fatalf("expected texts=['文字'], got %v", texts) + } +} + +func TestStreamParserWithThinkingDeduplication(t *testing.T) { + var thinkings []string + p := CreateStreamParserWithThinking( + func(text string) {}, + func(thinking string) { thinkings = append(thinkings, thinking) }, + func() {}, + ) + + p.Parse(makeThinkingLine("A")) + p.Parse(makeThinkingLine("B")) + // 重複的完整思考,應被跳過 + p.Parse(makeThinkingLine("AB")) + + if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" { + t.Fatalf("expected thinkings=['A','B'], got %v", thinkings) + } +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 0bf8013..20a2a53 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -180,6 +180,31 @@ func (p *AccountPool) Count() int { return len(p.accounts) } + +// ─── PoolHandle interface ────────────────────────────────────────────────── +// PoolHandle 讓 handler 可以注入獨立的 pool 實例,避免多 port 模式共用全域 pool。 + +type PoolHandle interface { + GetNextConfigDir() string + ReportRequestStart(configDir string) + ReportRequestEnd(configDir string) + ReportRequestSuccess(configDir string, latencyMs int64) + ReportRequestError(configDir string, latencyMs int64) + ReportRateLimit(configDir string, penaltyMs int64) + GetStats() []AccountStat +} + +// GlobalPoolHandle 包裝全域函式以實作 PoolHandle 介面(單 port 模式使用) +type GlobalPoolHandle struct{} + +func (GlobalPoolHandle) GetNextConfigDir() string { return GetNextAccountConfigDir() } +func (GlobalPoolHandle) ReportRequestStart(d string) { ReportRequestStart(d) } +func (GlobalPoolHandle) ReportRequestEnd(d string) { ReportRequestEnd(d) } +func (GlobalPoolHandle) ReportRequestSuccess(d string, l int64) { ReportRequestSuccess(d, l) } +func (GlobalPoolHandle) ReportRequestError(d string, l int64) { ReportRequestError(d, l) } +func (GlobalPoolHandle) ReportRateLimit(d string, p int64) { ReportRateLimit(d, p) } +func (GlobalPoolHandle) GetStats() []AccountStat { return GetAccountStats() } + // ─── Global pool ─────────────────────────────────────────────────────────── var ( diff --git a/internal/process/process.go b/internal/process/process.go index 44aa14a..681e42f 100644 --- a/internal/process/process.go +++ b/internal/process/process.go @@ -215,6 +215,7 @@ func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (Strea go func() { defer wg.Done() scanner := bufio.NewScanner(stdoutPipe) + scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024) for scanner.Scan() { line := scanner.Text() if strings.TrimSpace(line) != "" { @@ -227,6 +228,7 @@ func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (Strea go func() { defer wg.Done() scanner := bufio.NewScanner(stderrPipe) + scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024) for scanner.Scan() { stderrBuf.WriteString(scanner.Text()) stderrBuf.WriteString("\n") diff --git a/internal/router/router.go b/internal/router/router.go index 5d9f6e2..58e8552 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -5,6 +5,7 @@ import ( "cursor-api-proxy/internal/handlers" "cursor-api-proxy/internal/httputil" "cursor-api-proxy/internal/logger" + "cursor-api-proxy/internal/pool" "fmt" "net/http" "os" @@ -16,6 +17,7 @@ type RouterOptions struct { Config config.BridgeConfig ModelCache *handlers.ModelCacheRef LastModel *string + Pool pool.PoolHandle } func NewRouter(opts RouterOptions) http.HandlerFunc { @@ -59,7 +61,7 @@ func NewRouter(opts RouterOptions) http.HandlerFunc { }, nil) return } - handlers.HandleChatCompletions(w, r, cfg, opts.LastModel, raw, method, pathname, remoteAddress) + handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress) case method == "POST" && pathname == "/v1/messages": raw, err := httputil.ReadBody(r) @@ -69,7 +71,7 @@ func NewRouter(opts RouterOptions) http.HandlerFunc { }, nil) return } - handlers.HandleAnthropicMessages(w, r, cfg, opts.LastModel, raw, method, pathname, remoteAddress) + handlers.HandleAnthropicMessages(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress) case (method == "POST" || method == "GET") && pathname == "/v1/completions": httputil.WriteJSON(w, 404, map[string]interface{}{ diff --git a/internal/server/server.go b/internal/server/server.go index 6a7270b..adaa5d6 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,6 +20,7 @@ import ( type ServerOptions struct { Version string Config config.BridgeConfig + Pool pool.PoolHandle } func StartBridgeServer(opts ServerOptions) []*http.Server { @@ -34,8 +35,8 @@ func StartBridgeServer(opts ServerOptions) []*http.Server { subCfg.Port = port subCfg.ConfigDirs = []string{dir} subCfg.MultiPort = false - pool.InitAccountPool([]string{dir}) - srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg}) + subPool := pool.NewAccountPool([]string{dir}) + srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg, Pool: subPool}) servers = append(servers, srv) } return servers @@ -53,11 +54,16 @@ func startSingleServer(opts ServerOptions) *http.Server { modelCache := &handlers.ModelCacheRef{} lastModel := cfg.DefaultModel + ph := opts.Pool + if ph == nil { + ph = pool.GlobalPoolHandle{} + } handler := router.NewRouter(router.RouterOptions{ Version: opts.Version, Config: cfg, ModelCache: modelCache, LastModel: &lastModel, + Pool: ph, }) handler = router.WrapWithRecovery(cfg.SessionsLogPath, handler) diff --git a/internal/toolcall/toolcall.go b/internal/toolcall/toolcall.go new file mode 100644 index 0000000..4c7a161 --- /dev/null +++ b/internal/toolcall/toolcall.go @@ -0,0 +1,153 @@ +package toolcall + +import ( + "encoding/json" + "regexp" + "strings" +) + +type ToolCall struct { + Name string + Arguments string // JSON string +} + +type ParsedResponse struct { + TextContent string + ToolCalls []ToolCall +} + +func (p *ParsedResponse) HasToolCalls() bool { + return len(p.ToolCalls) > 0 +} + +var toolCallTagRe = regexp.MustCompile(`(?s)\s*(\{.*?\})\s*`) +var antmlFunctionCallsRe = regexp.MustCompile(`(?s)\s*(.*?)\s*`) +var antmlInvokeRe = regexp.MustCompile(`(?s)\s*(.*?)\s*`) +var antmlParamRe = regexp.MustCompile(`(?s)(.*?)`) + +func ExtractToolCalls(text string, toolNames map[string]bool) *ParsedResponse { + result := &ParsedResponse{} + + if locs := toolCallTagRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 { + var calls []ToolCall + var textParts []string + last := 0 + for _, loc := range locs { + if loc[0] > last { + textParts = append(textParts, text[last:loc[0]]) + } + jsonStr := text[loc[2]:loc[3]] + if tc := parseToolCallJSON(jsonStr, toolNames); tc != nil { + calls = append(calls, *tc) + } else { + textParts = append(textParts, text[loc[0]:loc[1]]) + } + last = loc[1] + } + if last < len(text) { + textParts = append(textParts, text[last:]) + } + if len(calls) > 0 { + result.TextContent = strings.TrimSpace(strings.Join(textParts, "")) + result.ToolCalls = calls + return result + } + } + + if locs := antmlFunctionCallsRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 { + var calls []ToolCall + var textParts []string + last := 0 + for _, loc := range locs { + if loc[0] > last { + textParts = append(textParts, text[last:loc[0]]) + } + block := text[loc[2]:loc[3]] + invokes := antmlInvokeRe.FindAllStringSubmatch(block, -1) + for _, inv := range invokes { + name := inv[1] + if toolNames != nil && len(toolNames) > 0 && !toolNames[name] { + continue + } + body := inv[2] + args := map[string]interface{}{} + params := antmlParamRe.FindAllStringSubmatch(body, -1) + for _, p := range params { + paramName := p[1] + paramValue := strings.TrimSpace(p[2]) + var jsonVal interface{} + if err := json.Unmarshal([]byte(paramValue), &jsonVal); err == nil { + args[paramName] = jsonVal + } else { + args[paramName] = paramValue + } + } + argsJSON, _ := json.Marshal(args) + calls = append(calls, ToolCall{Name: name, Arguments: string(argsJSON)}) + } + last = loc[1] + } + if last < len(text) { + textParts = append(textParts, text[last:]) + } + if len(calls) > 0 { + result.TextContent = strings.TrimSpace(strings.Join(textParts, "")) + result.ToolCalls = calls + return result + } + } + + result.TextContent = text + return result +} + +func parseToolCallJSON(jsonStr string, toolNames map[string]bool) *ToolCall { + var raw map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil { + return nil + } + name, _ := raw["name"].(string) + if name == "" { + return nil + } + if toolNames != nil && len(toolNames) > 0 && !toolNames[name] { + return nil + } + var argsStr string + switch a := raw["arguments"].(type) { + case string: + argsStr = a + case map[string]interface{}, []interface{}: + b, _ := json.Marshal(a) + argsStr = string(b) + default: + if p, ok := raw["parameters"]; ok { + b, _ := json.Marshal(p) + argsStr = string(b) + } else { + argsStr = "{}" + } + } + return &ToolCall{Name: name, Arguments: argsStr} +} + +func CollectToolNames(tools []interface{}) map[string]bool { + names := map[string]bool{} + for _, t := range tools { + m, ok := t.(map[string]interface{}) + if !ok { + continue + } + if m["type"] == "function" { + if fn, ok := m["function"].(map[string]interface{}); ok { + if name, ok := fn["name"].(string); ok { + names[name] = true + } + } + } + if name, ok := m["name"].(string); ok { + names[name] = true + } + } + return names +} diff --git a/internal/winlimit/winlimit.go b/internal/winlimit/winlimit.go index 035c447..06044b1 100644 --- a/internal/winlimit/winlimit.go +++ b/internal/winlimit/winlimit.go @@ -6,6 +6,13 @@ import ( ) const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n" +const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n" + +// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux. +// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars. +func safeLinuxArgMax() int { + return 1536 * 1024 +} type FitPromptResult struct { OK bool @@ -41,16 +48,7 @@ func estimateCmdlineLength(resolved env.AgentCommand) int { func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult { if runtime.GOOS != "windows" { - args := make([]string, len(fixedArgs)+1) - copy(args, fixedArgs) - args[len(fixedArgs)] = prompt - return FitPromptResult{ - OK: true, - Args: args, - Truncated: false, - OriginalLength: len(prompt), - FinalPromptLength: len(prompt), - } + return fitPromptLinux(fixedArgs, prompt) } e := env.OsEnvToMap() @@ -125,8 +123,59 @@ func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, m } } +// fitPromptLinux handles Linux ARG_MAX truncation. +func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult { + argMax := safeLinuxArgMax() + + // Estimate total cmdline size: sum of all fixed args + prompt + null terminators + fixedLen := 0 + for _, a := range fixedArgs { + fixedLen += len(a) + 1 + } + totalLen := fixedLen + len(prompt) + 1 + + if totalLen <= argMax { + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = prompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: false, + OriginalLength: len(prompt), + FinalPromptLength: len(prompt), + } + } + + // Need to truncate: keep the tail of the prompt (most recent messages) + prefix := LinuxPromptOmissionPrefix + available := argMax - fixedLen - len(prefix) - 1 + if available < 0 { + available = 0 + } + + var finalPrompt string + if available <= 0 { + finalPrompt = prefix + } else if available >= len(prompt) { + finalPrompt = prefix + prompt + } else { + finalPrompt = prefix + prompt[len(prompt)-available:] + } + + args := make([]string, len(fixedArgs)+1) + copy(args, fixedArgs) + args[len(fixedArgs)] = finalPrompt + return FitPromptResult{ + OK: true, + Args: args, + Truncated: true, + OriginalLength: len(prompt), + FinalPromptLength: len(finalPrompt), + } +} + func WarnPromptTruncated(originalLength, finalLength int) { _ = originalLength _ = finalLength - // fmt.Fprintf skipped to avoid import; caller may log as needed }