fix
This commit is contained in:
parent
974f2f2bb5
commit
9919fc7bb9
2
.env
2
.env
|
|
@ -10,7 +10,7 @@ CURSOR_AGENT_NODE=
|
|||
CURSOR_AGENT_SCRIPT=
|
||||
CURSOR_BRIDGE_DEFAULT_MODEL=auto
|
||||
CURSOR_BRIDGE_STRICT_MODEL=true
|
||||
CURSOR_BRIDGE_MAX_MODE=true
|
||||
CURSOR_BRIDGE_MAX_MODE=false
|
||||
CURSOR_BRIDGE_FORCE=false
|
||||
CURSOR_BRIDGE_APPROVE_MCPS=false
|
||||
CURSOR_BRIDGE_WORKSPACE=
|
||||
|
|
|
|||
90
Makefile
90
Makefile
|
|
@ -17,7 +17,7 @@ AGENT_NODE ?=
|
|||
AGENT_SCRIPT ?=
|
||||
DEFAULT_MODEL ?= auto
|
||||
STRICT_MODEL ?= true
|
||||
MAX_MODE ?= true
|
||||
MAX_MODE ?= false
|
||||
FORCE ?= false
|
||||
APPROVE_MCPS ?= false
|
||||
|
||||
|
|
@ -37,7 +37,9 @@ SESSIONS_LOG ?=
|
|||
|
||||
ENV_FILE ?= .env
|
||||
|
||||
.PHONY: env run build clean help
|
||||
OPENCODE_CONFIG ?= $(HOME)/.config/opencode/opencode.json
|
||||
|
||||
.PHONY: env run build clean help opencode opencode-models pm2 pm2-stop pm2-logs claude-code pm2-claude-code
|
||||
|
||||
## 產生 .env 檔(預設輸出至 .env,可用 ENV_FILE=xxx 覆寫)
|
||||
env:
|
||||
|
|
@ -80,13 +82,89 @@ run: build
|
|||
clean:
|
||||
rm -f cursor-api-proxy $(ENV_FILE)
|
||||
|
||||
## 設定 OpenCode 使用此代理(更新 opencode.json 的 cursor provider)
|
||||
opencode: build
|
||||
@if [ ! -f "$(OPENCODE_CONFIG)" ]; then \
|
||||
echo "找不到 $(OPENCODE_CONFIG),建立新設定檔"; \
|
||||
mkdir -p $$(dirname "$(OPENCODE_CONFIG)"); \
|
||||
printf '{\n "provider": {\n "cursor": {\n "npm": "@ai-sdk/openai-compatible",\n "name": "Cursor Agent",\n "options": {\n "baseURL": "http://$(HOST):$(PORT)/v1",\n "apiKey": "unused"\n },\n "models": { "auto": { "name": "Cursor Auto" } }\n }\n }\n}\n' > "$(OPENCODE_CONFIG)"; \
|
||||
echo "已建立 $(OPENCODE_CONFIG)"; \
|
||||
elif [ -n "$(API_KEY)" ]; then \
|
||||
jq '.provider.cursor.options.baseURL = "http://$(HOST):$(PORT)/v1" | .provider.cursor.options.apiKey = "$(API_KEY)"' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \
|
||||
echo "已更新 $(OPENCODE_CONFIG)(baseURL → http://$(HOST):$(PORT)/v1,apiKey 已設定)"; \
|
||||
else \
|
||||
jq '.provider.cursor.options.baseURL = "http://$(HOST):$(PORT)/v1"' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \
|
||||
echo "已更新 $(OPENCODE_CONFIG)(baseURL → http://$(HOST):$(PORT)/v1)"; \
|
||||
fi
|
||||
|
||||
## 啟動代理並用 curl 同步模型列表到 opencode.json
|
||||
opencode-models: opencode
|
||||
@echo "啟動代理以取得模型列表..."
|
||||
@set -a && . ./$(ENV_FILE) 2>/dev/null; set +a; \
|
||||
./cursor-api-proxy & PID=$$!; \
|
||||
sleep 2; \
|
||||
MODELS=$$(curl -s http://$(HOST):$(PORT)/v1/models | jq '[.data[].id]'); \
|
||||
kill $$PID 2>/dev/null; wait $$PID 2>/dev/null; \
|
||||
if [ -n "$$MODELS" ] && [ "$$MODELS" != "null" ]; then \
|
||||
jq --argjson ids "$$MODELS" 'reduce $ids[] as $id (.; .provider.cursor.models[$id] = { name: $id })' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \
|
||||
echo "已同步模型列表到 $(OPENCODE_CONFIG)"; \
|
||||
else \
|
||||
echo "無法取得模型列表,請確認代理已啟動"; \
|
||||
fi
|
||||
|
||||
## 編譯並用 pm2 啟動
|
||||
pm2: build
|
||||
@if [ -f "$(ENV_FILE)" ]; then \
|
||||
env $$(cat $(ENV_FILE) | grep -v '^#' | xargs) CURSOR_BRIDGE_HOST=$(HOST) CURSOR_BRIDGE_PORT=$(PORT) pm2 start ./cursor-api-proxy --name cursor-api-proxy --update-env; \
|
||||
else \
|
||||
CURSOR_BRIDGE_HOST=$(HOST) CURSOR_BRIDGE_PORT=$(PORT) pm2 start ./cursor-api-proxy --name cursor-api-proxy; \
|
||||
fi
|
||||
@pm2 save
|
||||
@echo "pm2 已啟動 cursor-api-proxy(http://$(HOST):$(PORT))"
|
||||
|
||||
## 用 pm2 啟動 OpenCode 代理(設定 + 啟動一步完成)
|
||||
pm2-opencode: opencode pm2
|
||||
@echo "OpenCode 設定已更新並用 pm2 啟動代理"
|
||||
|
||||
## 編譯並用 pm2 啟動 + 設定 Claude Code 環境變數
|
||||
pm2-claude-code: pm2
|
||||
@echo ""
|
||||
@echo "Claude Code 設定:將以下指令加入你的 shell 啟動檔(~/.bashrc 或 ~/.zshrc):"
|
||||
@echo ""
|
||||
@echo " export ANTHROPIC_BASE_URL=http://$(HOST):$(PORT)"
|
||||
@echo " export ANTHROPIC_API_KEY=$(if $(API_KEY),$(API_KEY),dummy-key)"
|
||||
@echo ""
|
||||
@echo "或在當前 shell 執行:"
|
||||
@echo ""
|
||||
@echo " export ANTHROPIC_BASE_URL=http://$(HOST):$(PORT)"
|
||||
@echo " export ANTHROPIC_API_KEY=$(if $(API_KEY),$(API_KEY),dummy-key)"
|
||||
@echo " claude"
|
||||
@echo ""
|
||||
|
||||
## 停止 pm2 中的代理
|
||||
pm2-stop:
|
||||
pm2 stop cursor-api-proxy 2>/dev/null || echo "cursor-api-proxy 未在執行"
|
||||
|
||||
## 查看 pm2 日誌
|
||||
pm2-logs:
|
||||
pm2 logs cursor-api-proxy
|
||||
|
||||
## 顯示說明
|
||||
help:
|
||||
@echo "可用目標:"
|
||||
@echo " make env 產生 .env(先在 Makefile 頂端填好變數)"
|
||||
@echo " make build 編譯 cursor-api-proxy 二進位檔"
|
||||
@echo " make run 編譯並載入 .env 執行"
|
||||
@echo " make clean 刪除二進位檔與 .env"
|
||||
@echo " make env 產生 .env(先在 Makefile 頂端填好變數)"
|
||||
@echo " make build 編譯 cursor-api-proxy 二進位檔"
|
||||
@echo " make run 編譯並載入 .env 執行"
|
||||
@echo " make pm2 編譯並用 pm2 啟動代理"
|
||||
@echo " make pm2-stop 停止 pm2 中的代理"
|
||||
@echo " make pm2-logs 查看 pm2 日誌"
|
||||
@echo " make pm2-claude-code 啟動代理 + 輸出 Claude Code 設定指令"
|
||||
@echo " make opencode 編譯並設定 OpenCode(更新 opencode.json)"
|
||||
@echo " make pm2-opencode 設定 OpenCode + 啟動代理"
|
||||
@echo " make opencode-models 編譯、設定 OpenCode 並同步模型列表"
|
||||
@echo " make clean 刪除二進位檔與 .env"
|
||||
@echo ""
|
||||
@echo "覆寫範例:"
|
||||
@echo " make env PORT=9000 API_KEY=mysecret TIMEOUT_MS=60000"
|
||||
@echo " make pm2-claude-code PORT=8765 API_KEY=mykey"
|
||||
@echo " make pm2-opencode PORT=8765"
|
||||
|
|
|
|||
206
README.md
206
README.md
|
|
@ -1,18 +1,21 @@
|
|||
# Cursor API Proxy
|
||||
|
||||
[English](./README.md) | 繁體中文
|
||||
|
||||
一個讓你可以透過標準 OpenAI/Anthropic API 格式存取 Cursor AI 編輯器的代理伺服器。
|
||||
|
||||
可以把 Cursor 的模型無縫接入 **Claude Code**、**OpenCode** 或任何支援 OpenAI/Anthropic API 的工具。
|
||||
|
||||
## 功能特色
|
||||
|
||||
- **API 相容**:支援 OpenAI 格式和 Anthropic 格式的 API 呼叫
|
||||
- **多帳號管理**:支援新增、移除、切換多個 Cursor 帳號
|
||||
- **動態模型對映**:自動將 Cursor CLI 支援的所有模型轉換為 `claude-*` 格式,供 Claude Code / OpenCode 使用
|
||||
- **Tailscale 支援**:可綁定到 `0.0.0.0` 供區域網路存取
|
||||
- **HWID 重置**:內建反偵測功能,可重置機器識別碼
|
||||
- **連線池**:最佳化的連線管理
|
||||
|
||||
## 安裝
|
||||
## 快速開始
|
||||
|
||||
### 1. 建置
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-repo/cursor-api-proxy-go.git
|
||||
|
|
@ -20,20 +23,179 @@ cd cursor-api-proxy-go
|
|||
go build -o cursor-api-proxy .
|
||||
```
|
||||
|
||||
## 使用方式
|
||||
### 2. 登入 Cursor 帳號
|
||||
|
||||
### 啟動伺服器
|
||||
```bash
|
||||
./cursor-api-proxy login myaccount
|
||||
```
|
||||
|
||||
### 3. 啟動伺服器
|
||||
|
||||
```bash
|
||||
./cursor-api-proxy
|
||||
```
|
||||
|
||||
預設監聽 `127.0.0.1:8080`。
|
||||
預設監聽 `127.0.0.1:8765`。
|
||||
|
||||
### 4. 設定 API Key(選用)
|
||||
|
||||
如果需要外部存取,建議設定 API Key:
|
||||
|
||||
```bash
|
||||
export CURSOR_BRIDGE_API_KEY=my-secret-key
|
||||
./cursor-api-proxy
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 接入 Claude Code
|
||||
|
||||
Claude Code 使用 Anthropic SDK,透過環境變數設定即可改用你的代理。
|
||||
|
||||
### 步驟
|
||||
|
||||
```bash
|
||||
# 1. 啟動代理(確認已在背景執行)
|
||||
./cursor-api-proxy &
|
||||
|
||||
# 2. 設定環境變數,讓 Claude Code 改用你的代理
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8765
|
||||
|
||||
# 3. 如果代理有設定 CURSOR_BRIDGE_API_KEY,這裡填相同的值;沒有的話隨便填
|
||||
export ANTHROPIC_API_KEY=my-secret-key
|
||||
|
||||
# 4. 啟動 Claude Code
|
||||
claude
|
||||
```
|
||||
|
||||
### 切換模型
|
||||
|
||||
在 Claude Code 中輸入 `/model`,即可看到你 Cursor CLI 支援的所有模型。
|
||||
|
||||
代理會自動將 Cursor 模型 ID 轉換為 `claude-*` 格式:
|
||||
|
||||
| Cursor CLI 模型 | Claude Code 看到的 |
|
||||
|---|---|
|
||||
| `opus-4.6` | `claude-opus-4-6` |
|
||||
| `sonnet-4.6` | `claude-sonnet-4-6` |
|
||||
| `opus-4.5-thinking` | `claude-opus-4-5-thinking` |
|
||||
| `sonnet-4.7` (未來新增) | `claude-sonnet-4-7` (自動生成) |
|
||||
|
||||
### 也可直接指定模型
|
||||
|
||||
```bash
|
||||
claude --model claude-sonnet-4-6
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 接入 OpenCode
|
||||
|
||||
OpenCode 透過 `~/.config/opencode/opencode.json` 設定 provider,使用 OpenAI 相容格式。
|
||||
|
||||
### 步驟
|
||||
|
||||
1. 啟動代理伺服器
|
||||
|
||||
```bash
|
||||
./cursor-api-proxy &
|
||||
```
|
||||
|
||||
2. 編輯 `~/.config/opencode/opencode.json`,在 `provider` 中新增一個 provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"provider": {
|
||||
"cursor": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Cursor Agent",
|
||||
"options": {
|
||||
"baseURL": "http://127.0.0.1:8765/v1",
|
||||
"apiKey": "unused"
|
||||
},
|
||||
"models": {
|
||||
"auto": { "name": "Cursor Auto" },
|
||||
"sonnet-4.6": { "name": "Sonnet 4.6" },
|
||||
"opus-4.6": { "name": "Opus 4.6" },
|
||||
"opus-4.6-thinking": { "name": "Opus 4.6 Thinking" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
若代理有設定 `CURSOR_BRIDGE_API_KEY`,把 `apiKey` 改成相同的值。
|
||||
|
||||
3. 在 `opencode.json` 中設定預設模型:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "cursor/sonnet-4.6"
|
||||
}
|
||||
```
|
||||
|
||||
### 查看可用模型
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8765/v1/models | jq '.data[].id'
|
||||
```
|
||||
|
||||
代理的 `/v1/models` 端點會回傳 Cursor CLI 目前支援的所有模型 ID,把結果加到 `opencode.json` 的 `models` 中即可。
|
||||
|
||||
### 在 OpenCode 中切換模型
|
||||
|
||||
OpenCode 的模型切換使用 `/model` 指令,從 `opencode.json` 的 `models` 清單中選擇。
|
||||
|
||||
---
|
||||
|
||||
## 模型對映原理
|
||||
|
||||
兩端使用不同的模型 ID 格式:
|
||||
|
||||
```
|
||||
Claude Code OpenCode
|
||||
│ │
|
||||
│ claude-opus-4-6 │ opus-4.6 (Cursor 原生 ID)
|
||||
│ (Anthropic alias) │ (OpenAI 相容格式)
|
||||
▼ ▼
|
||||
┌──────────────────────┐ ┌──────────────────────┐
|
||||
│ ResolveToCursorModel │ │ 直接傳給代理 │
|
||||
│ claude-opus-4-6 │ │ opus-4.6 │
|
||||
│ ↓ │ │ ↓ │
|
||||
│ opus-4.6 │ │ opus-4.6 │
|
||||
└──────────────────────┘ └──────────────────────┘
|
||||
│ │
|
||||
└──────────┬───────────┘
|
||||
▼
|
||||
Cursor CLI (agent --model opus-4.6)
|
||||
```
|
||||
|
||||
### Anthropic 模型對映表(Claude Code 使用)
|
||||
|
||||
| Cursor 模型 | Anthropic ID | 說明 |
|
||||
|---|---|---|
|
||||
| `opus-4.6` | `claude-opus-4-6` | Claude 4.6 Opus |
|
||||
| `opus-4.6-thinking` | `claude-opus-4-6-thinking` | Claude 4.6 Opus (Thinking) |
|
||||
| `sonnet-4.6` | `claude-sonnet-4-6` | Claude 4.6 Sonnet |
|
||||
| `sonnet-4.6-thinking` | `claude-sonnet-4-6-thinking` | Claude 4.6 Sonnet (Thinking) |
|
||||
| `opus-4.5` | `claude-opus-4-5` | Claude 4.5 Opus |
|
||||
| `opus-4.5-thinking` | `claude-opus-4-5-thinking` | Claude 4.5 Opus (Thinking) |
|
||||
| `sonnet-4.5` | `claude-sonnet-4-5` | Claude 4.5 Sonnet |
|
||||
| `sonnet-4.5-thinking` | `claude-sonnet-4-5-thinking` | Claude 4.5 Sonnet (Thinking) |
|
||||
|
||||
### 動態對映(自動)
|
||||
|
||||
Cursor CLI 新增模型時(如 `opus-4.7`、`sonnet-5.0`),代理自動生成對應的 `claude-*` ID,無需手動更新。
|
||||
|
||||
規則:`<family>-<major>.<minor>` → `claude-<family>-<major>-<minor>`
|
||||
|
||||
---
|
||||
|
||||
## 帳號管理
|
||||
|
||||
### 登入帳號
|
||||
|
||||
```bash
|
||||
# 登入帳號
|
||||
./cursor-api-proxy login myaccount
|
||||
|
||||
# 使用代理登入
|
||||
|
|
@ -62,21 +224,25 @@ go build -o cursor-api-proxy .
|
|||
./cursor-api-proxy reset-hwid --deep-clean
|
||||
```
|
||||
|
||||
### 其他選項
|
||||
### 啟動選項
|
||||
|
||||
| 選項 | 說明 |
|
||||
|------|------|
|
||||
| `--tailscale` | 綁定到 `0.0.0.0` 供區域網路存取 |
|
||||
| `-h, --help` | 顯示說明 |
|
||||
|
||||
---
|
||||
|
||||
## API 端點
|
||||
|
||||
| 端點 | 方法 | 說明 |
|
||||
|------|------|------|
|
||||
| `http://127.0.0.1:8080/v1/chat/completions` | POST | OpenAI 格式聊天完成 |
|
||||
| `http://127.0.0.1:8080/v1/models` | GET | 列出可用模型 |
|
||||
| `http://127.0.0.1:8080/v1/chat/messages` | POST | Anthropic 格式聊天 |
|
||||
| `http://127.0.0.1:8080/health` | GET | 健康檢查 |
|
||||
| `http://127.0.0.1:8765/v1/chat/completions` | POST | OpenAI 格式聊天完成 |
|
||||
| `http://127.0.0.1:8765/v1/models` | GET | 列出可用模型 |
|
||||
| `http://127.0.0.1:8765/v1/chat/messages` | POST | Anthropic 格式聊天 |
|
||||
| `http://127.0.0.1:8765/health` | GET | 健康檢查 |
|
||||
|
||||
---
|
||||
|
||||
## 環境變數
|
||||
|
||||
|
|
@ -127,17 +293,27 @@ go build -o cursor-api-proxy .
|
|||
| `CURSOR_BRIDGE_WIN_CMDLINE_MAX` | `30000` | Windows 命令列最大長度(4096–32700) |
|
||||
| `COMSPEC` | `cmd.exe` | Windows 命令直譯器路徑 |
|
||||
|
||||
---
|
||||
|
||||
## 常見問題
|
||||
|
||||
**Q: 為什麼需要登入帳號?**
|
||||
**Q: 為什麼需要登入帳號?**
|
||||
A: Cursor API 需要驗證才能使用,請先登入你的 Cursor 帳號。
|
||||
|
||||
**Q: 如何處理被BAN的問題?**
|
||||
**Q: 如何處理被BAN的問題?**
|
||||
A: 使用 `reset-hwid` 命令重置機器識別碼,加上 `--deep-clean` 進行更徹底的清理。
|
||||
|
||||
**Q: 可以在其他設備上使用嗎?**
|
||||
**Q: 可以在其他設備上使用嗎?**
|
||||
A: 可以,使用 `--tailscale` 選項啟動伺服器,然後透過區域網路 IP 存取。
|
||||
|
||||
**Q: 模型列表多久更新一次?**
|
||||
A: 每次呼叫 `GET /v1/models` 時,代理會即時呼叫 Cursor CLI 的 `--list-models` 取得最新模型,並自動生成對應的 `claude-*` ID。
|
||||
|
||||
**Q: 未來 Cursor 新增模型怎麼辦?**
|
||||
A: 不用改任何東西。只要新模型符合 `<family>-<major>.<minor>` 命名規則,代理會自動生成對應的 `claude-*` ID。
|
||||
|
||||
---
|
||||
|
||||
## 授權
|
||||
|
||||
MIT License
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ package agent
|
|||
import "cursor-api-proxy/internal/config"
|
||||
|
||||
func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, stream bool) []string {
|
||||
args := []string{"--print"}
|
||||
args := []string{"--print", "--plan"}
|
||||
if cfg.ApproveMcps {
|
||||
args = append(args, "--approve-mcps")
|
||||
}
|
||||
|
|
@ -13,7 +13,6 @@ func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, st
|
|||
if cfg.ChatOnlyWorkspace {
|
||||
args = append(args, "--trust")
|
||||
}
|
||||
args = append(args, "--mode", "ask")
|
||||
args = append(args, "--workspace", workspaceDir)
|
||||
args = append(args, "--model", model)
|
||||
if stream {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ package anthropic
|
|||
|
||||
import (
|
||||
"cursor-api-proxy/internal/openai"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
|
@ -83,6 +85,44 @@ func anthropicBlockToText(p interface{}) string {
|
|||
return "[Document: " + title + "]"
|
||||
}
|
||||
return "[Document]"
|
||||
case "tool_use":
|
||||
name, _ := v["name"].(string)
|
||||
id, _ := v["id"].(string)
|
||||
input := v["input"]
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
if inputJSON == nil {
|
||||
inputJSON = []byte("{}")
|
||||
}
|
||||
tag := fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, string(inputJSON))
|
||||
if id != "" {
|
||||
tag = fmt.Sprintf("[tool_use_id=%s] ", id) + tag
|
||||
}
|
||||
return tag
|
||||
case "tool_result":
|
||||
toolUseID, _ := v["tool_use_id"].(string)
|
||||
content := v["content"]
|
||||
var contentText string
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
contentText = c
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, block := range c {
|
||||
if bm, ok := block.(map[string]interface{}); ok {
|
||||
if bm["type"] == "text" {
|
||||
if t, ok := bm["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
contentText = strings.Join(parts, "\n")
|
||||
}
|
||||
label := "Tool result"
|
||||
if toolUseID != "" {
|
||||
label += " [id=" + toolUseID + "]"
|
||||
}
|
||||
return label + ": " + contentText
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/anthropic"
|
||||
"cursor-api-proxy/internal/config"
|
||||
|
|
@ -11,9 +12,11 @@ import (
|
|||
"cursor-api-proxy/internal/parser"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"cursor-api-proxy/internal/sanitize"
|
||||
"cursor-api-proxy/internal/toolcall"
|
||||
"cursor-api-proxy/internal/winlimit"
|
||||
"cursor-api-proxy/internal/workspace"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
|
@ -21,7 +24,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||||
func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||||
var req anthropic.MessagesRequest
|
||||
if err := json.Unmarshal([]byte(rawBody), &req); err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
|
|
@ -33,13 +36,11 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.
|
|||
requested := openai.NormalizeModelID(req.Model)
|
||||
model := ResolveModel(requested, lastModelRef, cfg)
|
||||
|
||||
// Parse system from raw body to handle both string and array
|
||||
var rawMap map[string]interface{}
|
||||
_ = json.Unmarshal([]byte(rawBody), &rawMap)
|
||||
|
||||
cleanSystem := sanitize.SanitizeSystem(req.System)
|
||||
|
||||
// SanitizeMessages expects []interface{}
|
||||
rawMessages := make([]interface{}, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
||||
|
|
@ -103,6 +104,15 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.
|
|||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream)
|
||||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
||||
|
||||
if cfg.Verbose {
|
||||
if len(prompt) > 200 {
|
||||
logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...")
|
||||
} else {
|
||||
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt)
|
||||
}
|
||||
logger.LogDebug("cmd_args=%v", fit.Args)
|
||||
}
|
||||
|
||||
if !fit.OK {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": fit.Error},
|
||||
|
|
@ -110,8 +120,7 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.
|
|||
return
|
||||
}
|
||||
if fit.Truncated {
|
||||
fmt.Printf("[%s] Windows: prompt truncated (%d -> %d chars).\n",
|
||||
time.Now().UTC().Format(time.RFC3339), fit.OriginalLength, fit.FinalPromptLength)
|
||||
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
|
||||
}
|
||||
|
||||
cmdArgs := fit.Args
|
||||
|
|
@ -122,6 +131,12 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.
|
|||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
||||
}
|
||||
|
||||
hasTools := len(req.Tools) > 0
|
||||
var toolNames map[string]bool
|
||||
if hasTools {
|
||||
toolNames = toolcall.CollectToolNames(req.Tools)
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
|
@ -134,6 +149,11 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.
|
|||
}
|
||||
}
|
||||
|
||||
var accumulated string
|
||||
var accumulatedThinking string
|
||||
var chunkNum int
|
||||
var p parser.Parser
|
||||
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
|
|
@ -144,74 +164,252 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.
|
|||
"content": []interface{}{},
|
||||
},
|
||||
})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
|
||||
var accumulated string
|
||||
parseLine := parser.CreateStreamParser(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": 0})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": 0},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||||
},
|
||||
)
|
||||
if hasTools {
|
||||
// tools 模式:先累積所有內容,完成後再一次性輸出(因為 tool_calls 需要完整解析)
|
||||
p = parser.CreateStreamParserWithThinking(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunkNum++
|
||||
logger.LogStreamChunk(model, text, chunkNum)
|
||||
},
|
||||
func(thinking string) {
|
||||
accumulatedThinking += thinking
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
|
||||
|
||||
configDir := pool.GetNextAccountConfigDir()
|
||||
blockIndex := 0
|
||||
if accumulatedThinking != "" {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start", "index": blockIndex,
|
||||
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
||||
})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta", "index": blockIndex,
|
||||
"delta": map[string]string{"type": "thinking_delta", "thinking": accumulatedThinking},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||||
blockIndex++
|
||||
}
|
||||
|
||||
if parsed.HasToolCalls() {
|
||||
if parsed.TextContent != "" {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start", "index": blockIndex,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta", "index": blockIndex,
|
||||
"delta": map[string]string{"type": "text_delta", "text": parsed.TextContent},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||||
blockIndex++
|
||||
}
|
||||
for _, tc := range parsed.ToolCalls {
|
||||
toolID := "toolu_" + uuid.New().String()[:12]
|
||||
var inputObj interface{}
|
||||
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
|
||||
if inputObj == nil {
|
||||
inputObj = map[string]interface{}{}
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start", "index": blockIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta", "index": blockIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta", "partial_json": tc.Arguments,
|
||||
},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||||
blockIndex++
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": 0},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||||
} else {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start", "index": blockIndex,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
if accumulated != "" {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta", "index": blockIndex,
|
||||
"delta": map[string]string{"type": "text_delta", "text": accumulated},
|
||||
})
|
||||
}
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": 0},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||||
}
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// 非 tools 模式:即時串流 thinking 和 text
|
||||
// blockCount 追蹤已開啟的 block 數量
|
||||
// thinkingOpen 代表 thinking block 是否已開啟且尚未關閉
|
||||
// textOpen 代表 text block 是否已開啟且尚未關閉
|
||||
blockCount := 0
|
||||
thinkingOpen := false
|
||||
textOpen := false
|
||||
|
||||
p = parser.CreateStreamParserWithThinking(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunkNum++
|
||||
logger.LogStreamChunk(model, text, chunkNum)
|
||||
// 若 thinking block 尚未關閉,先關閉它
|
||||
if thinkingOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||||
thinkingOpen = false
|
||||
}
|
||||
// 若 text block 尚未開啟,先開啟它
|
||||
if !textOpen {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": blockCount,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
textOpen = true
|
||||
blockCount++
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": blockCount - 1,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
},
|
||||
func(thinking string) {
|
||||
accumulatedThinking += thinking
|
||||
chunkNum++
|
||||
// 若 thinking block 尚未開啟,先開啟它
|
||||
if !thinkingOpen {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": blockCount,
|
||||
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
||||
})
|
||||
thinkingOpen = true
|
||||
blockCount++
|
||||
}
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": blockCount - 1,
|
||||
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
|
||||
})
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
// 關閉尚未關閉的 thinking block
|
||||
if thinkingOpen {
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||||
thinkingOpen = false
|
||||
}
|
||||
// 若 text block 尚未開啟(全部都是 thinking,沒有 text),開啟並立即關閉空的 text block
|
||||
if !textOpen {
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": blockCount,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
blockCount++
|
||||
}
|
||||
// 關閉 text block
|
||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||||
writeEvent(map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": 0},
|
||||
})
|
||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
configDir := ph.GetNextConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
pool.ReportRequestStart(configDir)
|
||||
ph.ReportRequestStart(configDir)
|
||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
||||
streamStart := time.Now().UnixMilli()
|
||||
|
||||
ctx := r.Context()
|
||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx)
|
||||
wrappedParser := func(line string) {
|
||||
logger.LogRawLine(line)
|
||||
p.Parse(line)
|
||||
}
|
||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
|
||||
|
||||
// agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾
|
||||
if ctx.Err() == nil {
|
||||
p.Flush()
|
||||
}
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - streamStart
|
||||
pool.ReportRequestEnd(configDir)
|
||||
ph.ReportRequestEnd(configDir)
|
||||
|
||||
if err == nil && isRateLimited(result.Stderr) {
|
||||
pool.ReportRateLimit(configDir, 60000)
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||||
} else if ctx.Err() == context.Canceled {
|
||||
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
||||
} else if err == nil && isRateLimited(result.Stderr) {
|
||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
||||
}
|
||||
if err != nil || result.Code != 0 {
|
||||
pool.ReportRequestError(configDir, latencyMs)
|
||||
|
||||
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
||||
ph.ReportRequestError(configDir, latencyMs)
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
} else {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
||||
}
|
||||
} else {
|
||||
pool.ReportRequestSuccess(configDir, latencyMs)
|
||||
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
||||
} else if ctx.Err() == nil {
|
||||
ph.ReportRequestSuccess(configDir, latencyMs)
|
||||
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
||||
}
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
return
|
||||
}
|
||||
|
||||
configDir := pool.GetNextAccountConfigDir()
|
||||
configDir := ph.GetNextConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
pool.ReportRequestStart(configDir)
|
||||
ph.ReportRequestStart(configDir)
|
||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
|
||||
syncStart := time.Now().UnixMilli()
|
||||
|
||||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
||||
syncLatency := time.Now().UnixMilli() - syncStart
|
||||
pool.ReportRequestEnd(configDir)
|
||||
ph.ReportRequestEnd(configDir)
|
||||
|
||||
ctx := r.Context()
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||||
httputil.WriteJSON(w, 504, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
if ctx.Err() == context.Canceled {
|
||||
logger.LogClientDisconnect(method, pathname, model, syncLatency)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
pool.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
ph.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": err.Error()},
|
||||
}, nil)
|
||||
|
|
@ -219,32 +417,67 @@ func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.
|
|||
}
|
||||
|
||||
if isRateLimited(out.Stderr) {
|
||||
pool.ReportRateLimit(configDir, 60000)
|
||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
|
||||
}
|
||||
|
||||
if out.Code != 0 {
|
||||
pool.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
ph.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"type": "api_error", "message": errMsg},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
pool.ReportRequestSuccess(configDir, syncLatency)
|
||||
content := trimSpace(out.Stdout)
|
||||
ph.ReportRequestSuccess(configDir, syncLatency)
|
||||
content := strings.TrimSpace(out.Stdout)
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
|
||||
|
||||
if hasTools {
|
||||
parsed := toolcall.ExtractToolCalls(content, toolNames)
|
||||
if parsed.HasToolCalls() {
|
||||
var contentBlocks []map[string]interface{}
|
||||
if parsed.TextContent != "" {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text", "text": parsed.TextContent,
|
||||
})
|
||||
}
|
||||
for _, tc := range parsed.ToolCalls {
|
||||
toolID := "toolu_" + uuid.New().String()[:12]
|
||||
var inputObj interface{}
|
||||
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
|
||||
if inputObj == nil {
|
||||
inputObj = map[string]interface{}{}
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj,
|
||||
})
|
||||
}
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": contentBlocks,
|
||||
"model": model,
|
||||
"stop_reason": "tool_use",
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []map[string]string{{"type": "text", "text": content}},
|
||||
"model": model,
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []map[string]string{{"type": "text", "text": content}},
|
||||
"model": model,
|
||||
"stop_reason": "end_turn",
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/agent"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
|
|
@ -10,24 +11,37 @@ import (
|
|||
"cursor-api-proxy/internal/parser"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"cursor-api-proxy/internal/sanitize"
|
||||
"cursor-api-proxy/internal/toolcall"
|
||||
"cursor-api-proxy/internal/winlimit"
|
||||
"cursor-api-proxy/internal/workspace"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`)
|
||||
var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`)
|
||||
|
||||
func isRateLimited(stderr string) bool {
|
||||
return rateLimitRe.MatchString(stderr)
|
||||
}
|
||||
|
||||
func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||||
func extractRetryAfterMs(stderr string) int64 {
|
||||
if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 {
|
||||
if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 {
|
||||
return secs * 1000
|
||||
}
|
||||
}
|
||||
return 60000
|
||||
}
|
||||
|
||||
func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||||
var bodyMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
|
|
@ -86,9 +100,22 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.Br
|
|||
headerWs := r.Header.Get("x-cursor-workspace")
|
||||
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
||||
|
||||
promptLen := len(prompt)
|
||||
if cfg.Verbose {
|
||||
if promptLen > 200 {
|
||||
logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200])
|
||||
} else {
|
||||
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream)
|
||||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
||||
|
||||
if cfg.Verbose {
|
||||
logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args)
|
||||
}
|
||||
|
||||
if !fit.OK {
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"},
|
||||
|
|
@ -108,90 +135,261 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.Br
|
|||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
||||
}
|
||||
|
||||
hasTools := len(tools) > 0 || len(funcs) > 0
|
||||
var toolNames map[string]bool
|
||||
if hasTools {
|
||||
toolNames = toolcall.CollectToolNames(tools)
|
||||
for _, f := range funcs {
|
||||
if fm, ok := f.(map[string]interface{}); ok {
|
||||
if name, ok := fm["name"].(string); ok {
|
||||
toolNames[name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isStream {
|
||||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
var accumulated string
|
||||
parseLine := parser.CreateStreamParser(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
)
|
||||
var chunkNum int
|
||||
var p parser.Parser
|
||||
|
||||
configDir := pool.GetNextAccountConfigDir()
|
||||
// toolCallMarkerRe 偵測 tool call 開頭標記,一旦出現就停止即時輸出並進入累積模式
|
||||
toolCallMarkerRe := regexp.MustCompile(`<tool_call>|<function_calls>`)
|
||||
if hasTools {
|
||||
var toolCallMode bool // 是否已進入 tool call 累積模式
|
||||
p = parser.CreateStreamParserWithThinking(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunkNum++
|
||||
logger.LogStreamChunk(model, text, chunkNum)
|
||||
if toolCallMode {
|
||||
// 已進入累積模式,不即時輸出
|
||||
return
|
||||
}
|
||||
if toolCallMarkerRe.MatchString(text) {
|
||||
// 偵測到 tool call 標記,切換為累積模式
|
||||
toolCallMode = true
|
||||
return
|
||||
}
|
||||
chunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
func(_ string) {}, // thinking ignored in tools mode
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
|
||||
|
||||
if parsed.HasToolCalls() {
|
||||
if parsed.TextContent != "" && toolCallMode {
|
||||
// 已有部分 text 被即時輸出,只補發剩餘的
|
||||
chunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
for i, tc := range parsed.ToolCalls {
|
||||
callID := "call_" + uuid.New().String()[:8]
|
||||
chunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{
|
||||
{
|
||||
"index": i,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": tc.Name,
|
||||
"arguments": tc.Arguments,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
} else {
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
} else {
|
||||
p = parser.CreateStreamParserWithThinking(
|
||||
func(text string) {
|
||||
accumulated += text
|
||||
chunkNum++
|
||||
logger.LogStreamChunk(model, text, chunkNum)
|
||||
chunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
func(thinking string) {
|
||||
chunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
func() {
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
configDir := ph.GetNextConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
pool.ReportRequestStart(configDir)
|
||||
ph.ReportRequestStart(configDir)
|
||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
||||
streamStart := time.Now().UnixMilli()
|
||||
|
||||
ctx := r.Context()
|
||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx)
|
||||
wrappedParser := func(line string) {
|
||||
logger.LogRawLine(line)
|
||||
p.Parse(line)
|
||||
}
|
||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
|
||||
|
||||
// agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾
|
||||
if ctx.Err() == nil {
|
||||
p.Flush()
|
||||
}
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - streamStart
|
||||
pool.ReportRequestEnd(configDir)
|
||||
ph.ReportRequestEnd(configDir)
|
||||
|
||||
if err == nil && isRateLimited(result.Stderr) {
|
||||
pool.ReportRateLimit(configDir, 60000)
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||||
} else if ctx.Err() == context.Canceled {
|
||||
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
||||
} else if err == nil && isRateLimited(result.Stderr) {
|
||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
||||
}
|
||||
|
||||
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
||||
pool.ReportRequestError(configDir, latencyMs)
|
||||
ph.ReportRequestError(configDir, latencyMs)
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
} else {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
||||
}
|
||||
} else {
|
||||
pool.ReportRequestSuccess(configDir, latencyMs)
|
||||
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
||||
} else if ctx.Err() == nil {
|
||||
ph.ReportRequestSuccess(configDir, latencyMs)
|
||||
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
||||
}
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
return
|
||||
}
|
||||
|
||||
configDir := pool.GetNextAccountConfigDir()
|
||||
configDir := ph.GetNextConfigDir()
|
||||
logger.LogAccountAssigned(configDir)
|
||||
pool.ReportRequestStart(configDir)
|
||||
ph.ReportRequestStart(configDir)
|
||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
|
||||
syncStart := time.Now().UnixMilli()
|
||||
|
||||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
||||
syncLatency := time.Now().UnixMilli() - syncStart
|
||||
pool.ReportRequestEnd(configDir)
|
||||
ph.ReportRequestEnd(configDir)
|
||||
|
||||
ctx := r.Context()
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||||
httputil.WriteJSON(w, 504, map[string]interface{}{
|
||||
"error": map[string]string{"message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs), "code": "timeout"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
if ctx.Err() == context.Canceled {
|
||||
logger.LogClientDisconnect(method, pathname, model, syncLatency)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
pool.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
ph.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"},
|
||||
}, nil)
|
||||
|
|
@ -199,23 +397,61 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.Br
|
|||
}
|
||||
|
||||
if isRateLimited(out.Stderr) {
|
||||
pool.ReportRateLimit(configDir, 60000)
|
||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
|
||||
}
|
||||
|
||||
if out.Code != 0 {
|
||||
pool.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
ph.ReportRequestError(configDir, syncLatency)
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": errMsg, "code": "cursor_cli_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
pool.ReportRequestSuccess(configDir, syncLatency)
|
||||
content := trimSpace(out.Stdout)
|
||||
ph.ReportRequestSuccess(configDir, syncLatency)
|
||||
content := strings.TrimSpace(out.Stdout)
|
||||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
||||
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
|
||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||||
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
|
||||
|
||||
if hasTools {
|
||||
parsed := toolcall.ExtractToolCalls(content, toolNames)
|
||||
if parsed.HasToolCalls() {
|
||||
msg := map[string]interface{}{"role": "assistant"}
|
||||
if parsed.TextContent != "" {
|
||||
msg["content"] = parsed.TextContent
|
||||
} else {
|
||||
msg["content"] = nil
|
||||
}
|
||||
var tcArr []map[string]interface{}
|
||||
for _, tc := range parsed.ToolCalls {
|
||||
callID := "call_" + uuid.New().String()[:8]
|
||||
tcArr = append(tcArr, map[string]interface{}{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": tc.Name,
|
||||
"arguments": tc.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
msg["tool_calls"] = tcArr
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{"index": 0, "message": msg, "finish_reason": "tool_calls"},
|
||||
},
|
||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}, truncatedHeaders)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||||
"id": id,
|
||||
|
|
@ -233,16 +469,3 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.Br
|
|||
}, truncatedHeaders)
|
||||
}
|
||||
|
||||
func trimSpace(s string) string {
|
||||
result := ""
|
||||
start := 0
|
||||
end := len(s)
|
||||
for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') {
|
||||
start++
|
||||
}
|
||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') {
|
||||
end--
|
||||
}
|
||||
result = s[start:end]
|
||||
return result
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,6 +65,11 @@ type TrafficMessage struct {
|
|||
Content string
|
||||
}
|
||||
|
||||
func LogDebug(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg)
|
||||
}
|
||||
|
||||
func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) {
|
||||
fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset)
|
||||
fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n",
|
||||
|
|
@ -76,6 +81,7 @@ func LogServerStart(version, scheme, host string, port int, cfg config.BridgeCon
|
|||
fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset)
|
||||
fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset)
|
||||
fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset)
|
||||
fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset)
|
||||
|
||||
flags := []string{}
|
||||
if cfg.Force {
|
||||
|
|
@ -109,6 +115,46 @@ func LogShutdown(sig string) {
|
|||
fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset)
|
||||
}
|
||||
|
||||
func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) {
|
||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
||||
if isStream {
|
||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
|
||||
}
|
||||
fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n",
|
||||
ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag)
|
||||
}
|
||||
|
||||
func LogRequestDone(method, pathname, model string, latencyMs int64, code int) {
|
||||
statusColor := cBGreen
|
||||
if code != 0 {
|
||||
statusColor = cRed
|
||||
}
|
||||
fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n",
|
||||
ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset)
|
||||
}
|
||||
|
||||
func LogRequestTimeout(method, pathname, model string, timeoutMs int) {
|
||||
fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n",
|
||||
ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset)
|
||||
}
|
||||
|
||||
func LogClientDisconnect(method, pathname, model string, latencyMs int64) {
|
||||
fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n",
|
||||
ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset)
|
||||
}
|
||||
|
||||
func LogStreamChunk(model string, text string, chunkNum int) {
|
||||
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120)
|
||||
fmt.Printf("%s %s▸%s #%d %s%s%s\n",
|
||||
ts(), cDim, cReset, chunkNum, cWhite, preview, cReset)
|
||||
}
|
||||
|
||||
func LogRawLine(line string) {
|
||||
preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200)
|
||||
fmt.Printf("%s %s│%s %sraw%s %s\n",
|
||||
ts(), cGray, cReset, cDim, cReset, preview)
|
||||
}
|
||||
|
||||
func LogIncoming(method, pathname, remoteAddress string) {
|
||||
methodColor := cBCyan
|
||||
switch method {
|
||||
|
|
|
|||
|
|
@ -1,26 +1,29 @@
|
|||
package models
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var anthropicToCursor = map[string]string{
|
||||
"claude-opus-4-6": "opus-4.6",
|
||||
"claude-opus-4.6": "opus-4.6",
|
||||
"claude-sonnet-4-6": "sonnet-4.6",
|
||||
"claude-sonnet-4.6": "sonnet-4.6",
|
||||
"claude-opus-4-5": "opus-4.5",
|
||||
"claude-opus-4.5": "opus-4.5",
|
||||
"claude-sonnet-4-5": "sonnet-4.5",
|
||||
"claude-sonnet-4.5": "sonnet-4.5",
|
||||
"claude-opus-4": "opus-4.6",
|
||||
"claude-sonnet-4": "sonnet-4.6",
|
||||
"claude-haiku-4-5-20251001": "sonnet-4.5",
|
||||
"claude-haiku-4-5": "sonnet-4.5",
|
||||
"claude-haiku-4-6": "sonnet-4.6",
|
||||
"claude-haiku-4": "sonnet-4.5",
|
||||
"claude-opus-4-6-thinking": "opus-4.6-thinking",
|
||||
"claude-sonnet-4-6-thinking": "sonnet-4.6-thinking",
|
||||
"claude-opus-4-5-thinking": "opus-4.5-thinking",
|
||||
"claude-sonnet-4-5-thinking": "sonnet-4.5-thinking",
|
||||
"claude-opus-4-6": "opus-4.6",
|
||||
"claude-opus-4.6": "opus-4.6",
|
||||
"claude-sonnet-4-6": "sonnet-4.6",
|
||||
"claude-sonnet-4.6": "sonnet-4.6",
|
||||
"claude-opus-4-5": "opus-4.5",
|
||||
"claude-opus-4.5": "opus-4.5",
|
||||
"claude-sonnet-4-5": "sonnet-4.5",
|
||||
"claude-sonnet-4.5": "sonnet-4.5",
|
||||
"claude-opus-4": "opus-4.6",
|
||||
"claude-sonnet-4": "sonnet-4.6",
|
||||
"claude-haiku-4-5-20251001": "sonnet-4.5",
|
||||
"claude-haiku-4-5": "sonnet-4.5",
|
||||
"claude-haiku-4-6": "sonnet-4.6",
|
||||
"claude-haiku-4": "sonnet-4.5",
|
||||
"claude-opus-4-6-thinking": "opus-4.6-thinking",
|
||||
"claude-sonnet-4-6-thinking": "sonnet-4.6-thinking",
|
||||
"claude-opus-4-5-thinking": "opus-4.5-thinking",
|
||||
"claude-sonnet-4-5-thinking": "sonnet-4.5-thinking",
|
||||
}
|
||||
|
||||
type ModelAlias struct {
|
||||
|
|
@ -40,6 +43,40 @@ var cursorToAnthropicAlias = []ModelAlias{
|
|||
{"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"},
|
||||
}
|
||||
|
||||
// cursorModelPattern matches cursor model IDs like "opus-4.6", "sonnet-4.7-thinking".
|
||||
var cursorModelPattern = regexp.MustCompile(`^([a-zA-Z]+)-(\d+)\.(\d+)(-thinking)?$`)
|
||||
|
||||
// reverseDynamicPattern matches dynamically generated anthropic aliases
|
||||
// like "claude-opus-4-7", "claude-sonnet-4-7-thinking".
|
||||
var reverseDynamicPattern = regexp.MustCompile(`^claude-([a-zA-Z]+)-(\d+)-(\d+)(-thinking)?$`)
|
||||
|
||||
func generateDynamicAlias(cursorID string) (AnthropicAlias, bool) {
|
||||
m := cursorModelPattern.FindStringSubmatch(cursorID)
|
||||
if m == nil {
|
||||
return AnthropicAlias{}, false
|
||||
}
|
||||
family := m[1]
|
||||
major := m[2]
|
||||
minor := m[3]
|
||||
thinking := m[4]
|
||||
|
||||
anthropicID := "claude-" + family + "-" + major + "-" + minor + thinking
|
||||
capFamily := strings.ToUpper(family[:1]) + family[1:]
|
||||
name := capFamily + " " + major + "." + minor
|
||||
if thinking == "-thinking" {
|
||||
name += " (Thinking)"
|
||||
}
|
||||
return AnthropicAlias{ID: anthropicID, Name: name}, true
|
||||
}
|
||||
|
||||
func reverseDynamicAlias(anthropicID string) (string, bool) {
|
||||
m := reverseDynamicPattern.FindStringSubmatch(anthropicID)
|
||||
if m == nil {
|
||||
return "", false
|
||||
}
|
||||
return m[1] + "-" + m[2] + "." + m[3] + m[4], true
|
||||
}
|
||||
|
||||
func ResolveToCursorModel(requested string) string {
|
||||
if strings.TrimSpace(requested) == "" {
|
||||
return ""
|
||||
|
|
@ -48,6 +85,9 @@ func ResolveToCursorModel(requested string) string {
|
|||
if v, ok := anthropicToCursor[key]; ok {
|
||||
return v
|
||||
}
|
||||
if v, ok := reverseDynamicAlias(key); ok {
|
||||
return v
|
||||
}
|
||||
return strings.TrimSpace(requested)
|
||||
}
|
||||
|
||||
|
|
@ -61,11 +101,23 @@ func GetAnthropicModelAliases(availableCursorIDs []string) []AnthropicAlias {
|
|||
for _, id := range availableCursorIDs {
|
||||
set[id] = true
|
||||
}
|
||||
|
||||
staticSet := make(map[string]bool, len(cursorToAnthropicAlias))
|
||||
var result []AnthropicAlias
|
||||
for _, a := range cursorToAnthropicAlias {
|
||||
if set[a.CursorID] {
|
||||
staticSet[a.CursorID] = true
|
||||
result = append(result, AnthropicAlias{ID: a.AnthropicID, Name: a.Name})
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range availableCursorIDs {
|
||||
if staticSet[id] {
|
||||
continue
|
||||
}
|
||||
if alias, ok := generateDynamicAlias(id); ok {
|
||||
result = append(result, alias)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,163 @@
|
|||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetAnthropicModelAliases_StaticOnly(t *testing.T) {
|
||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.5"})
|
||||
if len(aliases) != 2 {
|
||||
t.Fatalf("expected 2 aliases, got %d: %v", len(aliases), aliases)
|
||||
}
|
||||
ids := map[string]string{}
|
||||
for _, a := range aliases {
|
||||
ids[a.ID] = a.Name
|
||||
}
|
||||
if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" {
|
||||
t.Errorf("unexpected name for claude-sonnet-4-6: %s", ids["claude-sonnet-4-6"])
|
||||
}
|
||||
if ids["claude-opus-4-5"] != "Claude 4.5 Opus" {
|
||||
t.Errorf("unexpected name for claude-opus-4-5: %s", ids["claude-opus-4-5"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnthropicModelAliases_DynamicFallback(t *testing.T) {
|
||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.7", "opus-5.0-thinking", "gpt-4o"})
|
||||
ids := map[string]string{}
|
||||
for _, a := range aliases {
|
||||
ids[a.ID] = a.Name
|
||||
}
|
||||
if ids["claude-sonnet-4-7"] != "Sonnet 4.7" {
|
||||
t.Errorf("unexpected name for claude-sonnet-4-7: %s", ids["claude-sonnet-4-7"])
|
||||
}
|
||||
if ids["claude-opus-5-0-thinking"] != "Opus 5.0 (Thinking)" {
|
||||
t.Errorf("unexpected name for claude-opus-5-0-thinking: %s", ids["claude-opus-5-0-thinking"])
|
||||
}
|
||||
if _, ok := ids["claude-gpt-4o"]; ok {
|
||||
t.Errorf("gpt-4o should not generate a claude alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnthropicModelAliases_Mixed(t *testing.T) {
|
||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.7", "gpt-4o"})
|
||||
ids := map[string]string{}
|
||||
for _, a := range aliases {
|
||||
ids[a.ID] = a.Name
|
||||
}
|
||||
// static entry keeps its custom name
|
||||
if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" {
|
||||
t.Errorf("static alias should keep original name, got: %s", ids["claude-sonnet-4-6"])
|
||||
}
|
||||
// dynamic entry uses auto-generated name
|
||||
if ids["claude-opus-4-7"] != "Opus 4.7" {
|
||||
t.Errorf("dynamic alias name mismatch: %s", ids["claude-opus-4-7"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnthropicModelAliases_UnknownPattern(t *testing.T) {
|
||||
aliases := GetAnthropicModelAliases([]string{"some-unknown-model"})
|
||||
if len(aliases) != 0 {
|
||||
t.Fatalf("expected 0 aliases for unknown pattern, got %d: %v", len(aliases), aliases)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToCursorModel_Static(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"claude-opus-4-6", "opus-4.6"},
|
||||
{"claude-opus-4.6", "opus-4.6"},
|
||||
{"claude-sonnet-4-5", "sonnet-4.5"},
|
||||
{"claude-opus-4-6-thinking", "opus-4.6-thinking"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := ResolveToCursorModel(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToCursorModel_DynamicFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"claude-opus-4-7", "opus-4.7"},
|
||||
{"claude-sonnet-5-0", "sonnet-5.0"},
|
||||
{"claude-opus-4-7-thinking", "opus-4.7-thinking"},
|
||||
{"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := ResolveToCursorModel(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToCursorModel_Passthrough(t *testing.T) {
|
||||
tests := []string{"sonnet-4.6", "gpt-4o", "custom-model"}
|
||||
for _, input := range tests {
|
||||
got := ResolveToCursorModel(input)
|
||||
if got != input {
|
||||
t.Errorf("ResolveToCursorModel(%q) = %q, want passthrough %q", input, got, input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToCursorModel_Empty(t *testing.T) {
|
||||
if got := ResolveToCursorModel(""); got != "" {
|
||||
t.Errorf("ResolveToCursorModel(\"\") = %q, want empty", got)
|
||||
}
|
||||
if got := ResolveToCursorModel(" "); got != "" {
|
||||
t.Errorf("ResolveToCursorModel(\" \") = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDynamicAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want AnthropicAlias
|
||||
ok bool
|
||||
}{
|
||||
{"opus-4.7", AnthropicAlias{"claude-opus-4-7", "Opus 4.7"}, true},
|
||||
{"sonnet-5.0-thinking", AnthropicAlias{"claude-sonnet-5-0-thinking", "Sonnet 5.0 (Thinking)"}, true},
|
||||
{"gpt-4o", AnthropicAlias{}, false},
|
||||
{"invalid", AnthropicAlias{}, false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got, ok := generateDynamicAlias(tc.input)
|
||||
if ok != tc.ok {
|
||||
t.Errorf("generateDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok)
|
||||
continue
|
||||
}
|
||||
if ok && (got.ID != tc.want.ID || got.Name != tc.want.Name) {
|
||||
t.Errorf("generateDynamicAlias(%q) = {%q, %q}, want {%q, %q}", tc.input, got.ID, got.Name, tc.want.ID, tc.want.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseDynamicAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{"claude-opus-4-7", "opus-4.7", true},
|
||||
{"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking", true},
|
||||
{"claude-opus-4-6", "opus-4.6", true},
|
||||
{"claude-opus-4.6", "", false},
|
||||
{"claude-haiku-4-5-20251001", "", false},
|
||||
{"some-model", "", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got, ok := reverseDynamicAlias(tc.input)
|
||||
if ok != tc.ok {
|
||||
t.Errorf("reverseDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok)
|
||||
continue
|
||||
}
|
||||
if ok && got != tc.want {
|
||||
t.Errorf("reverseDynamicAlias(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package openai
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
|
@ -129,10 +130,28 @@ func ToolsToSystemText(tools []interface{}, functions []interface{}) string {
|
|||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
||||
params = string(b)
|
||||
}
|
||||
} else if p := fn["input_schema"]; p != nil {
|
||||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
||||
params = string(b)
|
||||
}
|
||||
}
|
||||
lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params)
|
||||
}
|
||||
|
||||
lines = append(lines, "",
|
||||
"When you want to call a tool, use this EXACT format:",
|
||||
"",
|
||||
"<tool_call>",
|
||||
`{"name": "function_name", "arguments": {"param1": "value1"}}`,
|
||||
"</tool_call>",
|
||||
"",
|
||||
"Rules:",
|
||||
"- Write your reasoning BEFORE the tool call",
|
||||
"- You may make multiple tool calls by using multiple <tool_call> blocks",
|
||||
"- STOP writing after the last </tool_call> tag",
|
||||
"- If no tool is needed, respond normally without <tool_call> tags",
|
||||
)
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
|
|
@ -152,18 +171,58 @@ func BuildPromptFromMessages(messages []interface{}) string {
|
|||
}
|
||||
role, _ := m["role"].(string)
|
||||
text := MessageContentToText(m["content"])
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
switch role {
|
||||
case "system", "developer":
|
||||
systemParts = append(systemParts, text)
|
||||
if text != "" {
|
||||
systemParts = append(systemParts, text)
|
||||
}
|
||||
case "user":
|
||||
convo = append(convo, "User: "+text)
|
||||
if text != "" {
|
||||
convo = append(convo, "User: "+text)
|
||||
}
|
||||
case "assistant":
|
||||
convo = append(convo, "Assistant: "+text)
|
||||
toolCalls, _ := m["tool_calls"].([]interface{})
|
||||
if len(toolCalls) > 0 {
|
||||
var parts []string
|
||||
if text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
tcMap, ok := tc.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tcMap["function"].(map[string]interface{})
|
||||
if fn == nil {
|
||||
continue
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
args, _ := fn["arguments"].(string)
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, args))
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
convo = append(convo, "Assistant: "+strings.Join(parts, "\n"))
|
||||
}
|
||||
} else if text != "" {
|
||||
convo = append(convo, "Assistant: "+text)
|
||||
}
|
||||
case "tool", "function":
|
||||
convo = append(convo, "Tool: "+text)
|
||||
name, _ := m["name"].(string)
|
||||
toolCallID, _ := m["tool_call_id"].(string)
|
||||
label := "Tool result"
|
||||
if name != "" {
|
||||
label = "Tool result (" + name + ")"
|
||||
}
|
||||
if toolCallID != "" {
|
||||
label += " [id=" + toolCallID + "]"
|
||||
}
|
||||
if text != "" {
|
||||
convo = append(convo, label+": "+text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,24 @@ import "encoding/json"
|
|||
|
||||
type StreamParser func(line string)
|
||||
|
||||
func CreateStreamParser(onText func(string), onDone func()) StreamParser {
|
||||
accumulated := ""
|
||||
type Parser struct {
|
||||
Parse StreamParser
|
||||
Flush func()
|
||||
}
|
||||
|
||||
// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking)
|
||||
func CreateStreamParser(onText func(string), onDone func()) Parser {
|
||||
return CreateStreamParserWithThinking(onText, nil, onDone)
|
||||
}
|
||||
|
||||
// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。
|
||||
// onThinking 可為 nil,表示忽略思考過程。
|
||||
func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser {
|
||||
accumulatedText := ""
|
||||
accumulatedThinking := ""
|
||||
done := false
|
||||
|
||||
return func(line string) {
|
||||
parse := func(line string) {
|
||||
if done {
|
||||
return
|
||||
}
|
||||
|
|
@ -18,8 +31,9 @@ func CreateStreamParser(onText func(string), onDone func()) StreamParser {
|
|||
Subtype string `json:"subtype"`
|
||||
Message *struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Thinking string `json:"thinking"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
}
|
||||
|
|
@ -29,27 +43,52 @@ func CreateStreamParser(onText func(string), onDone func()) StreamParser {
|
|||
}
|
||||
|
||||
if obj.Type == "assistant" && obj.Message != nil {
|
||||
text := ""
|
||||
fullText := ""
|
||||
fullThinking := ""
|
||||
for _, p := range obj.Message.Content {
|
||||
if p.Type == "text" && p.Text != "" {
|
||||
text += p.Text
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
fullText += p.Text
|
||||
}
|
||||
case "thinking":
|
||||
if p.Thinking != "" {
|
||||
fullThinking += p.Thinking
|
||||
}
|
||||
}
|
||||
}
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
if text == accumulated {
|
||||
return
|
||||
}
|
||||
if len(accumulated) > 0 && len(text) > len(accumulated) && text[:len(accumulated)] == accumulated {
|
||||
delta := text[len(accumulated):]
|
||||
if delta != "" {
|
||||
onText(delta)
|
||||
|
||||
// 處理思考過程 delta
|
||||
if onThinking != nil && fullThinking != "" {
|
||||
if fullThinking == accumulatedThinking {
|
||||
// 重複的完整思考文字,跳過
|
||||
} else if len(fullThinking) > len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking {
|
||||
delta := fullThinking[len(accumulatedThinking):]
|
||||
onThinking(delta)
|
||||
accumulatedThinking = fullThinking
|
||||
} else {
|
||||
onThinking(fullThinking)
|
||||
accumulatedThinking += fullThinking
|
||||
}
|
||||
accumulated = text
|
||||
}
|
||||
|
||||
// 處理一般文字 delta
|
||||
if fullText == "" {
|
||||
return
|
||||
}
|
||||
// 若此訊息文字等於已累積內容(重複的完整文字),跳過
|
||||
if fullText == accumulatedText {
|
||||
return
|
||||
}
|
||||
// 若此訊息是已累積內容的延伸,只輸出新的 delta
|
||||
if len(fullText) > len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText {
|
||||
delta := fullText[len(accumulatedText):]
|
||||
onText(delta)
|
||||
accumulatedText = fullText
|
||||
} else {
|
||||
onText(text)
|
||||
accumulated += text
|
||||
// 獨立的 token fragment(一般情況),直接輸出
|
||||
onText(fullText)
|
||||
accumulatedText += fullText
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -58,4 +97,13 @@ func CreateStreamParser(onText func(string), onDone func()) StreamParser {
|
|||
onDone()
|
||||
}
|
||||
}
|
||||
|
||||
flush := func() {
|
||||
if !done {
|
||||
done = true
|
||||
onDone()
|
||||
}
|
||||
}
|
||||
|
||||
return Parser{Parse: parse, Flush: flush}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,57 +23,57 @@ func makeResultLine() string {
|
|||
return string(b)
|
||||
}
|
||||
|
||||
func TestStreamParserIncrementalDeltas(t *testing.T) {
|
||||
func TestStreamParserFragmentMode(t *testing.T) {
|
||||
// cursor --stream-partial-output 模式:每個訊息是獨立 token fragment
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
parse := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
parse(makeAssistantLine("Hello"))
|
||||
if len(texts) != 1 || texts[0] != "Hello" {
|
||||
t.Fatalf("expected [Hello], got %v", texts)
|
||||
}
|
||||
|
||||
parse(makeAssistantLine("Hello world"))
|
||||
if len(texts) != 2 || texts[1] != " world" {
|
||||
t.Fatalf("expected second call with ' world', got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserDeduplicatesFinalMessage(t *testing.T) {
|
||||
var texts []string
|
||||
parse := CreateStreamParser(
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
parse(makeAssistantLine("Hi"))
|
||||
parse(makeAssistantLine("Hi there"))
|
||||
if len(texts) != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d: %v", len(texts), texts)
|
||||
p.Parse(makeAssistantLine("你"))
|
||||
p.Parse(makeAssistantLine("好!有"))
|
||||
p.Parse(makeAssistantLine("什"))
|
||||
p.Parse(makeAssistantLine("麼"))
|
||||
|
||||
if len(texts) != 4 {
|
||||
t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts)
|
||||
}
|
||||
if texts[0] != "Hi" || texts[1] != " there" {
|
||||
if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" {
|
||||
t.Fatalf("unexpected texts: %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserDeduplicatesFinalFullText(t *testing.T) {
|
||||
// 最後一個訊息是完整的累積文字,應被跳過(去重)
|
||||
var texts []string
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeAssistantLine("Hello"))
|
||||
p.Parse(makeAssistantLine(" world"))
|
||||
// 最後一個是完整累積文字,應被去重
|
||||
p.Parse(makeAssistantLine("Hello world"))
|
||||
|
||||
// Final duplicate: full accumulated text again
|
||||
parse(makeAssistantLine("Hi there"))
|
||||
if len(texts) != 2 {
|
||||
t.Fatalf("expected no new call after duplicate, got %d: %v", len(texts), texts)
|
||||
t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts)
|
||||
}
|
||||
if texts[0] != "Hello" || texts[1] != " world" {
|
||||
t.Fatalf("unexpected texts: %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserCallsOnDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
parse := CreateStreamParser(
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
parse(makeResultLine())
|
||||
p.Parse(makeResultLine())
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
|
|
@ -85,13 +85,13 @@ func TestStreamParserCallsOnDone(t *testing.T) {
|
|||
func TestStreamParserIgnoresLinesAfterDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
parse := CreateStreamParser(
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
parse(makeResultLine())
|
||||
parse(makeAssistantLine("late"))
|
||||
p.Parse(makeResultLine())
|
||||
p.Parse(makeAssistantLine("late"))
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no text after done, got %v", texts)
|
||||
}
|
||||
|
|
@ -102,19 +102,19 @@ func TestStreamParserIgnoresLinesAfterDone(t *testing.T) {
|
|||
|
||||
func TestStreamParserIgnoresNonAssistantLines(t *testing.T) {
|
||||
var texts []string
|
||||
parse := CreateStreamParser(
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}})
|
||||
parse(string(b1))
|
||||
p.Parse(string(b1))
|
||||
b2, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{"content": []interface{}{}},
|
||||
})
|
||||
parse(string(b2))
|
||||
parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`)
|
||||
p.Parse(string(b2))
|
||||
p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`)
|
||||
|
||||
if len(texts) != 0 {
|
||||
t.Fatalf("expected no texts, got %v", texts)
|
||||
|
|
@ -124,14 +124,14 @@ func TestStreamParserIgnoresNonAssistantLines(t *testing.T) {
|
|||
func TestStreamParserIgnoresParseErrors(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
parse := CreateStreamParser(
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
parse("not json")
|
||||
parse("{")
|
||||
parse("")
|
||||
p.Parse("not json")
|
||||
p.Parse("{")
|
||||
p.Parse("")
|
||||
|
||||
if len(texts) != 0 || doneCount != 0 {
|
||||
t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount)
|
||||
|
|
@ -140,7 +140,7 @@ func TestStreamParserIgnoresParseErrors(t *testing.T) {
|
|||
|
||||
func TestStreamParserJoinsMultipleTextParts(t *testing.T) {
|
||||
var texts []string
|
||||
parse := CreateStreamParser(
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() {},
|
||||
)
|
||||
|
|
@ -156,9 +156,125 @@ func TestStreamParserJoinsMultipleTextParts(t *testing.T) {
|
|||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
parse(string(b))
|
||||
p.Parse(string(b))
|
||||
|
||||
if len(texts) != 1 || texts[0] != "Hello world" {
|
||||
t.Fatalf("expected ['Hello world'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserFlushTriggersDone(t *testing.T) {
|
||||
var texts []string
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeAssistantLine("Hello"))
|
||||
// agent 結束但沒有 result/success,手動 flush
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once after Flush, got %d", doneCount)
|
||||
}
|
||||
// 再 flush 不應重複觸發
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called only once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) {
|
||||
doneCount := 0
|
||||
p := CreateStreamParser(
|
||||
func(text string) {},
|
||||
func() { doneCount++ },
|
||||
)
|
||||
|
||||
p.Parse(makeResultLine())
|
||||
p.Flush()
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func makeThinkingLine(thinking string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "thinking", "thinking": thinking},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func makeThinkingAndTextLine(thinking, text string) string {
|
||||
obj := map[string]interface{}{
|
||||
"type": "assistant",
|
||||
"message": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "thinking", "thinking": thinking},
|
||||
{"type": "text", "text": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(obj)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) {
|
||||
var texts []string
|
||||
var thinkings []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("思考中..."))
|
||||
p.Parse(makeAssistantLine("回答"))
|
||||
|
||||
if len(thinkings) != 1 || thinkings[0] != "思考中..." {
|
||||
t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings)
|
||||
}
|
||||
if len(texts) != 1 || texts[0] != "回答" {
|
||||
t.Fatalf("expected texts=['回答'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) {
|
||||
var texts []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) { texts = append(texts, text) },
|
||||
nil,
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("忽略的思考"))
|
||||
p.Parse(makeAssistantLine("文字"))
|
||||
|
||||
if len(texts) != 1 || texts[0] != "文字" {
|
||||
t.Fatalf("expected texts=['文字'], got %v", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParserWithThinkingDeduplication(t *testing.T) {
|
||||
var thinkings []string
|
||||
p := CreateStreamParserWithThinking(
|
||||
func(text string) {},
|
||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
p.Parse(makeThinkingLine("A"))
|
||||
p.Parse(makeThinkingLine("B"))
|
||||
// 重複的完整思考,應被跳過
|
||||
p.Parse(makeThinkingLine("AB"))
|
||||
|
||||
if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" {
|
||||
t.Fatalf("expected thinkings=['A','B'], got %v", thinkings)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -180,6 +180,31 @@ func (p *AccountPool) Count() int {
|
|||
return len(p.accounts)
|
||||
}
|
||||
|
||||
|
||||
// ─── PoolHandle interface ──────────────────────────────────────────────────
|
||||
// PoolHandle 讓 handler 可以注入獨立的 pool 實例,避免多 port 模式共用全域 pool。
|
||||
|
||||
type PoolHandle interface {
|
||||
GetNextConfigDir() string
|
||||
ReportRequestStart(configDir string)
|
||||
ReportRequestEnd(configDir string)
|
||||
ReportRequestSuccess(configDir string, latencyMs int64)
|
||||
ReportRequestError(configDir string, latencyMs int64)
|
||||
ReportRateLimit(configDir string, penaltyMs int64)
|
||||
GetStats() []AccountStat
|
||||
}
|
||||
|
||||
// GlobalPoolHandle 包裝全域函式以實作 PoolHandle 介面(單 port 模式使用)
|
||||
type GlobalPoolHandle struct{}
|
||||
|
||||
func (GlobalPoolHandle) GetNextConfigDir() string { return GetNextAccountConfigDir() }
|
||||
func (GlobalPoolHandle) ReportRequestStart(d string) { ReportRequestStart(d) }
|
||||
func (GlobalPoolHandle) ReportRequestEnd(d string) { ReportRequestEnd(d) }
|
||||
func (GlobalPoolHandle) ReportRequestSuccess(d string, l int64) { ReportRequestSuccess(d, l) }
|
||||
func (GlobalPoolHandle) ReportRequestError(d string, l int64) { ReportRequestError(d, l) }
|
||||
func (GlobalPoolHandle) ReportRateLimit(d string, p int64) { ReportRateLimit(d, p) }
|
||||
func (GlobalPoolHandle) GetStats() []AccountStat { return GetAccountStats() }
|
||||
|
||||
// ─── Global pool ───────────────────────────────────────────────────────────
|
||||
|
||||
var (
|
||||
|
|
|
|||
|
|
@ -215,6 +215,7 @@ func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (Strea
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) != "" {
|
||||
|
|
@ -227,6 +228,7 @@ func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (Strea
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
stderrBuf.WriteString(scanner.Text())
|
||||
stderrBuf.WriteString("\n")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"cursor-api-proxy/internal/handlers"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/pool"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
|
@ -16,6 +17,7 @@ type RouterOptions struct {
|
|||
Config config.BridgeConfig
|
||||
ModelCache *handlers.ModelCacheRef
|
||||
LastModel *string
|
||||
Pool pool.PoolHandle
|
||||
}
|
||||
|
||||
func NewRouter(opts RouterOptions) http.HandlerFunc {
|
||||
|
|
@ -59,7 +61,7 @@ func NewRouter(opts RouterOptions) http.HandlerFunc {
|
|||
}, nil)
|
||||
return
|
||||
}
|
||||
handlers.HandleChatCompletions(w, r, cfg, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
|
||||
case method == "POST" && pathname == "/v1/messages":
|
||||
raw, err := httputil.ReadBody(r)
|
||||
|
|
@ -69,7 +71,7 @@ func NewRouter(opts RouterOptions) http.HandlerFunc {
|
|||
}, nil)
|
||||
return
|
||||
}
|
||||
handlers.HandleAnthropicMessages(w, r, cfg, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
handlers.HandleAnthropicMessages(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
|
||||
case (method == "POST" || method == "GET") && pathname == "/v1/completions":
|
||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
type ServerOptions struct {
|
||||
Version string
|
||||
Config config.BridgeConfig
|
||||
Pool pool.PoolHandle
|
||||
}
|
||||
|
||||
func StartBridgeServer(opts ServerOptions) []*http.Server {
|
||||
|
|
@ -34,8 +35,8 @@ func StartBridgeServer(opts ServerOptions) []*http.Server {
|
|||
subCfg.Port = port
|
||||
subCfg.ConfigDirs = []string{dir}
|
||||
subCfg.MultiPort = false
|
||||
pool.InitAccountPool([]string{dir})
|
||||
srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg})
|
||||
subPool := pool.NewAccountPool([]string{dir})
|
||||
srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg, Pool: subPool})
|
||||
servers = append(servers, srv)
|
||||
}
|
||||
return servers
|
||||
|
|
@ -53,11 +54,16 @@ func startSingleServer(opts ServerOptions) *http.Server {
|
|||
modelCache := &handlers.ModelCacheRef{}
|
||||
lastModel := cfg.DefaultModel
|
||||
|
||||
ph := opts.Pool
|
||||
if ph == nil {
|
||||
ph = pool.GlobalPoolHandle{}
|
||||
}
|
||||
handler := router.NewRouter(router.RouterOptions{
|
||||
Version: opts.Version,
|
||||
Config: cfg,
|
||||
ModelCache: modelCache,
|
||||
LastModel: &lastModel,
|
||||
Pool: ph,
|
||||
})
|
||||
handler = router.WrapWithRecovery(cfg.SessionsLogPath, handler)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,153 @@
|
|||
package toolcall
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ToolCall struct {
|
||||
Name string
|
||||
Arguments string // JSON string
|
||||
}
|
||||
|
||||
type ParsedResponse struct {
|
||||
TextContent string
|
||||
ToolCalls []ToolCall
|
||||
}
|
||||
|
||||
func (p *ParsedResponse) HasToolCalls() bool {
|
||||
return len(p.ToolCalls) > 0
|
||||
}
|
||||
|
||||
var toolCallTagRe = regexp.MustCompile(`(?s)<tool_call>\s*(\{.*?\})\s*</tool_call>`)
|
||||
var antmlFunctionCallsRe = regexp.MustCompile(`(?s)<function_calls>\s*(.*?)\s*</function_calls>`)
|
||||
var antmlInvokeRe = regexp.MustCompile(`(?s)<invoke\s+name="([^"]+)">\s*(.*?)\s*</invoke>`)
|
||||
var antmlParamRe = regexp.MustCompile(`(?s)<parameter\s+name="([^"]+)">(.*?)</parameter>`)
|
||||
|
||||
func ExtractToolCalls(text string, toolNames map[string]bool) *ParsedResponse {
|
||||
result := &ParsedResponse{}
|
||||
|
||||
if locs := toolCallTagRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
|
||||
var calls []ToolCall
|
||||
var textParts []string
|
||||
last := 0
|
||||
for _, loc := range locs {
|
||||
if loc[0] > last {
|
||||
textParts = append(textParts, text[last:loc[0]])
|
||||
}
|
||||
jsonStr := text[loc[2]:loc[3]]
|
||||
if tc := parseToolCallJSON(jsonStr, toolNames); tc != nil {
|
||||
calls = append(calls, *tc)
|
||||
} else {
|
||||
textParts = append(textParts, text[loc[0]:loc[1]])
|
||||
}
|
||||
last = loc[1]
|
||||
}
|
||||
if last < len(text) {
|
||||
textParts = append(textParts, text[last:])
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
|
||||
result.ToolCalls = calls
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
if locs := antmlFunctionCallsRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
|
||||
var calls []ToolCall
|
||||
var textParts []string
|
||||
last := 0
|
||||
for _, loc := range locs {
|
||||
if loc[0] > last {
|
||||
textParts = append(textParts, text[last:loc[0]])
|
||||
}
|
||||
block := text[loc[2]:loc[3]]
|
||||
invokes := antmlInvokeRe.FindAllStringSubmatch(block, -1)
|
||||
for _, inv := range invokes {
|
||||
name := inv[1]
|
||||
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
|
||||
continue
|
||||
}
|
||||
body := inv[2]
|
||||
args := map[string]interface{}{}
|
||||
params := antmlParamRe.FindAllStringSubmatch(body, -1)
|
||||
for _, p := range params {
|
||||
paramName := p[1]
|
||||
paramValue := strings.TrimSpace(p[2])
|
||||
var jsonVal interface{}
|
||||
if err := json.Unmarshal([]byte(paramValue), &jsonVal); err == nil {
|
||||
args[paramName] = jsonVal
|
||||
} else {
|
||||
args[paramName] = paramValue
|
||||
}
|
||||
}
|
||||
argsJSON, _ := json.Marshal(args)
|
||||
calls = append(calls, ToolCall{Name: name, Arguments: string(argsJSON)})
|
||||
}
|
||||
last = loc[1]
|
||||
}
|
||||
if last < len(text) {
|
||||
textParts = append(textParts, text[last:])
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
|
||||
result.ToolCalls = calls
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
result.TextContent = text
|
||||
return result
|
||||
}
|
||||
|
||||
func parseToolCallJSON(jsonStr string, toolNames map[string]bool) *ToolCall {
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
name, _ := raw["name"].(string)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
|
||||
return nil
|
||||
}
|
||||
var argsStr string
|
||||
switch a := raw["arguments"].(type) {
|
||||
case string:
|
||||
argsStr = a
|
||||
case map[string]interface{}, []interface{}:
|
||||
b, _ := json.Marshal(a)
|
||||
argsStr = string(b)
|
||||
default:
|
||||
if p, ok := raw["parameters"]; ok {
|
||||
b, _ := json.Marshal(p)
|
||||
argsStr = string(b)
|
||||
} else {
|
||||
argsStr = "{}"
|
||||
}
|
||||
}
|
||||
return &ToolCall{Name: name, Arguments: argsStr}
|
||||
}
|
||||
|
||||
func CollectToolNames(tools []interface{}) map[string]bool {
|
||||
names := map[string]bool{}
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if m["type"] == "function" {
|
||||
if fn, ok := m["function"].(map[string]interface{}); ok {
|
||||
if name, ok := fn["name"].(string); ok {
|
||||
names[name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if name, ok := m["name"].(string); ok {
|
||||
names[name] = true
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
|
@ -6,6 +6,13 @@ import (
|
|||
)
|
||||
|
||||
const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n"
|
||||
const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n"
|
||||
|
||||
// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux.
|
||||
// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars.
|
||||
func safeLinuxArgMax() int {
|
||||
return 1536 * 1024
|
||||
}
|
||||
|
||||
type FitPromptResult struct {
|
||||
OK bool
|
||||
|
|
@ -41,16 +48,7 @@ func estimateCmdlineLength(resolved env.AgentCommand) int {
|
|||
|
||||
func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult {
|
||||
if runtime.GOOS != "windows" {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = prompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: false,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(prompt),
|
||||
}
|
||||
return fitPromptLinux(fixedArgs, prompt)
|
||||
}
|
||||
|
||||
e := env.OsEnvToMap()
|
||||
|
|
@ -125,8 +123,59 @@ func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, m
|
|||
}
|
||||
}
|
||||
|
||||
// fitPromptLinux handles Linux ARG_MAX truncation.
|
||||
func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult {
|
||||
argMax := safeLinuxArgMax()
|
||||
|
||||
// Estimate total cmdline size: sum of all fixed args + prompt + null terminators
|
||||
fixedLen := 0
|
||||
for _, a := range fixedArgs {
|
||||
fixedLen += len(a) + 1
|
||||
}
|
||||
totalLen := fixedLen + len(prompt) + 1
|
||||
|
||||
if totalLen <= argMax {
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = prompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: false,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(prompt),
|
||||
}
|
||||
}
|
||||
|
||||
// Need to truncate: keep the tail of the prompt (most recent messages)
|
||||
prefix := LinuxPromptOmissionPrefix
|
||||
available := argMax - fixedLen - len(prefix) - 1
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
var finalPrompt string
|
||||
if available <= 0 {
|
||||
finalPrompt = prefix
|
||||
} else if available >= len(prompt) {
|
||||
finalPrompt = prefix + prompt
|
||||
} else {
|
||||
finalPrompt = prefix + prompt[len(prompt)-available:]
|
||||
}
|
||||
|
||||
args := make([]string, len(fixedArgs)+1)
|
||||
copy(args, fixedArgs)
|
||||
args[len(fixedArgs)] = finalPrompt
|
||||
return FitPromptResult{
|
||||
OK: true,
|
||||
Args: args,
|
||||
Truncated: true,
|
||||
OriginalLength: len(prompt),
|
||||
FinalPromptLength: len(finalPrompt),
|
||||
}
|
||||
}
|
||||
|
||||
func WarnPromptTruncated(originalLength, finalLength int) {
|
||||
_ = originalLength
|
||||
_ = finalLength
|
||||
// fmt.Fprintf skipped to avoid import; caller may log as needed
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue