remove old version
This commit is contained in:
parent
85f695250e
commit
bf18e7958e
319
Makefile
319
Makefile
|
|
@ -1,319 +0,0 @@
|
||||||
# ──────────────────────────────────────────────
|
|
||||||
# cursor-api-proxy — 設定與建置
|
|
||||||
# 編輯下方變數,然後執行 make env 產生 .env 檔
|
|
||||||
# ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
# ── 伺服器設定 ─────────────────────────────────
|
|
||||||
HOST ?= 127.0.0.1
|
|
||||||
PORT ?= 8766
|
|
||||||
API_KEY ?=
|
|
||||||
TIMEOUT_MS ?= 3600000
|
|
||||||
MULTI_PORT ?= false
|
|
||||||
VERBOSE ?= false
|
|
||||||
|
|
||||||
# ── Agent / 模型設定 ──────────────────────────
|
|
||||||
AGENT_BIN ?= agent
|
|
||||||
AGENT_NODE ?=
|
|
||||||
AGENT_SCRIPT ?=
|
|
||||||
DEFAULT_MODEL ?= auto
|
|
||||||
STRICT_MODEL ?= true
|
|
||||||
MAX_MODE ?= false
|
|
||||||
FORCE ?= false
|
|
||||||
APPROVE_MCPS ?= false
|
|
||||||
|
|
||||||
# ── 工作區與帳號 ──────────────────────────────
|
|
||||||
WORKSPACE ?=
|
|
||||||
CHAT_ONLY_WORKSPACE ?= true
|
|
||||||
CONFIG_DIRS ?=
|
|
||||||
|
|
||||||
# ── OpenCode 模型設定 ────────────────────
|
|
||||||
OPENCODE_MODEL ?= cursor/claude-4.6-sonnet-medium
|
|
||||||
OPENCODE_SMALL_MODEL ?= cursor/gpt-5.4-nano-medium
|
|
||||||
|
|
||||||
# ── Cursor / Claude Code(~/.claude)────────────────
|
|
||||||
CLAUDE_SETTINGS ?= $(HOME)/.claude/settings.json
|
|
||||||
CLAUDE_JSON ?= $(HOME)/.claude.json
|
|
||||||
ANTHROPIC_AUTH_TOKEN ?=
|
|
||||||
ANTHROPIC_DEFAULT_SONNET_MODEL ?= claude-4.6-sonnet-medium
|
|
||||||
ANTHROPIC_DEFAULT_OPUS_MODEL ?= claude-4.6-opus-max
|
|
||||||
ANTHROPIC_DEFAULT_HAIKU_MODEL ?= gemini-3-flash
|
|
||||||
ANTHROPIC_BASE_HOST ?= $(HOST)
|
|
||||||
|
|
||||||
# ── TLS / HTTPS ───────────────────────────────
|
|
||||||
TLS_CERT ?=
|
|
||||||
TLS_KEY ?=
|
|
||||||
|
|
||||||
# ── Gemini Web Provider ───────────────────────
|
|
||||||
PROVIDER ?= cursor
|
|
||||||
GEMINI_ACCOUNT_DIR ?=
|
|
||||||
GEMINI_BROWSER_VISIBLE ?= false
|
|
||||||
GEMINI_MAX_SESSIONS ?= 3
|
|
||||||
|
|
||||||
# ── 記錄 ──────────────────────────────────────
|
|
||||||
SESSIONS_LOG ?=
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
ENV_FILE ?= .env
|
|
||||||
|
|
||||||
OPENCODE_CONFIG ?= $(HOME)/.config/opencode/opencode.json
|
|
||||||
|
|
||||||
# ── Docker 設定 ───────────────────────────────
|
|
||||||
DOCKER_IMAGE ?= cursor-api-proxy
|
|
||||||
DOCKER_TAG ?= latest
|
|
||||||
DOCKER_COMPOSE ?= docker compose
|
|
||||||
|
|
||||||
.PHONY: env run build clean help opencode opencode-models pm2 pm2-stop pm2-logs claude-code pm2-claude-code \
|
|
||||||
claude-settings claude-onboarding claude-cursor-setup \
|
|
||||||
docker-build docker-up docker-down docker-logs docker-restart docker-shell docker-env docker-setup
|
|
||||||
|
|
||||||
## 產生 .env 檔(預設輸出至 .env,可用 ENV_FILE=xxx 覆寫)
|
|
||||||
env:
|
|
||||||
@printf '# 由 make env 自動產生,請勿手動編輯\n' > $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_HOST=%s\n' "$(HOST)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_PORT=%s\n' "$(PORT)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_API_KEY=%s\n' "$(API_KEY)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_TIMEOUT_MS=%s\n' "$(TIMEOUT_MS)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_MULTI_PORT=%s\n' "$(MULTI_PORT)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_VERBOSE=%s\n' "$(VERBOSE)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_AGENT_BIN=%s\n' "$(AGENT_BIN)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_AGENT_NODE=%s\n' "$(AGENT_NODE)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_AGENT_SCRIPT=%s\n' "$(AGENT_SCRIPT)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_DEFAULT_MODEL=%s\n' "$(DEFAULT_MODEL)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_STRICT_MODEL=%s\n' "$(STRICT_MODEL)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_MAX_MODE=%s\n' "$(MAX_MODE)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_FORCE=%s\n' "$(FORCE)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_APPROVE_MCPS=%s\n' "$(APPROVE_MCPS)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_WORKSPACE=%s\n' "$(WORKSPACE)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE=%s\n' "$(CHAT_ONLY_WORKSPACE)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_CONFIG_DIRS=%s\n' "$(CONFIG_DIRS)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_TLS_CERT=%s\n' "$(TLS_CERT)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_TLS_KEY=%s\n' "$(TLS_KEY)" >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_SESSIONS_LOG=%s\n' "$(SESSIONS_LOG)" >> $(ENV_FILE)
|
|
||||||
@printf '# ── Provider 設定 ───────────────────────────\n' >> $(ENV_FILE)
|
|
||||||
@printf 'CURSOR_BRIDGE_PROVIDER=%s\n' "$(PROVIDER)" >> $(ENV_FILE)
|
|
||||||
@printf '# Gemini Web Provider 設定(當 PROVIDER=gemini-web 時使用)\n' >> $(ENV_FILE)
|
|
||||||
@printf 'GEMINI_ACCOUNT_DIR=%s\n' "$(GEMINI_ACCOUNT_DIR)" >> $(ENV_FILE)
|
|
||||||
@printf 'GEMINI_BROWSER_VISIBLE=%s\n' "$(GEMINI_BROWSER_VISIBLE)" >> $(ENV_FILE)
|
|
||||||
@printf 'GEMINI_MAX_SESSIONS=%s\n' "$(GEMINI_MAX_SESSIONS)" >> $(ENV_FILE)
|
|
||||||
@echo "已產生 $(ENV_FILE)"
|
|
||||||
|
|
||||||
## 編譯二進位檔
|
|
||||||
build:
|
|
||||||
go build -o cursor-api-proxy .
|
|
||||||
|
|
||||||
## 載入 .env 後直接執行(需先執行 make env 或已有 .env)
|
|
||||||
run: build
|
|
||||||
@if [ -f $(ENV_FILE) ]; then \
|
|
||||||
set -a && . ./$(ENV_FILE) && set +a && ./cursor-api-proxy; \
|
|
||||||
else \
|
|
||||||
echo "找不到 $(ENV_FILE),請先執行 make env"; exit 1; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
## 清除產出物
|
|
||||||
clean:
|
|
||||||
rm -f cursor-api-proxy $(ENV_FILE)
|
|
||||||
|
|
||||||
## 設定 OpenCode 使用此代理(更新 opencode.json 的 cursor 與 gemini-web provider)
|
|
||||||
opencode: build
|
|
||||||
@if [ ! -f "$(OPENCODE_CONFIG)" ]; then \
|
|
||||||
echo "找不到 $(OPENCODE_CONFIG),建立新設定檔"; \
|
|
||||||
mkdir -p $$(dirname "$(OPENCODE_CONFIG)"); \
|
|
||||||
printf '{\n "model": "$(OPENCODE_MODEL)",\n "small_model": "$(OPENCODE_SMALL_MODEL)",\n "provider": {\n "cursor": {\n "npm": "@ai-sdk/openai-compatible",\n "name": "Cursor Agent",\n "options": {\n "baseURL": "http://$(HOST):$(PORT)/v1",\n "apiKey": "unused"\n },\n "models": { "auto": { "name": "Cursor Auto" } }\n },\n "gemini-web": {\n "npm": "@ai-sdk/openai-compatible",\n "name": "Gemini Web",\n "options": {\n "baseURL": "http://$(HOST):$(PORT)/v1",\n "apiKey": "unused"\n },\n "models": {\n "gemini-2.0-flash": { "name": "Gemini 2.0 Flash" },\n "gemini-2.5-pro": { "name": "Gemini 2.5 Pro" },\n "gemini-2.5-pro-thinking": { "name": "Gemini 2.5 Pro Thinking" }\n }\n }\n }\n}\n' > "$(OPENCODE_CONFIG)"; \
|
|
||||||
echo "已建立 $(OPENCODE_CONFIG)(包含 cursor 與 gemini-web provider)"; \
|
|
||||||
elif [ -n "$(API_KEY)" ]; then \
|
|
||||||
jq --arg model "$(OPENCODE_MODEL)" --arg small "$(OPENCODE_SMALL_MODEL)" --arg base "http://$(HOST):$(PORT)/v1" --arg key "$(API_KEY)" '.model = $$model | .small_model = $$small | .provider.cursor.options.baseURL = $$base | .provider.cursor.options.apiKey = $$key | .provider["gemini-web"].options.baseURL = $$base | .provider["gemini-web"].options.apiKey = $$key' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \
|
|
||||||
echo "已更新 $(OPENCODE_CONFIG)(model=$(OPENCODE_MODEL), small_model=$(OPENCODE_SMALL_MODEL), baseURL → http://$(HOST):$(PORT)/v1,apiKey 已設定)"; \
|
|
||||||
else \
|
|
||||||
jq --arg model "$(OPENCODE_MODEL)" --arg small "$(OPENCODE_SMALL_MODEL)" --arg base "http://$(HOST):$(PORT)/v1" '.model = $$model | .small_model = $$small | .provider.cursor.options.baseURL = $$base | .provider["gemini-web"].options.baseURL = $$base' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \
|
|
||||||
echo "已更新 $(OPENCODE_CONFIG)(model=$(OPENCODE_MODEL), small_model=$(OPENCODE_SMALL_MODEL), baseURL → http://$(HOST):$(PORT)/v1)"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
## 啟動代理並用 curl 同步模型列表到 opencode.json
|
|
||||||
opencode-models: opencode
|
|
||||||
@echo "啟動代理以取得模型列表..."
|
|
||||||
@set -a && . ./$(ENV_FILE) 2>/dev/null; set +a; \
|
|
||||||
./cursor-api-proxy & PID=$$!; \
|
|
||||||
sleep 2; \
|
|
||||||
MODELS=$$(curl -s http://$(HOST):$(PORT)/v1/models | jq '[.data[].id]'); \
|
|
||||||
kill $$PID 2>/dev/null; wait $$PID 2>/dev/null; \
|
|
||||||
if [ -n "$$MODELS" ] && [ "$$MODELS" != "null" ]; then \
|
|
||||||
jq --argjson ids "$$MODELS" 'reduce $ids[] as $id (.; .provider.cursor.models[$id] = { name: $id } | .provider["gemini-web"].models[$id] = { name: $id })' "$(OPENCODE_CONFIG)" > "$(OPENCODE_CONFIG).tmp" && mv "$(OPENCODE_CONFIG).tmp" "$(OPENCODE_CONFIG)"; \
|
|
||||||
echo "已同步模型列表到 $(OPENCODE_CONFIG)(cursor 與 gemini-web)"; \
|
|
||||||
else \
|
|
||||||
echo "無法取得模型列表,請確認代理已啟動"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
## 編譯並用 pm2 啟動
|
|
||||||
pm2: build
|
|
||||||
@if [ -f "$(ENV_FILE)" ]; then \
|
|
||||||
env $$(cat $(ENV_FILE) | grep -v '^#' | xargs) CURSOR_BRIDGE_HOST=$(HOST) CURSOR_BRIDGE_PORT=$(PORT) pm2 start ./cursor-api-proxy --name cursor-api-proxy --update-env; \
|
|
||||||
else \
|
|
||||||
CURSOR_BRIDGE_HOST=$(HOST) CURSOR_BRIDGE_PORT=$(PORT) pm2 start ./cursor-api-proxy --name cursor-api-proxy; \
|
|
||||||
fi
|
|
||||||
@pm2 save
|
|
||||||
@echo "pm2 已啟動 cursor-api-proxy(http://$(HOST):$(PORT))"
|
|
||||||
|
|
||||||
## 用 pm2 啟動 OpenCode 代理(設定 + 啟動一步完成)
|
|
||||||
pm2-opencode: opencode pm2
|
|
||||||
@echo "OpenCode 設定已更新並用 pm2 啟動代理"
|
|
||||||
|
|
||||||
## 寫入 ~/.claude/settings.json(ANTHROPIC_BASE_URL、三個 DEFAULT_* 模型;需 jq)
|
|
||||||
claude-settings:
|
|
||||||
@command -v jq >/dev/null 2>&1 || { echo "需要 jq"; exit 1; }
|
|
||||||
@mkdir -p $$(dirname "$(CLAUDE_SETTINGS)")
|
|
||||||
@jq -n \
|
|
||||||
--arg base "http://$(ANTHROPIC_BASE_HOST):$(PORT)" \
|
|
||||||
--arg token "$(ANTHROPIC_AUTH_TOKEN)" \
|
|
||||||
--arg sonnet "$(ANTHROPIC_DEFAULT_SONNET_MODEL)" \
|
|
||||||
--arg opus "$(ANTHROPIC_DEFAULT_OPUS_MODEL)" \
|
|
||||||
--arg haiku "$(ANTHROPIC_DEFAULT_HAIKU_MODEL)" \
|
|
||||||
'{ env: { ANTHROPIC_BASE_URL: $$base, ANTHROPIC_AUTH_TOKEN: $$token, ANTHROPIC_DEFAULT_SONNET_MODEL: $$sonnet, ANTHROPIC_DEFAULT_OPUS_MODEL: $$opus, ANTHROPIC_DEFAULT_HAIKU_MODEL: $$haiku } }' \
|
|
||||||
> "$(CLAUDE_SETTINGS).tmp" && mv "$(CLAUDE_SETTINGS).tmp" "$(CLAUDE_SETTINGS)"
|
|
||||||
@echo "已寫入 $(CLAUDE_SETTINGS)(BASE_URL=http://$(ANTHROPIC_BASE_HOST):$(PORT))"
|
|
||||||
|
|
||||||
## 將 ~/.claude.json 的 hasCompletedOnboarding 設為 true(繞過初次引導;需 jq)
|
|
||||||
claude-onboarding:
|
|
||||||
@command -v jq >/dev/null 2>&1 || { echo "需要 jq"; exit 1; }
|
|
||||||
@test -f "$(CLAUDE_JSON)" || { echo "找不到 $(CLAUDE_JSON)"; exit 1; }
|
|
||||||
@jq '.hasCompletedOnboarding = true' "$(CLAUDE_JSON)" > "$(CLAUDE_JSON).tmp" && mv "$(CLAUDE_JSON).tmp" "$(CLAUDE_JSON)"
|
|
||||||
@echo "已設定 $(CLAUDE_JSON) hasCompletedOnboarding=true"
|
|
||||||
|
|
||||||
## 一次執行 claude-settings + claude-onboarding
|
|
||||||
claude-cursor-setup: claude-settings claude-onboarding
|
|
||||||
@echo "Cursor/Claude Code 本機設定已套用"
|
|
||||||
|
|
||||||
## 編譯並用 pm2 啟動 + 設定 Claude Code 環境變數
|
|
||||||
pm2-claude-code: pm2
|
|
||||||
@echo ""
|
|
||||||
@echo "Claude Code 設定:將以下指令加入你的 shell 啟動檔(~/.bashrc 或 ~/.zshrc):"
|
|
||||||
@echo ""
|
|
||||||
@echo " export ANTHROPIC_BASE_URL=http://$(HOST):$(PORT)"
|
|
||||||
@echo " export ANTHROPIC_API_KEY=$(if $(API_KEY),$(API_KEY),dummy-key)"
|
|
||||||
@echo ""
|
|
||||||
@echo "或在當前 shell 執行:"
|
|
||||||
@echo ""
|
|
||||||
@echo " export ANTHROPIC_BASE_URL=http://$(HOST):$(PORT)"
|
|
||||||
@echo " export ANTHROPIC_API_KEY=$(if $(API_KEY),$(API_KEY),dummy-key)"
|
|
||||||
@echo " claude"
|
|
||||||
@echo ""
|
|
||||||
|
|
||||||
## 停止 pm2 中的代理
|
|
||||||
pm2-stop:
|
|
||||||
pm2 stop cursor-api-proxy 2>/dev/null || echo "cursor-api-proxy 未在執行"
|
|
||||||
|
|
||||||
## 查看 pm2 日誌
|
|
||||||
pm2-logs:
|
|
||||||
pm2 logs cursor-api-proxy
|
|
||||||
|
|
||||||
|
|
||||||
## ──────────────────────────────────────────────────
|
|
||||||
## Docker Compose 指令
|
|
||||||
## ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
## 複製 .env.example 為 .env 並自動偵測本機 agent 路徑(首次設定)
|
|
||||||
docker-env:
|
|
||||||
@if [ -f .env ]; then \
|
|
||||||
echo ".env 已存在,若要重置請手動刪除後再執行"; \
|
|
||||||
else \
|
|
||||||
cp .env.example .env; \
|
|
||||||
DETECTED_AGENT=$$(which agent 2>/dev/null || echo ""); \
|
|
||||||
if [ -n "$$DETECTED_AGENT" ]; then \
|
|
||||||
REAL_AGENT=$$(readlink -f "$$DETECTED_AGENT"); \
|
|
||||||
sed -i "s|^CURSOR_AGENT_HOST_BIN=.*|CURSOR_AGENT_HOST_BIN=$$REAL_AGENT|" .env; \
|
|
||||||
echo " 已偵測 agent: $$REAL_AGENT"; \
|
|
||||||
else \
|
|
||||||
echo " 警告:找不到 agent,請手動設定 CURSOR_AGENT_HOST_BIN"; \
|
|
||||||
fi; \
|
|
||||||
echo "已建立 .env,請確認設定後執行 make docker-up"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
## 建置 Docker 映像檔
|
|
||||||
docker-build:
|
|
||||||
$(DOCKER_COMPOSE) build
|
|
||||||
|
|
||||||
## 啟動 Docker Compose(背景執行)
|
|
||||||
docker-up:
|
|
||||||
@if [ ! -f .env ]; then \
|
|
||||||
echo "找不到 .env,請先執行:make docker-env"; exit 1; \
|
|
||||||
fi
|
|
||||||
$(DOCKER_COMPOSE) up -d
|
|
||||||
@echo "cursor-api-proxy 已啟動(http://0.0.0.0:$(PORT))"
|
|
||||||
|
|
||||||
## 首次設定並啟動(複製 .env + build + up,一步完成)
|
|
||||||
docker-setup:
|
|
||||||
@if [ ! -f .env ]; then \
|
|
||||||
cp .env.example .env; \
|
|
||||||
echo "已建立 .env,請先編輯填入必要設定(CURSOR_AGENT_HOST_BIN、CURSOR_ACCOUNTS_DIR),然後重新執行 make docker-setup"; \
|
|
||||||
exit 1; \
|
|
||||||
fi
|
|
||||||
$(DOCKER_COMPOSE) build
|
|
||||||
$(DOCKER_COMPOSE) up -d
|
|
||||||
@echo "cursor-api-proxy 已啟動(http://0.0.0.0:$(PORT))"
|
|
||||||
@echo "查看日誌:make docker-logs"
|
|
||||||
|
|
||||||
## 停止並移除容器
|
|
||||||
docker-down:
|
|
||||||
$(DOCKER_COMPOSE) down
|
|
||||||
|
|
||||||
## 查看容器日誌(即時跟蹤)
|
|
||||||
docker-logs:
|
|
||||||
$(DOCKER_COMPOSE) logs -f cursor-api-proxy
|
|
||||||
|
|
||||||
## 重新建置並啟動容器
|
|
||||||
docker-restart:
|
|
||||||
$(DOCKER_COMPOSE) down
|
|
||||||
$(DOCKER_COMPOSE) build
|
|
||||||
$(DOCKER_COMPOSE) up -d
|
|
||||||
@echo "cursor-api-proxy 已重新啟動"
|
|
||||||
|
|
||||||
## 進入容器 shell(除錯用)
|
|
||||||
docker-shell:
|
|
||||||
$(DOCKER_COMPOSE) exec cursor-api-proxy sh
|
|
||||||
|
|
||||||
## 顯示說明
|
|
||||||
help:
|
|
||||||
@echo "可用目標:"
|
|
||||||
@echo " make env 產生 .env(先在 Makefile 頂端填好變數)"
|
|
||||||
@echo " make build 編譯 cursor-api-proxy 二進位檔"
|
|
||||||
@echo " make run 編譯並載入 .env 執行"
|
|
||||||
@echo " make pm2 編譯並用 pm2 啟動代理"
|
|
||||||
@echo " make pm2-stop 停止 pm2 中的代理"
|
|
||||||
@echo " make pm2-logs 查看 pm2 日誌"
|
|
||||||
@echo " make pm2-claude-code 啟動代理 + 輸出 Claude Code 設定指令"
|
|
||||||
@echo " make opencode 編譯並設定 OpenCode(更新 opencode.json)"
|
|
||||||
@echo " make pm2-opencode 設定 OpenCode + 啟動代理"
|
|
||||||
@echo " make opencode-models 編譯、設定 OpenCode 並同步模型列表"
|
|
||||||
@echo " make claude-settings 寫入 ~/.claude/settings.json(模型與 BASE_URL)"
|
|
||||||
@echo " make claude-onboarding 設定 ~/.claude.json hasCompletedOnboarding=true"
|
|
||||||
@echo " make claude-cursor-setup 同上兩步一次完成"
|
|
||||||
@echo " make clean 刪除二進位檔與 .env"
|
|
||||||
@echo ""
|
|
||||||
@echo "Docker Compose 指令:"
|
|
||||||
@echo " make docker-env 複製 .env.example 為 .env(首次設定)"
|
|
||||||
@echo " make docker-setup 首次設定並啟動(自動複製 .env + build + up)"
|
|
||||||
@echo " make docker-build 建置 Docker 映像檔"
|
|
||||||
@echo " make docker-up 啟動容器(背景執行,需已有 .env)"
|
|
||||||
@echo " make docker-down 停止並移除容器"
|
|
||||||
@echo " make docker-logs 查看容器即時日誌"
|
|
||||||
@echo " make docker-restart 重新建置並啟動容器"
|
|
||||||
@echo " make docker-shell 進入容器 shell(除錯用)"
|
|
||||||
@echo ""
|
|
||||||
@echo "Provider 設定範例:"
|
|
||||||
@echo " make env PROVIDER=cursor # 使用 Cursor(預設)"
|
|
||||||
@echo " make env PROVIDER=gemini-web # 使用 Gemini Web"
|
|
||||||
@echo " make env PROVIDER=gemini-web GEMINI_ACCOUNT_DIR=/path/to/sessions"
|
|
||||||
@echo " make env PROVIDER=gemini-web GEMINI_BROWSER_VISIBLE=true"
|
|
||||||
@echo " make opencode # 設定 OpenCode(含 cursor 與 gemini-web provider)"
|
|
||||||
@echo ""
|
|
||||||
@echo "覆寫範例:"
|
|
||||||
@echo " make env PORT=9000 API_KEY=mysecret TIMEOUT_MS=60000"
|
|
||||||
@echo " make pm2-claude-code PORT=8765 API_KEY=mykey"
|
|
||||||
@echo " make pm2-opencode PORT=8765"
|
|
||||||
@echo " make claude-settings PORT=8766 ANTHROPIC_BASE_HOST=localhost ANTHROPIC_DEFAULT_OPUS_MODEL=claude-4.6-opus-high"
|
|
||||||
@echo ""
|
|
||||||
@echo "使用 Gemini Web Provider:"
|
|
||||||
@echo " 1. make env PROVIDER=gemini-web"
|
|
||||||
@echo " 2. gemini-login my-session # 登入並儲存 session"
|
|
||||||
@echo " 3. make run # 啟動代理"
|
|
||||||
@echo " 4. 在 OpenCode 設定 model: gemini/gemini-2.5-pro"
|
|
||||||
152
README.md
152
README.md
|
|
@ -1,152 +0,0 @@
|
||||||
# Cursor API Proxy
|
|
||||||
|
|
||||||
一個代理伺服器,讓你用標準 OpenAI / Anthropic API 存取 Cursor CLI 模型。
|
|
||||||
|
|
||||||
可接入:
|
|
||||||
- Claude Code
|
|
||||||
- OpenCode
|
|
||||||
- 任何支援 OpenAI / Anthropic API 的工具
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 功能
|
|
||||||
|
|
||||||
- API 相容(OpenAI / Anthropic)
|
|
||||||
- 多帳號管理
|
|
||||||
- 模型自動對映(轉成 claude-*)
|
|
||||||
- 支援區網存取(0.0.0.0)
|
|
||||||
- 連線池優化
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 快速開始(本機)
|
|
||||||
### 看幫助
|
|
||||||
```bash
|
|
||||||
make help
|
|
||||||
```
|
|
||||||
### 安裝依賴
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl https://cursor.com/install -fsS | bash
|
|
||||||
curl -fsSL https://claude.ai/install.sh | bash
|
|
||||||
```
|
|
||||||
|
|
||||||
### 下載與建置
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://code.30cm.net/daniel.w/opencode-cursor-agent.git
|
|
||||||
cd cursor-api-proxy-go
|
|
||||||
go build -o cursor-api-proxy .
|
|
||||||
```
|
|
||||||
|
|
||||||
### 登入
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./cursor-api-proxy login myaccount
|
|
||||||
```
|
|
||||||
|
|
||||||
### 啟動
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make env PORT=8766 API_KEY=mysecret
|
|
||||||
make pm2
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Claude Code 設定
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make claude-settings PORT=8766
|
|
||||||
make claude-onboarding
|
|
||||||
claude
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## OpenCode 設定
|
|
||||||
|
|
||||||
編輯:
|
|
||||||
|
|
||||||
~/.config/opencode/opencode.json
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"provider": {
|
|
||||||
"cursor": {
|
|
||||||
"options": {
|
|
||||||
"baseURL": "http://127.0.0.1:8766/v1"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Docker(簡化版)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make docker-setup
|
|
||||||
vim .env
|
|
||||||
make docker-setup
|
|
||||||
```
|
|
||||||
|
|
||||||
檢查:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://localhost:8766/health
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 常用指令
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make docker-up
|
|
||||||
make docker-down
|
|
||||||
make docker-logs
|
|
||||||
make docker-restart
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API
|
|
||||||
|
|
||||||
| 路徑 | 方法 |
|
|
||||||
|------|------|
|
|
||||||
| /v1/chat/completions | POST |
|
|
||||||
| /v1/messages | POST |
|
|
||||||
| /v1/models | GET |
|
|
||||||
| /health | GET |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 環境變數(核心)
|
|
||||||
|
|
||||||
| 變數 | 預設 |
|
|
||||||
|------|------|
|
|
||||||
| CURSOR_BRIDGE_HOST | 127.0.0.1 |
|
|
||||||
| CURSOR_BRIDGE_PORT | 8766 |
|
|
||||||
| CURSOR_BRIDGE_TIMEOUT_MS | 3600000 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 帳號操作
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./cursor-api-proxy login myaccount
|
|
||||||
./cursor-api-proxy accounts
|
|
||||||
./cursor-api-proxy logout myaccount
|
|
||||||
./cursor-api-proxy reset-hwid
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 備註
|
|
||||||
|
|
||||||
- Docker 預設開放區網(0.0.0.0)
|
|
||||||
- 模型自動同步
|
|
||||||
- 支援多模型切換
|
|
||||||
|
|
||||||
---
|
|
||||||
196
cmd/accounts.go
196
cmd/accounts.go
|
|
@ -1,196 +0,0 @@
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/agent"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AccountInfo struct {
|
|
||||||
Name string
|
|
||||||
ConfigDir string
|
|
||||||
Authenticated bool
|
|
||||||
Email string
|
|
||||||
DisplayName string
|
|
||||||
AuthID string
|
|
||||||
Plan string
|
|
||||||
SubscriptionStatus string
|
|
||||||
ExpiresAt string
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadAccountInfo(name, configDir string) AccountInfo {
|
|
||||||
info := AccountInfo{Name: name, ConfigDir: configDir}
|
|
||||||
|
|
||||||
configFile := filepath.Join(configDir, "cli-config.json")
|
|
||||||
data, err := os.ReadFile(configFile)
|
|
||||||
if err != nil {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
|
|
||||||
var raw struct {
|
|
||||||
AuthInfo *struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
AuthID string `json:"authId"`
|
|
||||||
} `json:"authInfo"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &raw); err == nil && raw.AuthInfo != nil {
|
|
||||||
info.Authenticated = true
|
|
||||||
info.Email = raw.AuthInfo.Email
|
|
||||||
info.DisplayName = raw.AuthInfo.DisplayName
|
|
||||||
info.AuthID = raw.AuthInfo.AuthID
|
|
||||||
}
|
|
||||||
|
|
||||||
statsigFile := filepath.Join(configDir, "statsig-cache.json")
|
|
||||||
statsigData, err := os.ReadFile(statsigFile)
|
|
||||||
if err != nil {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
|
|
||||||
var statsigWrapper struct {
|
|
||||||
Data string `json:"data"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(statsigData, &statsigWrapper); err != nil || statsigWrapper.Data == "" {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
|
|
||||||
var statsig struct {
|
|
||||||
User *struct {
|
|
||||||
Custom *struct {
|
|
||||||
IsEnterpriseUser bool `json:"isEnterpriseUser"`
|
|
||||||
StripeSubscriptionStatus string `json:"stripeSubscriptionStatus"`
|
|
||||||
StripeMembershipExpiration string `json:"stripeMembershipExpiration"`
|
|
||||||
} `json:"custom"`
|
|
||||||
} `json:"user"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal([]byte(statsigWrapper.Data), &statsig); err != nil {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
|
|
||||||
if statsig.User != nil && statsig.User.Custom != nil {
|
|
||||||
c := statsig.User.Custom
|
|
||||||
if c.IsEnterpriseUser {
|
|
||||||
info.Plan = "Enterprise"
|
|
||||||
} else if c.StripeSubscriptionStatus == "active" {
|
|
||||||
info.Plan = "Pro"
|
|
||||||
} else {
|
|
||||||
info.Plan = "Free"
|
|
||||||
}
|
|
||||||
info.SubscriptionStatus = c.StripeSubscriptionStatus
|
|
||||||
info.ExpiresAt = c.StripeMembershipExpiration
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
|
|
||||||
func HandleAccountsList() error {
|
|
||||||
accountsDir := agent.AccountsDir()
|
|
||||||
|
|
||||||
entries, err := os.ReadDir(accountsDir)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var names []string
|
|
||||||
for _, e := range entries {
|
|
||||||
if e.IsDir() {
|
|
||||||
names = append(names, e.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(names) == 0 {
|
|
||||||
fmt.Println("No accounts found. Use 'cursor-api-proxy login' to add one.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Print("Cursor Accounts:\n\n")
|
|
||||||
|
|
||||||
keychainToken := agent.ReadKeychainToken()
|
|
||||||
|
|
||||||
for i, name := range names {
|
|
||||||
configDir := filepath.Join(accountsDir, name)
|
|
||||||
info := ReadAccountInfo(name, configDir)
|
|
||||||
|
|
||||||
fmt.Printf(" %d. %s\n", i+1, name)
|
|
||||||
|
|
||||||
if info.Authenticated {
|
|
||||||
cachedToken := agent.ReadCachedToken(configDir)
|
|
||||||
keychainMatchesAccount := keychainToken != "" && info.AuthID != "" && TokenSub(keychainToken) == info.AuthID
|
|
||||||
token := cachedToken
|
|
||||||
if token == "" && keychainMatchesAccount {
|
|
||||||
token = keychainToken
|
|
||||||
}
|
|
||||||
|
|
||||||
var liveProfile *StripeProfile
|
|
||||||
var liveUsage *UsageData
|
|
||||||
if token != "" {
|
|
||||||
liveUsage, _ = FetchAccountUsage(token)
|
|
||||||
liveProfile, _ = FetchStripeProfile(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Email != "" {
|
|
||||||
display := ""
|
|
||||||
if info.DisplayName != "" {
|
|
||||||
display = " (" + info.DisplayName + ")"
|
|
||||||
}
|
|
||||||
fmt.Printf(" %s%s\n", info.Email, display)
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Plan != "" && liveProfile == nil {
|
|
||||||
canceled := ""
|
|
||||||
if info.SubscriptionStatus == "canceled" {
|
|
||||||
canceled = " · canceled"
|
|
||||||
}
|
|
||||||
expiry := ""
|
|
||||||
if info.ExpiresAt != "" {
|
|
||||||
expiry = " · expires " + info.ExpiresAt
|
|
||||||
}
|
|
||||||
fmt.Printf(" %s%s%s\n", info.Plan, canceled, expiry)
|
|
||||||
}
|
|
||||||
fmt.Println(" Authenticated")
|
|
||||||
|
|
||||||
if liveProfile != nil {
|
|
||||||
fmt.Printf(" %s\n", DescribePlan(liveProfile))
|
|
||||||
}
|
|
||||||
if liveUsage != nil {
|
|
||||||
for _, line := range FormatUsageSummary(liveUsage) {
|
|
||||||
fmt.Println(line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fmt.Println(" Not authenticated")
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("")
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("Tip: run 'cursor-api-proxy logout <name>' to remove an account.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func HandleLogout(accountName string) error {
|
|
||||||
if accountName == "" {
|
|
||||||
fmt.Fprintln(os.Stderr, "Error: Please specify the account name to remove.")
|
|
||||||
fmt.Fprintln(os.Stderr, "Usage: cursor-api-proxy logout <account-name>")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
accountsDir := agent.AccountsDir()
|
|
||||||
configDir := filepath.Join(accountsDir, accountName)
|
|
||||||
|
|
||||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
|
||||||
fmt.Fprintf(os.Stderr, "Account '%s' not found.\n", accountName)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.RemoveAll(configDir); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error removing account: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Account '%s' removed.\n", accountName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
118
cmd/args.go
118
cmd/args.go
|
|
@ -1,118 +0,0 @@
|
||||||
package cmd
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
type ParsedArgs struct {
|
|
||||||
Tailscale bool
|
|
||||||
Help bool
|
|
||||||
Login bool
|
|
||||||
AccountsList bool
|
|
||||||
Logout bool
|
|
||||||
AccountName string
|
|
||||||
Proxies []string
|
|
||||||
ResetHwid bool
|
|
||||||
DeepClean bool
|
|
||||||
DryRun bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseArgs(argv []string) (ParsedArgs, error) {
|
|
||||||
var args ParsedArgs
|
|
||||||
|
|
||||||
for i := 0; i < len(argv); i++ {
|
|
||||||
arg := argv[i]
|
|
||||||
|
|
||||||
switch arg {
|
|
||||||
case "login", "add-account":
|
|
||||||
args.Login = true
|
|
||||||
if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' {
|
|
||||||
i++
|
|
||||||
args.AccountName = argv[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
case "logout", "remove-account":
|
|
||||||
args.Logout = true
|
|
||||||
if i+1 < len(argv) && len(argv[i+1]) > 0 && argv[i+1][0] != '-' {
|
|
||||||
i++
|
|
||||||
args.AccountName = argv[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
case "accounts", "list-accounts":
|
|
||||||
args.AccountsList = true
|
|
||||||
|
|
||||||
case "reset-hwid", "reset":
|
|
||||||
args.ResetHwid = true
|
|
||||||
|
|
||||||
case "--deep-clean":
|
|
||||||
args.DeepClean = true
|
|
||||||
|
|
||||||
case "--dry-run":
|
|
||||||
args.DryRun = true
|
|
||||||
|
|
||||||
case "--tailscale":
|
|
||||||
args.Tailscale = true
|
|
||||||
|
|
||||||
case "--help", "-h":
|
|
||||||
args.Help = true
|
|
||||||
|
|
||||||
default:
|
|
||||||
if len(arg) > len("--proxy=") && arg[:len("--proxy=")] == "--proxy=" {
|
|
||||||
raw := arg[len("--proxy="):]
|
|
||||||
parts := splitComma(raw)
|
|
||||||
for _, p := range parts {
|
|
||||||
if p != "" {
|
|
||||||
args.Proxies = append(args.Proxies, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return args, fmt.Errorf("Unknown argument: %s", arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return args, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitComma(s string) []string {
|
|
||||||
var result []string
|
|
||||||
start := 0
|
|
||||||
for i := 0; i <= len(s); i++ {
|
|
||||||
if i == len(s) || s[i] == ',' {
|
|
||||||
part := trim(s[start:i])
|
|
||||||
if part != "" {
|
|
||||||
result = append(result, part)
|
|
||||||
}
|
|
||||||
start = i + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func trim(s string) string {
|
|
||||||
start := 0
|
|
||||||
end := len(s)
|
|
||||||
for start < end && (s[start] == ' ' || s[start] == '\t') {
|
|
||||||
start++
|
|
||||||
}
|
|
||||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
|
|
||||||
end--
|
|
||||||
}
|
|
||||||
return s[start:end]
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrintHelp(version string) {
|
|
||||||
fmt.Printf("cursor-api-proxy v%s\n\n", version)
|
|
||||||
fmt.Println("Usage:")
|
|
||||||
fmt.Println(" cursor-api-proxy [options]")
|
|
||||||
fmt.Println("")
|
|
||||||
fmt.Println("Commands:")
|
|
||||||
fmt.Println(" login [name] Log into a Cursor account (saved to ~/.cursor-api-proxy/accounts/)")
|
|
||||||
fmt.Println(" login [name] --proxy=... Same, but with a proxy from a comma-separated list")
|
|
||||||
fmt.Println(" logout <name> Remove a saved Cursor account")
|
|
||||||
fmt.Println(" accounts List saved accounts with plan info")
|
|
||||||
fmt.Println(" reset-hwid Reset Cursor machine/telemetry IDs (anti-ban)")
|
|
||||||
fmt.Println(" reset-hwid --deep-clean Also wipe session storage and cookies")
|
|
||||||
fmt.Println("")
|
|
||||||
fmt.Println("Options:")
|
|
||||||
fmt.Println(" --tailscale Bind to 0.0.0.0 for tailnet/LAN access")
|
|
||||||
fmt.Println(" -h, --help Show this help message")
|
|
||||||
}
|
|
||||||
|
|
@ -1,61 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/env"
|
|
||||||
"cursor-api-proxy/internal/providers/geminiweb"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
accountName := ""
|
|
||||||
visible := false
|
|
||||||
|
|
||||||
// 解析命令列參數
|
|
||||||
for i := 1; i < len(os.Args); i++ {
|
|
||||||
arg := os.Args[i]
|
|
||||||
if arg == "--visible" || arg == "-v" {
|
|
||||||
visible = true
|
|
||||||
} else if arg == "--help" || arg == "-h" {
|
|
||||||
printHelp()
|
|
||||||
os.Exit(0)
|
|
||||||
} else if !strings.HasPrefix(arg, "-") {
|
|
||||||
accountName = arg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
e := env.OsEnvToMap()
|
|
||||||
loaded := env.LoadEnvConfig(e, "")
|
|
||||||
cfg := config.LoadBridgeConfig(e, "")
|
|
||||||
|
|
||||||
cfg.GeminiAccountDir = loaded.GeminiAccountDir
|
|
||||||
// 命令列參數優先於環境變數
|
|
||||||
cfg.GeminiBrowserVisible = visible || loaded.GeminiBrowserVisible
|
|
||||||
|
|
||||||
fmt.Printf("Session 儲存位置: %s\n", cfg.GeminiAccountDir)
|
|
||||||
fmt.Printf("瀏覽器可見: %v\n", cfg.GeminiBrowserVisible)
|
|
||||||
|
|
||||||
if err := geminiweb.RunLogin(cfg, accountName); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func printHelp() {
|
|
||||||
fmt.Println("使用方法: gemini-login [options] [session-name]")
|
|
||||||
fmt.Println("")
|
|
||||||
fmt.Println("選項:")
|
|
||||||
fmt.Println(" --visible, -v 顯示瀏覽器視窗(預設隱藏)")
|
|
||||||
fmt.Println(" --help, -h 顯示此說明")
|
|
||||||
fmt.Println("")
|
|
||||||
fmt.Println("環境變數:")
|
|
||||||
fmt.Println(" GEMINI_ACCOUNT_DIR Session 儲存目錄(預設: ~/.cursor-api-proxy/gemini-accounts)")
|
|
||||||
fmt.Println(" GEMINI_BROWSER_VISIBLE 是否顯示瀏覽器(true/false,預設: false)")
|
|
||||||
fmt.Println("")
|
|
||||||
fmt.Println("範例:")
|
|
||||||
fmt.Println(" gemini-login my-session")
|
|
||||||
fmt.Println(" gemini-login --visible my-session")
|
|
||||||
fmt.Println(" GEMINI_BROWSER_VISIBLE=true gemini-login")
|
|
||||||
}
|
|
||||||
125
cmd/login.go
125
cmd/login.go
|
|
@ -1,125 +0,0 @@
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"cursor-api-proxy/internal/agent"
|
|
||||||
"cursor-api-proxy/internal/env"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"os/signal"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var loginURLRe = regexp.MustCompile(`https://cursor\.com/loginDeepControl.*?redirectTarget=cli`)
|
|
||||||
|
|
||||||
func HandleLogin(accountName string, proxies []string) error {
|
|
||||||
e := env.OsEnvToMap()
|
|
||||||
loaded := env.LoadEnvConfig(e, "")
|
|
||||||
agentBin := loaded.AgentBin
|
|
||||||
|
|
||||||
if accountName == "" {
|
|
||||||
accountName = fmt.Sprintf("account-%d", time.Now().UnixMilli()%10000)
|
|
||||||
}
|
|
||||||
|
|
||||||
accountsDir := agent.AccountsDir()
|
|
||||||
configDir := filepath.Join(accountsDir, accountName)
|
|
||||||
dirWasNew := !fileExists(configDir)
|
|
||||||
|
|
||||||
if err := os.MkdirAll(accountsDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create accounts dir: %w", err)
|
|
||||||
}
|
|
||||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create config dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Logging into Cursor account: %s\n", accountName)
|
|
||||||
fmt.Printf("Config: %s\n\n", configDir)
|
|
||||||
fmt.Println("Run the login command — complete the login in your browser.")
|
|
||||||
fmt.Println("")
|
|
||||||
|
|
||||||
cleanupDir := func() {
|
|
||||||
if dirWasNew {
|
|
||||||
_ = os.RemoveAll(configDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdEnv := make([]string, 0, len(e)+2)
|
|
||||||
for k, v := range e {
|
|
||||||
cmdEnv = append(cmdEnv, k+"="+v)
|
|
||||||
}
|
|
||||||
cmdEnv = append(cmdEnv, "CURSOR_CONFIG_DIR="+configDir)
|
|
||||||
cmdEnv = append(cmdEnv, "NO_OPEN_BROWSER=1")
|
|
||||||
|
|
||||||
child := exec.Command(agentBin, "login")
|
|
||||||
child.Env = cmdEnv
|
|
||||||
child.Stdin = os.Stdin
|
|
||||||
child.Stderr = os.Stderr
|
|
||||||
|
|
||||||
stdoutPipe, err := child.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create stdout pipe: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := child.Start(); err != nil {
|
|
||||||
cleanupDir()
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("could not find '%s'. Make sure the Cursor CLI is installed", agentBin)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("error launching agent login: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle cancellation signals
|
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
|
||||||
go func() {
|
|
||||||
sig := <-sigCh
|
|
||||||
_ = child.Process.Kill()
|
|
||||||
cleanupDir()
|
|
||||||
if sig == syscall.SIGINT {
|
|
||||||
fmt.Println("\n\nLogin cancelled.")
|
|
||||||
}
|
|
||||||
os.Exit(0)
|
|
||||||
}()
|
|
||||||
defer signal.Stop(sigCh)
|
|
||||||
|
|
||||||
var stdoutBuf string
|
|
||||||
scanner := bufio.NewScanner(stdoutPipe)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
fmt.Println(line)
|
|
||||||
stdoutBuf += line + "\n"
|
|
||||||
|
|
||||||
if loginURLRe.MatchString(stdoutBuf) {
|
|
||||||
match := loginURLRe.FindString(stdoutBuf)
|
|
||||||
if match != "" {
|
|
||||||
fmt.Printf("\nOpen this URL in your browser (incognito recommended):\n%s\n\n", match)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := child.Wait(); err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
cleanupDir()
|
|
||||||
return fmt.Errorf("login failed with code %d", exitErr.ExitCode())
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache keychain token for this account
|
|
||||||
token := agent.ReadKeychainToken()
|
|
||||||
if token != "" {
|
|
||||||
agent.WriteCachedToken(configDir, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("\nAccount '%s' saved — it will be auto-discovered when you start the proxy.\n", accountName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func fileExists(path string) bool {
|
|
||||||
_, err := os.Stat(path)
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
261
cmd/resethwid.go
261
cmd/resethwid.go
|
|
@ -1,261 +0,0 @@
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/sha512"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
func sha256hex() string {
|
|
||||||
b := make([]byte, 32)
|
|
||||||
_, _ = rand.Read(b)
|
|
||||||
h := sha256.Sum256(b)
|
|
||||||
return hex.EncodeToString(h[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func sha512hex() string {
|
|
||||||
b := make([]byte, 64)
|
|
||||||
_, _ = rand.Read(b)
|
|
||||||
h := sha512.Sum512(b)
|
|
||||||
return hex.EncodeToString(h[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUUID() string {
|
|
||||||
return uuid.New().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func log(icon, msg string) {
|
|
||||||
fmt.Printf(" %s %s\n", icon, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCursorGlobalStorage() string {
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
return filepath.Join(home, "Library", "Application Support", "Cursor", "User", "globalStorage")
|
|
||||||
case "windows":
|
|
||||||
appdata := os.Getenv("APPDATA")
|
|
||||||
return filepath.Join(appdata, "Cursor", "User", "globalStorage")
|
|
||||||
default:
|
|
||||||
xdg := os.Getenv("XDG_CONFIG_HOME")
|
|
||||||
if xdg == "" {
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
xdg = filepath.Join(home, ".config")
|
|
||||||
}
|
|
||||||
return filepath.Join(xdg, "Cursor", "User", "globalStorage")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCursorRoot() string {
|
|
||||||
gs := getCursorGlobalStorage()
|
|
||||||
return filepath.Dir(filepath.Dir(gs))
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateNewIDs() map[string]string {
|
|
||||||
return map[string]string{
|
|
||||||
"telemetry.machineId": sha256hex(),
|
|
||||||
"telemetry.macMachineId": sha512hex(),
|
|
||||||
"telemetry.devDeviceId": newUUID(),
|
|
||||||
"telemetry.sqmId": "{" + fmt.Sprintf("%s", newUUID()+"") + "}",
|
|
||||||
"storage.serviceMachineId": newUUID(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func killCursor() {
|
|
||||||
log("", "Stopping Cursor processes...")
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "windows":
|
|
||||||
exec.Command("taskkill", "/F", "/IM", "Cursor.exe").Run()
|
|
||||||
default:
|
|
||||||
exec.Command("pkill", "-x", "Cursor").Run()
|
|
||||||
exec.Command("pkill", "-f", "Cursor.app").Run()
|
|
||||||
}
|
|
||||||
log("", "Cursor stopped (or was not running)")
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateStorageJSON(storagePath string, ids map[string]string) {
|
|
||||||
if _, err := os.Stat(storagePath); os.IsNotExist(err) {
|
|
||||||
log("", fmt.Sprintf("storage.json not found: %s", storagePath))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
exec.Command("chflags", "nouchg", storagePath).Run()
|
|
||||||
exec.Command("chmod", "644", storagePath).Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := os.ReadFile(storagePath)
|
|
||||||
if err != nil {
|
|
||||||
log("", fmt.Sprintf("storage.json read error: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var obj map[string]interface{}
|
|
||||||
if err := json.Unmarshal(data, &obj); err != nil {
|
|
||||||
log("", fmt.Sprintf("storage.json parse error: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range ids {
|
|
||||||
obj[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := json.MarshalIndent(obj, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
log("", fmt.Sprintf("storage.json marshal error: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(storagePath, out, 0644); err != nil {
|
|
||||||
log("", fmt.Sprintf("storage.json write error: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log("", "storage.json updated")
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateStateVscdb(dbPath string, ids map[string]string) {
|
|
||||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
|
||||||
log("", fmt.Sprintf("state.vscdb not found: %s", dbPath))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
exec.Command("chflags", "nouchg", dbPath).Run()
|
|
||||||
exec.Command("chmod", "644", dbPath).Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := updateVscdbPureGo(dbPath, ids); err != nil {
|
|
||||||
log("", fmt.Sprintf("state.vscdb error: %v", err))
|
|
||||||
} else {
|
|
||||||
log("", "state.vscdb updated")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateMachineIDFile(machineID, cursorRoot string) {
|
|
||||||
var candidates []string
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
candidates = []string{
|
|
||||||
filepath.Join(cursorRoot, "machineid"),
|
|
||||||
filepath.Join(cursorRoot, "machineId"),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
candidates = []string{filepath.Join(cursorRoot, "machineId")}
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath := candidates[0]
|
|
||||||
for _, c := range candidates {
|
|
||||||
if _, err := os.Stat(c); err == nil {
|
|
||||||
filePath = c
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
|
|
||||||
log("", fmt.Sprintf("machineId dir error: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
if _, err := os.Stat(filePath); err == nil {
|
|
||||||
exec.Command("chflags", "nouchg", filePath).Run()
|
|
||||||
exec.Command("chmod", "644", filePath).Run()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(filePath, []byte(machineID+"\n"), 0644); err != nil {
|
|
||||||
log("", fmt.Sprintf("machineId write error: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log("", fmt.Sprintf("machineId file updated (%s)", filepath.Base(filePath)))
|
|
||||||
}
|
|
||||||
|
|
||||||
var dirsToWipe = []string{
|
|
||||||
"Session Storage", "Local Storage", "IndexedDB", "Cache", "Code Cache",
|
|
||||||
"GPUCache", "Service Worker", "Network", "Cookies", "Cookies-journal",
|
|
||||||
}
|
|
||||||
|
|
||||||
func deepClean(cursorRoot string) {
|
|
||||||
log("", "Deep-cleaning session data...")
|
|
||||||
wiped := 0
|
|
||||||
for _, name := range dirsToWipe {
|
|
||||||
target := filepath.Join(cursorRoot, name)
|
|
||||||
if _, err := os.Stat(target); os.IsNotExist(err) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
info, err := os.Stat(target)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if info.IsDir() {
|
|
||||||
if err := os.RemoveAll(target); err == nil {
|
|
||||||
wiped++
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := os.Remove(target); err == nil {
|
|
||||||
wiped++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log("", fmt.Sprintf("Wiped %d cache/session items", wiped))
|
|
||||||
}
|
|
||||||
|
|
||||||
func HandleResetHwid(doDeepClean, dryRun bool) error {
|
|
||||||
fmt.Print("\nCursor HWID Reset\n\n")
|
|
||||||
fmt.Println(" Resets all machine / telemetry IDs so Cursor sees a fresh install.")
|
|
||||||
fmt.Print(" Cursor must be closed — it will be killed automatically.\n\n")
|
|
||||||
|
|
||||||
globalStorage := getCursorGlobalStorage()
|
|
||||||
cursorRoot := getCursorRoot()
|
|
||||||
|
|
||||||
if _, err := os.Stat(globalStorage); os.IsNotExist(err) {
|
|
||||||
fmt.Printf("Cursor config not found at:\n %s\n", globalStorage)
|
|
||||||
fmt.Println(" Make sure Cursor is installed and has been run at least once.")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dryRun {
|
|
||||||
fmt.Println(" [DRY RUN] Would reset IDs in:")
|
|
||||||
fmt.Printf(" %s\n", filepath.Join(globalStorage, "storage.json"))
|
|
||||||
fmt.Printf(" %s\n", filepath.Join(globalStorage, "state.vscdb"))
|
|
||||||
fmt.Printf(" %s\n", filepath.Join(cursorRoot, "machineId"))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
killCursor()
|
|
||||||
|
|
||||||
time.Sleep(800 * time.Millisecond)
|
|
||||||
|
|
||||||
newIDs := generateNewIDs()
|
|
||||||
log("", "Generated new IDs:")
|
|
||||||
for k, v := range newIDs {
|
|
||||||
fmt.Printf(" %s: %s\n", k, v)
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
log("", "Updating storage.json...")
|
|
||||||
updateStorageJSON(filepath.Join(globalStorage, "storage.json"), newIDs)
|
|
||||||
|
|
||||||
log("", "Updating state.vscdb...")
|
|
||||||
updateStateVscdb(filepath.Join(globalStorage, "state.vscdb"), newIDs)
|
|
||||||
|
|
||||||
log("", "Updating machineId file...")
|
|
||||||
updateMachineIDFile(newIDs["telemetry.machineId"], cursorRoot)
|
|
||||||
|
|
||||||
if doDeepClean {
|
|
||||||
fmt.Println()
|
|
||||||
deepClean(cursorRoot)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Print("\nHWID reset complete. You can now restart Cursor.\n\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,29 +0,0 @@
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
func updateVscdbPureGo(dbPath string, ids map[string]string) error {
|
|
||||||
db, err := sql.Open("sqlite", dbPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open db: %w", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS ItemTable (key TEXT PRIMARY KEY, value TEXT NOT NULL)`)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range ids {
|
|
||||||
_, err = db.Exec(`INSERT OR REPLACE INTO ItemTable (key, value) VALUES (?, ?)`, k, v)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("insert %s: %w", k, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
255
cmd/usage.go
255
cmd/usage.go
|
|
@ -1,255 +0,0 @@
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ModelUsage struct {
|
|
||||||
NumRequests int `json:"numRequests"`
|
|
||||||
NumRequestsTotal int `json:"numRequestsTotal"`
|
|
||||||
NumTokens int `json:"numTokens"`
|
|
||||||
MaxTokenUsage *int `json:"maxTokenUsage"`
|
|
||||||
MaxRequestUsage *int `json:"maxRequestUsage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UsageData struct {
|
|
||||||
StartOfMonth string `json:"startOfMonth"`
|
|
||||||
Models map[string]ModelUsage `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type StripeProfile struct {
|
|
||||||
MembershipType string `json:"membershipType"`
|
|
||||||
SubscriptionStatus string `json:"subscriptionStatus"`
|
|
||||||
DaysRemainingOnTrial *int `json:"daysRemainingOnTrial"`
|
|
||||||
IsTeamMember bool `json:"isTeamMember"`
|
|
||||||
IsYearlyPlan bool `json:"isYearlyPlan"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeJWTPayload(token string) map[string]interface{} {
|
|
||||||
parts := strings.Split(token, ".")
|
|
||||||
if len(parts) < 2 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
padded := strings.ReplaceAll(parts[1], "-", "+")
|
|
||||||
padded = strings.ReplaceAll(padded, "_", "/")
|
|
||||||
data, err := base64.StdEncoding.DecodeString(padded + strings.Repeat("=", (4-len(padded)%4)%4))
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var result map[string]interface{}
|
|
||||||
if err := json.Unmarshal(data, &result); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func TokenSub(token string) string {
|
|
||||||
payload := DecodeJWTPayload(token)
|
|
||||||
if payload == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if sub, ok := payload["sub"].(string); ok {
|
|
||||||
return sub
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiGet(path, token string) (map[string]interface{}, error) {
|
|
||||||
client := &http.Client{Timeout: 8 * time.Second}
|
|
||||||
req, err := http.NewRequest("GET", "https://api2.cursor.sh"+path, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+token)
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var result map[string]interface{}
|
|
||||||
if err := json.Unmarshal(data, &result); err != nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func FetchAccountUsage(token string) (*UsageData, error) {
|
|
||||||
raw, err := apiGet("/auth/usage", token)
|
|
||||||
if err != nil || raw == nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
startOfMonth, _ := raw["startOfMonth"].(string)
|
|
||||||
usage := &UsageData{
|
|
||||||
StartOfMonth: startOfMonth,
|
|
||||||
Models: make(map[string]ModelUsage),
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range raw {
|
|
||||||
if k == "startOfMonth" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var mu ModelUsage
|
|
||||||
if err := json.Unmarshal(data, &mu); err == nil {
|
|
||||||
usage.Models[k] = mu
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return usage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func FetchStripeProfile(token string) (*StripeProfile, error) {
|
|
||||||
raw, err := apiGet("/auth/full_stripe_profile", token)
|
|
||||||
if err != nil || raw == nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
profile := &StripeProfile{
|
|
||||||
MembershipType: fmt.Sprintf("%v", raw["membershipType"]),
|
|
||||||
SubscriptionStatus: fmt.Sprintf("%v", raw["subscriptionStatus"]),
|
|
||||||
IsTeamMember: raw["isTeamMember"] == true,
|
|
||||||
IsYearlyPlan: raw["isYearlyPlan"] == true,
|
|
||||||
}
|
|
||||||
if d, ok := raw["daysRemainingOnTrial"].(float64); ok {
|
|
||||||
di := int(d)
|
|
||||||
profile.DaysRemainingOnTrial = &di
|
|
||||||
}
|
|
||||||
return profile, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DescribePlan(profile *StripeProfile) string {
|
|
||||||
if profile == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
switch profile.MembershipType {
|
|
||||||
case "free_trial":
|
|
||||||
days := 0
|
|
||||||
if profile.DaysRemainingOnTrial != nil {
|
|
||||||
days = *profile.DaysRemainingOnTrial
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("Pro Trial (%dd left) — unlimited fast requests", days)
|
|
||||||
case "pro":
|
|
||||||
return "Pro — extended limits"
|
|
||||||
case "pro_plus":
|
|
||||||
return "Pro+ — extended limits"
|
|
||||||
case "ultra":
|
|
||||||
return "Ultra — extended limits"
|
|
||||||
case "free", "hobby":
|
|
||||||
return "Hobby (free) — limited agent requests"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%s · %s", profile.MembershipType, profile.SubscriptionStatus)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var modelLabels = map[string]string{
|
|
||||||
"gpt-4": "Fast Premium Requests",
|
|
||||||
"claude-sonnet-4-6": "Claude Sonnet 4.6",
|
|
||||||
"claude-sonnet-4-5-20250929-v1": "Claude Sonnet 4.5",
|
|
||||||
"claude-sonnet-4-20250514-v1": "Claude Sonnet 4",
|
|
||||||
"claude-opus-4-6-v1": "Claude Opus 4.6",
|
|
||||||
"claude-opus-4-5-20251101-v1": "Claude Opus 4.5",
|
|
||||||
"claude-opus-4-1-20250805-v1": "Claude Opus 4.1",
|
|
||||||
"claude-opus-4-20250514-v1": "Claude Opus 4",
|
|
||||||
"claude-haiku-4-5-20251001-v1": "Claude Haiku 4.5",
|
|
||||||
"claude-3-5-haiku-20241022-v1": "Claude 3.5 Haiku",
|
|
||||||
"gpt-5": "GPT-5",
|
|
||||||
"gpt-4o": "GPT-4o",
|
|
||||||
"o1": "o1",
|
|
||||||
"o3-mini": "o3-mini",
|
|
||||||
"cursor-small": "Cursor Small (free)",
|
|
||||||
}
|
|
||||||
|
|
||||||
func modelLabel(key string) string {
|
|
||||||
if label, ok := modelLabels[key]; ok {
|
|
||||||
return label
|
|
||||||
}
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
func FormatUsageSummary(usage *UsageData) []string {
|
|
||||||
if usage == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var lines []string
|
|
||||||
|
|
||||||
start := "?"
|
|
||||||
if usage.StartOfMonth != "" {
|
|
||||||
if t, err := time.Parse(time.RFC3339, usage.StartOfMonth); err == nil {
|
|
||||||
start = t.Format("2006-01-02")
|
|
||||||
} else {
|
|
||||||
start = usage.StartOfMonth
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lines = append(lines, fmt.Sprintf(" Billing period from %s", start))
|
|
||||||
|
|
||||||
if len(usage.Models) == 0 {
|
|
||||||
lines = append(lines, " No requests this billing period")
|
|
||||||
return lines
|
|
||||||
}
|
|
||||||
|
|
||||||
type entry struct {
|
|
||||||
key string
|
|
||||||
usage ModelUsage
|
|
||||||
}
|
|
||||||
var entries []entry
|
|
||||||
for k, v := range usage.Models {
|
|
||||||
entries = append(entries, entry{k, v})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort: entries with limits first, then by usage descending
|
|
||||||
for i := 1; i < len(entries); i++ {
|
|
||||||
for j := i; j > 0; j-- {
|
|
||||||
a, b := entries[j-1], entries[j]
|
|
||||||
aHasLimit := a.usage.MaxRequestUsage != nil
|
|
||||||
bHasLimit := b.usage.MaxRequestUsage != nil
|
|
||||||
if !aHasLimit && bHasLimit {
|
|
||||||
entries[j-1], entries[j] = entries[j], entries[j-1]
|
|
||||||
} else if aHasLimit == bHasLimit && a.usage.NumRequests < b.usage.NumRequests {
|
|
||||||
entries[j-1], entries[j] = entries[j], entries[j-1]
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, e := range entries {
|
|
||||||
used := e.usage.NumRequests
|
|
||||||
max := e.usage.MaxRequestUsage
|
|
||||||
label := modelLabel(e.key)
|
|
||||||
if max != nil && *max > 0 {
|
|
||||||
pct := int(float64(used) / float64(*max) * 100)
|
|
||||||
bar := makeBar(used, *max, 12)
|
|
||||||
lines = append(lines, fmt.Sprintf(" %s: %d/%d (%d%%) [%s]", label, used, *max, pct, bar))
|
|
||||||
} else if used > 0 {
|
|
||||||
lines = append(lines, fmt.Sprintf(" %s: %d requests", label, used))
|
|
||||||
} else {
|
|
||||||
lines = append(lines, fmt.Sprintf(" %s: 0 requests (unlimited)", label))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return lines
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeBar(used, max, width int) string {
|
|
||||||
fill := int(float64(used) / float64(max) * float64(width))
|
|
||||||
if fill > width {
|
|
||||||
fill = width
|
|
||||||
}
|
|
||||||
return strings.Repeat("█", fill) + strings.Repeat("░", width-fill)
|
|
||||||
}
|
|
||||||
29
go.mod
29
go.mod
|
|
@ -1,29 +0,0 @@
|
||||||
module cursor-api-proxy
|
|
||||||
|
|
||||||
go 1.25.0
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/google/uuid v1.6.0
|
|
||||||
modernc.org/sqlite v1.48.0
|
|
||||||
)
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/deckarep/golang-set/v2 v2.8.0 // indirect
|
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
|
||||||
github.com/go-jose/go-jose/v3 v3.0.5 // indirect
|
|
||||||
github.com/go-rod/rod v0.116.2 // indirect
|
|
||||||
github.com/go-stack/stack v1.8.1 // indirect
|
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
|
||||||
github.com/playwright-community/playwright-go v0.5700.1 // indirect
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
|
||||||
github.com/ysmood/fetchup v0.2.3 // indirect
|
|
||||||
github.com/ysmood/goob v0.4.0 // indirect
|
|
||||||
github.com/ysmood/got v0.40.0 // indirect
|
|
||||||
github.com/ysmood/gson v0.7.3 // indirect
|
|
||||||
github.com/ysmood/leakless v0.9.0 // indirect
|
|
||||||
golang.org/x/sys v0.42.0 // indirect
|
|
||||||
modernc.org/libc v1.70.0 // indirect
|
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
|
||||||
modernc.org/memory v1.11.0 // indirect
|
|
||||||
)
|
|
||||||
117
go.sum
117
go.sum
|
|
@ -1,117 +0,0 @@
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/deckarep/golang-set/v2 v2.8.0 h1:swm0rlPCmdWn9mESxKOjWk8hXSqoxOp+ZlfuyaAdFlQ=
|
|
||||||
github.com/deckarep/golang-set/v2 v2.8.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4=
|
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
|
||||||
github.com/go-jose/go-jose/v3 v3.0.5 h1:BLLJWbC4nMZOfuPVxoZIxeYsn6Nl2r1fITaJ78UQlVQ=
|
|
||||||
github.com/go-jose/go-jose/v3 v3.0.5/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
|
|
||||||
github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA=
|
|
||||||
github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg=
|
|
||||||
github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw=
|
|
||||||
github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4=
|
|
||||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
|
||||||
github.com/playwright-community/playwright-go v0.5700.1 h1:PNFb1byWqrTT720rEO0JL88C6Ju0EmUnR5deFLvtP/U=
|
|
||||||
github.com/playwright-community/playwright-go v0.5700.1/go.mod h1:MlSn1dZrx8rszbCxY6x3qK89ZesJUYVx21B2JnkoNF0=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
|
||||||
github.com/ysmood/fetchup v0.2.3 h1:ulX+SonA0Vma5zUFXtv52Kzip/xe7aj4vqT5AJwQ+ZQ=
|
|
||||||
github.com/ysmood/fetchup v0.2.3/go.mod h1:xhibcRKziSvol0H1/pj33dnKrYyI2ebIvz5cOOkYGns=
|
|
||||||
github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ=
|
|
||||||
github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18=
|
|
||||||
github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q=
|
|
||||||
github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg=
|
|
||||||
github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM=
|
|
||||||
github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE=
|
|
||||||
github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg=
|
|
||||||
github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU=
|
|
||||||
github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ=
|
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
|
||||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
|
||||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
|
||||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
|
||||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
|
||||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
|
||||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
|
||||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
|
||||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
|
||||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
|
||||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
|
||||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
|
||||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
|
||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
|
||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
|
||||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
|
||||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
|
||||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
|
||||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
||||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
|
||||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
|
||||||
modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
|
|
||||||
modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
|
|
||||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
|
||||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
|
||||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
|
||||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
|
||||||
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
|
||||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
|
||||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
|
||||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
|
||||||
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
|
|
||||||
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
|
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
|
||||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
|
||||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
|
||||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
|
||||||
modernc.org/sqlite v1.48.0 h1:ElZyLop3Q2mHYk5IFPPXADejZrlHu7APbpB0sF78bq4=
|
|
||||||
modernc.org/sqlite v1.48.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
|
||||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import "cursor-api-proxy/internal/config"
|
|
||||||
|
|
||||||
func BuildAgentFixedArgs(cfg config.BridgeConfig, workspaceDir, model string, stream bool) []string {
|
|
||||||
args := []string{"--print"}
|
|
||||||
if cfg.ApproveMcps {
|
|
||||||
args = append(args, "--approve-mcps")
|
|
||||||
}
|
|
||||||
if cfg.Force {
|
|
||||||
args = append(args, "--force")
|
|
||||||
}
|
|
||||||
if cfg.ChatOnlyWorkspace {
|
|
||||||
args = append(args, "--trust")
|
|
||||||
}
|
|
||||||
args = append(args, "--workspace", workspaceDir)
|
|
||||||
args = append(args, "--model", model)
|
|
||||||
if stream {
|
|
||||||
args = append(args, "--stream-partial-output", "--output-format", "stream-json")
|
|
||||||
} else {
|
|
||||||
args = append(args, "--output-format", "text")
|
|
||||||
}
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
func BuildAgentCmdArgs(cfg config.BridgeConfig, workspaceDir, model, prompt string, stream bool) []string {
|
|
||||||
return append(BuildAgentFixedArgs(cfg, workspaceDir, model, stream), prompt)
|
|
||||||
}
|
|
||||||
|
|
@ -1,85 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getCandidates(agentScriptPath, configDirOverride string) []string {
|
|
||||||
if configDirOverride != "" {
|
|
||||||
return []string{filepath.Join(configDirOverride, "cli-config.json")}
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []string
|
|
||||||
|
|
||||||
if dir := os.Getenv("CURSOR_CONFIG_DIR"); dir != "" {
|
|
||||||
result = append(result, filepath.Join(dir, "cli-config.json"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if agentScriptPath != "" {
|
|
||||||
agentDir := filepath.Dir(agentScriptPath)
|
|
||||||
result = append(result, filepath.Join(agentDir, "..", "data", "config", "cli-config.json"))
|
|
||||||
}
|
|
||||||
|
|
||||||
home := os.Getenv("HOME")
|
|
||||||
if home == "" {
|
|
||||||
home = os.Getenv("USERPROFILE")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "windows":
|
|
||||||
local := os.Getenv("LOCALAPPDATA")
|
|
||||||
if local == "" {
|
|
||||||
local = filepath.Join(home, "AppData", "Local")
|
|
||||||
}
|
|
||||||
result = append(result, filepath.Join(local, "cursor-agent", "cli-config.json"))
|
|
||||||
case "darwin":
|
|
||||||
result = append(result, filepath.Join(home, "Library", "Application Support", "cursor-agent", "cli-config.json"))
|
|
||||||
default:
|
|
||||||
xdg := os.Getenv("XDG_CONFIG_HOME")
|
|
||||||
if xdg == "" {
|
|
||||||
xdg = filepath.Join(home, ".config")
|
|
||||||
}
|
|
||||||
result = append(result, filepath.Join(xdg, "cursor-agent", "cli-config.json"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunMaxModePreflight(agentScriptPath, configDirOverride string) {
|
|
||||||
for _, candidate := range getCandidates(agentScriptPath, configDirOverride) {
|
|
||||||
data, err := os.ReadFile(candidate)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strip BOM if present
|
|
||||||
if len(data) >= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF {
|
|
||||||
data = data[3:]
|
|
||||||
}
|
|
||||||
|
|
||||||
var raw map[string]interface{}
|
|
||||||
if err := json.Unmarshal(data, &raw); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if raw == nil || len(raw) <= 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
raw["maxMode"] = true
|
|
||||||
if model, ok := raw["model"].(map[string]interface{}); ok {
|
|
||||||
model["maxMode"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := json.MarshalIndent(raw, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(candidate, out, 0644); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/process"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
process.MaxModeFn = RunMaxModePreflight
|
|
||||||
}
|
|
||||||
|
|
||||||
func cacheTokenForAccount(configDir string) {
|
|
||||||
if configDir == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
token := ReadKeychainToken()
|
|
||||||
if token != "" {
|
|
||||||
WriteCachedToken(configDir, token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func AccountsDir() string {
|
|
||||||
home := os.Getenv("HOME")
|
|
||||||
if home == "" {
|
|
||||||
home = os.Getenv("USERPROFILE")
|
|
||||||
}
|
|
||||||
return filepath.Join(home, ".cursor-api-proxy", "accounts")
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunAgentSync(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, tempDir, configDir string, ctx context.Context) (process.RunResult, error) {
|
|
||||||
opts := process.RunOptions{
|
|
||||||
Cwd: workspaceDir,
|
|
||||||
TimeoutMs: cfg.TimeoutMs,
|
|
||||||
MaxMode: cfg.MaxMode,
|
|
||||||
ConfigDir: configDir,
|
|
||||||
Ctx: ctx,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := process.Run(cfg.AgentBin, cmdArgs, opts)
|
|
||||||
|
|
||||||
cacheTokenForAccount(configDir)
|
|
||||||
if tempDir != "" {
|
|
||||||
os.RemoveAll(tempDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunAgentStreamWithContext(cfg config.BridgeConfig, workspaceDir string, cmdArgs []string, onLine func(string), tempDir, configDir string, ctx context.Context) (process.StreamResult, error) {
|
|
||||||
opts := process.RunStreamingOptions{
|
|
||||||
RunOptions: process.RunOptions{
|
|
||||||
Cwd: workspaceDir,
|
|
||||||
TimeoutMs: cfg.TimeoutMs,
|
|
||||||
MaxMode: cfg.MaxMode,
|
|
||||||
ConfigDir: configDir,
|
|
||||||
Ctx: ctx,
|
|
||||||
},
|
|
||||||
OnLine: onLine,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := process.RunStreaming(cfg.AgentBin, cmdArgs, opts)
|
|
||||||
|
|
||||||
cacheTokenForAccount(configDir)
|
|
||||||
if tempDir != "" {
|
|
||||||
os.RemoveAll(tempDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const tokenFile = ".cursor-token"
|
|
||||||
|
|
||||||
func ReadCachedToken(configDir string) string {
|
|
||||||
p := filepath.Join(configDir, tokenFile)
|
|
||||||
data, err := os.ReadFile(p)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(string(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteCachedToken(configDir, token string) {
|
|
||||||
p := filepath.Join(configDir, tokenFile)
|
|
||||||
_ = os.WriteFile(p, []byte(token), 0600)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadKeychainToken() string {
|
|
||||||
if runtime.GOOS != "darwin" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
out, err := exec.Command("security", "find-generic-password", "-s", "cursor-access-token", "-w").Output()
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(string(out))
|
|
||||||
}
|
|
||||||
|
|
@ -1,174 +0,0 @@
|
||||||
package anthropic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/openai"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MessageParam struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content interface{} `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MessagesRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
Messages []MessageParam `json:"messages"`
|
|
||||||
System interface{} `json:"system"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
Tools []interface{} `json:"tools"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func systemToText(system interface{}) string {
|
|
||||||
if system == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
switch v := system.(type) {
|
|
||||||
case string:
|
|
||||||
return strings.TrimSpace(v)
|
|
||||||
case []interface{}:
|
|
||||||
var parts []string
|
|
||||||
for _, p := range v {
|
|
||||||
if m, ok := p.(map[string]interface{}); ok {
|
|
||||||
if m["type"] == "text" {
|
|
||||||
if t, ok := m["text"].(string); ok {
|
|
||||||
parts = append(parts, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(parts, "\n")
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func anthropicBlockToText(p interface{}) string {
|
|
||||||
if p == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
switch v := p.(type) {
|
|
||||||
case string:
|
|
||||||
return v
|
|
||||||
case map[string]interface{}:
|
|
||||||
typ, _ := v["type"].(string)
|
|
||||||
switch typ {
|
|
||||||
case "text":
|
|
||||||
if t, ok := v["text"].(string); ok {
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
case "image":
|
|
||||||
if src, ok := v["source"].(map[string]interface{}); ok {
|
|
||||||
srcType, _ := src["type"].(string)
|
|
||||||
switch srcType {
|
|
||||||
case "base64":
|
|
||||||
mt, _ := src["media_type"].(string)
|
|
||||||
if mt == "" {
|
|
||||||
mt = "image"
|
|
||||||
}
|
|
||||||
return "[Image: base64 " + mt + "]"
|
|
||||||
case "url":
|
|
||||||
url, _ := src["url"].(string)
|
|
||||||
return "[Image: " + url + "]"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "[Image]"
|
|
||||||
case "document":
|
|
||||||
title, _ := v["title"].(string)
|
|
||||||
if title == "" {
|
|
||||||
if src, ok := v["source"].(map[string]interface{}); ok {
|
|
||||||
title, _ = src["url"].(string)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if title != "" {
|
|
||||||
return "[Document: " + title + "]"
|
|
||||||
}
|
|
||||||
return "[Document]"
|
|
||||||
case "tool_use":
|
|
||||||
name, _ := v["name"].(string)
|
|
||||||
id, _ := v["id"].(string)
|
|
||||||
input := v["input"]
|
|
||||||
inputJSON, _ := json.Marshal(input)
|
|
||||||
if inputJSON == nil {
|
|
||||||
inputJSON = []byte("{}")
|
|
||||||
}
|
|
||||||
tag := fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, string(inputJSON))
|
|
||||||
if id != "" {
|
|
||||||
tag = fmt.Sprintf("[tool_use_id=%s] ", id) + tag
|
|
||||||
}
|
|
||||||
return tag
|
|
||||||
case "tool_result":
|
|
||||||
toolUseID, _ := v["tool_use_id"].(string)
|
|
||||||
content := v["content"]
|
|
||||||
var contentText string
|
|
||||||
switch c := content.(type) {
|
|
||||||
case string:
|
|
||||||
contentText = c
|
|
||||||
case []interface{}:
|
|
||||||
var parts []string
|
|
||||||
for _, block := range c {
|
|
||||||
if bm, ok := block.(map[string]interface{}); ok {
|
|
||||||
if bm["type"] == "text" {
|
|
||||||
if t, ok := bm["text"].(string); ok {
|
|
||||||
parts = append(parts, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
contentText = strings.Join(parts, "\n")
|
|
||||||
}
|
|
||||||
label := "Tool result"
|
|
||||||
if toolUseID != "" {
|
|
||||||
label += " [id=" + toolUseID + "]"
|
|
||||||
}
|
|
||||||
return label + ": " + contentText
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func anthropicContentToText(content interface{}) string {
|
|
||||||
switch v := content.(type) {
|
|
||||||
case string:
|
|
||||||
return v
|
|
||||||
case []interface{}:
|
|
||||||
var parts []string
|
|
||||||
for _, p := range v {
|
|
||||||
if t := anthropicBlockToText(p); t != "" {
|
|
||||||
parts = append(parts, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(parts, " ")
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func BuildPromptFromAnthropicMessages(messages []MessageParam, system interface{}) string {
|
|
||||||
var oaiMessages []interface{}
|
|
||||||
|
|
||||||
systemText := systemToText(system)
|
|
||||||
if systemText != "" {
|
|
||||||
oaiMessages = append(oaiMessages, map[string]interface{}{
|
|
||||||
"role": "system",
|
|
||||||
"content": systemText,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range messages {
|
|
||||||
text := anthropicContentToText(m.Content)
|
|
||||||
if text == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
role := m.Role
|
|
||||||
if role != "user" && role != "assistant" {
|
|
||||||
role = "user"
|
|
||||||
}
|
|
||||||
oaiMessages = append(oaiMessages, map[string]interface{}{
|
|
||||||
"role": role,
|
|
||||||
"content": text,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return openai.BuildPromptFromMessages(oaiMessages)
|
|
||||||
}
|
|
||||||
|
|
@ -1,109 +0,0 @@
|
||||||
package anthropic_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/anthropic"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBuildPromptFromAnthropicMessages_Simple(t *testing.T) {
|
|
||||||
messages := []anthropic.MessageParam{
|
|
||||||
{Role: "user", Content: "Hello"},
|
|
||||||
{Role: "assistant", Content: "Hi there"},
|
|
||||||
}
|
|
||||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
|
||||||
if !strings.Contains(prompt, "Hello") {
|
|
||||||
t.Errorf("prompt missing user message: %q", prompt)
|
|
||||||
}
|
|
||||||
if !strings.Contains(prompt, "Hi there") {
|
|
||||||
t.Errorf("prompt missing assistant message: %q", prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildPromptFromAnthropicMessages_WithSystem(t *testing.T) {
|
|
||||||
messages := []anthropic.MessageParam{
|
|
||||||
{Role: "user", Content: "ping"},
|
|
||||||
}
|
|
||||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, "You are a helpful bot.")
|
|
||||||
if !strings.Contains(prompt, "You are a helpful bot.") {
|
|
||||||
t.Errorf("prompt missing system: %q", prompt)
|
|
||||||
}
|
|
||||||
if !strings.Contains(prompt, "ping") {
|
|
||||||
t.Errorf("prompt missing user: %q", prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildPromptFromAnthropicMessages_SystemArray(t *testing.T) {
|
|
||||||
system := []interface{}{
|
|
||||||
map[string]interface{}{"type": "text", "text": "Part A"},
|
|
||||||
map[string]interface{}{"type": "text", "text": "Part B"},
|
|
||||||
}
|
|
||||||
messages := []anthropic.MessageParam{
|
|
||||||
{Role: "user", Content: "test"},
|
|
||||||
}
|
|
||||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, system)
|
|
||||||
if !strings.Contains(prompt, "Part A") {
|
|
||||||
t.Errorf("prompt missing Part A: %q", prompt)
|
|
||||||
}
|
|
||||||
if !strings.Contains(prompt, "Part B") {
|
|
||||||
t.Errorf("prompt missing Part B: %q", prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildPromptFromAnthropicMessages_ContentBlocks(t *testing.T) {
|
|
||||||
content := []interface{}{
|
|
||||||
map[string]interface{}{"type": "text", "text": "block one"},
|
|
||||||
map[string]interface{}{"type": "text", "text": "block two"},
|
|
||||||
}
|
|
||||||
messages := []anthropic.MessageParam{
|
|
||||||
{Role: "user", Content: content},
|
|
||||||
}
|
|
||||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
|
||||||
if !strings.Contains(prompt, "block one") {
|
|
||||||
t.Errorf("prompt missing 'block one': %q", prompt)
|
|
||||||
}
|
|
||||||
if !strings.Contains(prompt, "block two") {
|
|
||||||
t.Errorf("prompt missing 'block two': %q", prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildPromptFromAnthropicMessages_ImageBlock(t *testing.T) {
|
|
||||||
content := []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"type": "image",
|
|
||||||
"source": map[string]interface{}{
|
|
||||||
"type": "base64",
|
|
||||||
"media_type": "image/png",
|
|
||||||
"data": "abc123",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
messages := []anthropic.MessageParam{
|
|
||||||
{Role: "user", Content: content},
|
|
||||||
}
|
|
||||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
|
||||||
if !strings.Contains(prompt, "[Image") {
|
|
||||||
t.Errorf("prompt missing [Image]: %q", prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildPromptFromAnthropicMessages_EmptyContentSkipped(t *testing.T) {
|
|
||||||
messages := []anthropic.MessageParam{
|
|
||||||
{Role: "user", Content: ""},
|
|
||||||
{Role: "assistant", Content: "response"},
|
|
||||||
}
|
|
||||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
|
||||||
if !strings.Contains(prompt, "response") {
|
|
||||||
t.Errorf("prompt missing 'response': %q", prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildPromptFromAnthropicMessages_UnknownRoleBecomesUser(t *testing.T) {
|
|
||||||
messages := []anthropic.MessageParam{
|
|
||||||
{Role: "system", Content: "system-as-user"},
|
|
||||||
}
|
|
||||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
|
||||||
if !strings.Contains(prompt, "system-as-user") {
|
|
||||||
t.Errorf("prompt missing 'system-as-user': %q", prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
||||||
package apitypes
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Role string
|
|
||||||
Content string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Type string
|
|
||||||
Function ToolFunction
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolFunction struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
Parameters interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
Arguments string
|
|
||||||
}
|
|
||||||
|
|
||||||
type StreamChunk struct {
|
|
||||||
Type ChunkType
|
|
||||||
Text string
|
|
||||||
Thinking string
|
|
||||||
ToolCall *ToolCall
|
|
||||||
Done bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChunkType int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ChunkText ChunkType = iota
|
|
||||||
ChunkThinking
|
|
||||||
ChunkToolCall
|
|
||||||
ChunkDone
|
|
||||||
)
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/env"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BridgeConfig struct {
|
|
||||||
AgentBin string
|
|
||||||
Host string
|
|
||||||
Port int
|
|
||||||
RequiredKey string
|
|
||||||
DefaultModel string
|
|
||||||
Mode string
|
|
||||||
Provider string
|
|
||||||
Force bool
|
|
||||||
ApproveMcps bool
|
|
||||||
StrictModel bool
|
|
||||||
Workspace string
|
|
||||||
TimeoutMs int
|
|
||||||
TLSCertPath string
|
|
||||||
TLSKeyPath string
|
|
||||||
SessionsLogPath string
|
|
||||||
ChatOnlyWorkspace bool
|
|
||||||
Verbose bool
|
|
||||||
MaxMode bool
|
|
||||||
ConfigDirs []string
|
|
||||||
MultiPort bool
|
|
||||||
WinCmdlineMax int
|
|
||||||
GeminiAccountDir string
|
|
||||||
GeminiBrowserVisible bool
|
|
||||||
GeminiMaxSessions int
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadBridgeConfig(e env.EnvSource, cwd string) BridgeConfig {
|
|
||||||
loaded := env.LoadEnvConfig(e, cwd)
|
|
||||||
return BridgeConfig{
|
|
||||||
AgentBin: loaded.AgentBin,
|
|
||||||
Host: loaded.Host,
|
|
||||||
Port: loaded.Port,
|
|
||||||
RequiredKey: loaded.RequiredKey,
|
|
||||||
DefaultModel: loaded.DefaultModel,
|
|
||||||
Mode: "ask",
|
|
||||||
Provider: loaded.Provider,
|
|
||||||
Force: loaded.Force,
|
|
||||||
ApproveMcps: loaded.ApproveMcps,
|
|
||||||
StrictModel: loaded.StrictModel,
|
|
||||||
Workspace: loaded.Workspace,
|
|
||||||
TimeoutMs: loaded.TimeoutMs,
|
|
||||||
TLSCertPath: loaded.TLSCertPath,
|
|
||||||
TLSKeyPath: loaded.TLSKeyPath,
|
|
||||||
SessionsLogPath: loaded.SessionsLogPath,
|
|
||||||
ChatOnlyWorkspace: loaded.ChatOnlyWorkspace,
|
|
||||||
Verbose: loaded.Verbose,
|
|
||||||
MaxMode: loaded.MaxMode,
|
|
||||||
ConfigDirs: loaded.ConfigDirs,
|
|
||||||
MultiPort: loaded.MultiPort,
|
|
||||||
WinCmdlineMax: loaded.WinCmdlineMax,
|
|
||||||
GeminiAccountDir: loaded.GeminiAccountDir,
|
|
||||||
GeminiBrowserVisible: loaded.GeminiBrowserVisible,
|
|
||||||
GeminiMaxSessions: loaded.GeminiMaxSessions,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,123 +0,0 @@
|
||||||
package config_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/env"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoadBridgeConfig_Defaults(t *testing.T) {
|
|
||||||
cfg := config.LoadBridgeConfig(env.EnvSource{}, "/workspace")
|
|
||||||
|
|
||||||
if cfg.AgentBin != "agent" {
|
|
||||||
t.Errorf("AgentBin = %q, want %q", cfg.AgentBin, "agent")
|
|
||||||
}
|
|
||||||
if cfg.Host != "127.0.0.1" {
|
|
||||||
t.Errorf("Host = %q, want %q", cfg.Host, "127.0.0.1")
|
|
||||||
}
|
|
||||||
if cfg.Port != 8765 {
|
|
||||||
t.Errorf("Port = %d, want 8765", cfg.Port)
|
|
||||||
}
|
|
||||||
if cfg.RequiredKey != "" {
|
|
||||||
t.Errorf("RequiredKey = %q, want empty", cfg.RequiredKey)
|
|
||||||
}
|
|
||||||
if cfg.DefaultModel != "auto" {
|
|
||||||
t.Errorf("DefaultModel = %q, want %q", cfg.DefaultModel, "auto")
|
|
||||||
}
|
|
||||||
if cfg.Force {
|
|
||||||
t.Error("Force should be false")
|
|
||||||
}
|
|
||||||
if cfg.ApproveMcps {
|
|
||||||
t.Error("ApproveMcps should be false")
|
|
||||||
}
|
|
||||||
if !cfg.StrictModel {
|
|
||||||
t.Error("StrictModel should be true")
|
|
||||||
}
|
|
||||||
if cfg.Mode != "ask" {
|
|
||||||
t.Errorf("Mode = %q, want %q", cfg.Mode, "ask")
|
|
||||||
}
|
|
||||||
if cfg.Workspace != "/workspace" {
|
|
||||||
t.Errorf("Workspace = %q, want /workspace", cfg.Workspace)
|
|
||||||
}
|
|
||||||
if !cfg.ChatOnlyWorkspace {
|
|
||||||
t.Error("ChatOnlyWorkspace should be true")
|
|
||||||
}
|
|
||||||
if cfg.WinCmdlineMax != 30000 {
|
|
||||||
t.Errorf("WinCmdlineMax = %d, want 30000", cfg.WinCmdlineMax)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadBridgeConfig_FromEnv(t *testing.T) {
|
|
||||||
e := env.EnvSource{
|
|
||||||
"CURSOR_AGENT_BIN": "/usr/bin/agent",
|
|
||||||
"CURSOR_BRIDGE_HOST": "0.0.0.0",
|
|
||||||
"CURSOR_BRIDGE_PORT": "9999",
|
|
||||||
"CURSOR_BRIDGE_API_KEY": "sk-secret",
|
|
||||||
"CURSOR_BRIDGE_DEFAULT_MODEL": "org/claude-3-opus",
|
|
||||||
"CURSOR_BRIDGE_FORCE": "true",
|
|
||||||
"CURSOR_BRIDGE_APPROVE_MCPS": "yes",
|
|
||||||
"CURSOR_BRIDGE_STRICT_MODEL": "false",
|
|
||||||
"CURSOR_BRIDGE_WORKSPACE": "./my-workspace",
|
|
||||||
"CURSOR_BRIDGE_TIMEOUT_MS": "60000",
|
|
||||||
"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE": "false",
|
|
||||||
"CURSOR_BRIDGE_VERBOSE": "1",
|
|
||||||
"CURSOR_BRIDGE_TLS_CERT": "./certs/test.crt",
|
|
||||||
"CURSOR_BRIDGE_TLS_KEY": "./certs/test.key",
|
|
||||||
}
|
|
||||||
cfg := config.LoadBridgeConfig(e, "/tmp/project")
|
|
||||||
|
|
||||||
if cfg.AgentBin != "/usr/bin/agent" {
|
|
||||||
t.Errorf("AgentBin = %q, want /usr/bin/agent", cfg.AgentBin)
|
|
||||||
}
|
|
||||||
if cfg.Host != "0.0.0.0" {
|
|
||||||
t.Errorf("Host = %q, want 0.0.0.0", cfg.Host)
|
|
||||||
}
|
|
||||||
if cfg.Port != 9999 {
|
|
||||||
t.Errorf("Port = %d, want 9999", cfg.Port)
|
|
||||||
}
|
|
||||||
if cfg.RequiredKey != "sk-secret" {
|
|
||||||
t.Errorf("RequiredKey = %q, want sk-secret", cfg.RequiredKey)
|
|
||||||
}
|
|
||||||
if cfg.DefaultModel != "claude-3-opus" {
|
|
||||||
t.Errorf("DefaultModel = %q, want claude-3-opus", cfg.DefaultModel)
|
|
||||||
}
|
|
||||||
if !cfg.Force {
|
|
||||||
t.Error("Force should be true")
|
|
||||||
}
|
|
||||||
if !cfg.ApproveMcps {
|
|
||||||
t.Error("ApproveMcps should be true")
|
|
||||||
}
|
|
||||||
if cfg.StrictModel {
|
|
||||||
t.Error("StrictModel should be false")
|
|
||||||
}
|
|
||||||
if !filepath.IsAbs(cfg.Workspace) {
|
|
||||||
t.Errorf("Workspace should be absolute, got %q", cfg.Workspace)
|
|
||||||
}
|
|
||||||
if !strings.Contains(cfg.Workspace, "my-workspace") {
|
|
||||||
t.Errorf("Workspace %q should contain 'my-workspace'", cfg.Workspace)
|
|
||||||
}
|
|
||||||
if cfg.TimeoutMs != 60000 {
|
|
||||||
t.Errorf("TimeoutMs = %d, want 60000", cfg.TimeoutMs)
|
|
||||||
}
|
|
||||||
if cfg.ChatOnlyWorkspace {
|
|
||||||
t.Error("ChatOnlyWorkspace should be false")
|
|
||||||
}
|
|
||||||
if !cfg.Verbose {
|
|
||||||
t.Error("Verbose should be true")
|
|
||||||
}
|
|
||||||
if cfg.TLSCertPath != "/tmp/project/certs/test.crt" {
|
|
||||||
t.Errorf("TLSCertPath = %q, want /tmp/project/certs/test.crt", cfg.TLSCertPath)
|
|
||||||
}
|
|
||||||
if cfg.TLSKeyPath != "/tmp/project/certs/test.key" {
|
|
||||||
t.Errorf("TLSKeyPath = %q, want /tmp/project/certs/test.key", cfg.TLSKeyPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadBridgeConfig_WideHost(t *testing.T) {
|
|
||||||
cfg := config.LoadBridgeConfig(env.EnvSource{"CURSOR_BRIDGE_HOST": "0.0.0.0"}, "/workspace")
|
|
||||||
if cfg.Host != "0.0.0.0" {
|
|
||||||
t.Errorf("Host = %q, want 0.0.0.0", cfg.Host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,381 +0,0 @@
|
||||||
package env
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type EnvSource map[string]string
|
|
||||||
|
|
||||||
type LoadedEnv struct {
|
|
||||||
AgentBin string
|
|
||||||
AgentNode string
|
|
||||||
AgentScript string
|
|
||||||
CommandShell string
|
|
||||||
Host string
|
|
||||||
Port int
|
|
||||||
RequiredKey string
|
|
||||||
DefaultModel string
|
|
||||||
Provider string
|
|
||||||
Force bool
|
|
||||||
ApproveMcps bool
|
|
||||||
StrictModel bool
|
|
||||||
Workspace string
|
|
||||||
TimeoutMs int
|
|
||||||
TLSCertPath string
|
|
||||||
TLSKeyPath string
|
|
||||||
SessionsLogPath string
|
|
||||||
ChatOnlyWorkspace bool
|
|
||||||
Verbose bool
|
|
||||||
MaxMode bool
|
|
||||||
ConfigDirs []string
|
|
||||||
MultiPort bool
|
|
||||||
WinCmdlineMax int
|
|
||||||
GeminiAccountDir string
|
|
||||||
GeminiBrowserVisible bool
|
|
||||||
GeminiMaxSessions int
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentCommand struct {
|
|
||||||
Command string
|
|
||||||
Args []string
|
|
||||||
Env map[string]string
|
|
||||||
WindowsVerbatimArguments bool
|
|
||||||
AgentScriptPath string
|
|
||||||
ConfigDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
func getEnvVal(e EnvSource, names []string) string {
|
|
||||||
for _, name := range names {
|
|
||||||
if v, ok := e[name]; ok && strings.TrimSpace(v) != "" {
|
|
||||||
return strings.TrimSpace(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func envBool(e EnvSource, names []string, def bool) bool {
|
|
||||||
raw := getEnvVal(e, names)
|
|
||||||
if raw == "" {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
switch strings.ToLower(raw) {
|
|
||||||
case "1", "true", "yes", "on":
|
|
||||||
return true
|
|
||||||
case "0", "false", "no", "off":
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
|
|
||||||
func envInt(e EnvSource, names []string, def int) int {
|
|
||||||
raw := getEnvVal(e, names)
|
|
||||||
if raw == "" {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
v, err := strconv.Atoi(raw)
|
|
||||||
if err != nil {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeModelId(raw string) string {
|
|
||||||
raw = strings.TrimSpace(raw)
|
|
||||||
if raw == "" {
|
|
||||||
return "auto"
|
|
||||||
}
|
|
||||||
parts := strings.Split(raw, "/")
|
|
||||||
last := parts[len(parts)-1]
|
|
||||||
if last == "" {
|
|
||||||
return "auto"
|
|
||||||
}
|
|
||||||
return last
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveAbs(raw, cwd string) string {
|
|
||||||
if raw == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if filepath.IsAbs(raw) {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
return filepath.Join(cwd, raw)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isAuthenticatedAccountDir(dir string) bool {
|
|
||||||
data, err := os.ReadFile(filepath.Join(dir, "cli-config.json"))
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var cfg struct {
|
|
||||||
AuthInfo *struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
} `json:"authInfo"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return cfg.AuthInfo != nil && cfg.AuthInfo.Email != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func discoverAccountDirs(homeDir string) []string {
|
|
||||||
if homeDir == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
accountsDir := filepath.Join(homeDir, ".cursor-api-proxy", "accounts")
|
|
||||||
entries, err := os.ReadDir(accountsDir)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var dirs []string
|
|
||||||
for _, e := range entries {
|
|
||||||
if !e.IsDir() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dir := filepath.Join(accountsDir, e.Name())
|
|
||||||
if isAuthenticatedAccountDir(dir) {
|
|
||||||
dirs = append(dirs, dir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dirs
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseDotEnv(path string) EnvSource {
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m := make(EnvSource)
|
|
||||||
for _, line := range strings.Split(string(data), "\n") {
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
if line == "" || strings.HasPrefix(line, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parts := strings.SplitN(line, "=", 2)
|
|
||||||
if len(parts) == 2 {
|
|
||||||
m[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func OsEnvToMap(cwdHint ...string) EnvSource {
|
|
||||||
m := make(EnvSource)
|
|
||||||
for _, kv := range os.Environ() {
|
|
||||||
parts := strings.SplitN(kv, "=", 2)
|
|
||||||
if len(parts) == 2 {
|
|
||||||
m[parts[0]] = parts[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cwd := ""
|
|
||||||
if len(cwdHint) > 0 && cwdHint[0] != "" {
|
|
||||||
cwd = cwdHint[0]
|
|
||||||
} else {
|
|
||||||
cwd, _ = os.Getwd()
|
|
||||||
}
|
|
||||||
|
|
||||||
if dotenv := parseDotEnv(filepath.Join(cwd, ".env")); dotenv != nil {
|
|
||||||
for k, v := range dotenv {
|
|
||||||
if _, exists := m[k]; !exists {
|
|
||||||
m[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadEnvConfig(e EnvSource, cwd string) LoadedEnv {
|
|
||||||
if e == nil {
|
|
||||||
e = OsEnvToMap()
|
|
||||||
}
|
|
||||||
if cwd == "" {
|
|
||||||
var err error
|
|
||||||
cwd, err = os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
cwd = "."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
host := getEnvVal(e, []string{"CURSOR_BRIDGE_HOST"})
|
|
||||||
if host == "" {
|
|
||||||
host = "127.0.0.1"
|
|
||||||
}
|
|
||||||
port := envInt(e, []string{"CURSOR_BRIDGE_PORT"}, 8765)
|
|
||||||
if port <= 0 {
|
|
||||||
port = 8765
|
|
||||||
}
|
|
||||||
|
|
||||||
home := getEnvVal(e, []string{"HOME", "USERPROFILE"})
|
|
||||||
|
|
||||||
sessionsLogPath := func() string {
|
|
||||||
if p := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_SESSIONS_LOG"}), cwd); p != "" {
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
if home != "" {
|
|
||||||
return filepath.Join(home, ".cursor-api-proxy", "sessions.log")
|
|
||||||
}
|
|
||||||
return filepath.Join(cwd, "sessions.log")
|
|
||||||
}()
|
|
||||||
|
|
||||||
var configDirs []string
|
|
||||||
if raw := getEnvVal(e, []string{"CURSOR_CONFIG_DIRS", "CURSOR_ACCOUNT_DIRS"}); raw != "" {
|
|
||||||
for _, d := range strings.Split(raw, ",") {
|
|
||||||
d = strings.TrimSpace(d)
|
|
||||||
if d != "" {
|
|
||||||
if p := resolveAbs(d, cwd); p != "" {
|
|
||||||
configDirs = append(configDirs, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(configDirs) == 0 {
|
|
||||||
configDirs = discoverAccountDirs(home)
|
|
||||||
}
|
|
||||||
|
|
||||||
winMax := envInt(e, []string{"CURSOR_BRIDGE_WIN_CMDLINE_MAX"}, 30000)
|
|
||||||
if winMax < 4096 {
|
|
||||||
winMax = 4096
|
|
||||||
}
|
|
||||||
if winMax > 32700 {
|
|
||||||
winMax = 32700
|
|
||||||
}
|
|
||||||
|
|
||||||
agentBin := getEnvVal(e, []string{"CURSOR_AGENT_BIN", "CURSOR_CLI_BIN", "CURSOR_CLI_PATH"})
|
|
||||||
if agentBin == "" {
|
|
||||||
agentBin = "agent"
|
|
||||||
}
|
|
||||||
commandShell := getEnvVal(e, []string{"COMSPEC"})
|
|
||||||
if commandShell == "" {
|
|
||||||
commandShell = "cmd.exe"
|
|
||||||
}
|
|
||||||
workspace := resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_WORKSPACE"}), cwd)
|
|
||||||
if workspace == "" {
|
|
||||||
workspace = cwd
|
|
||||||
}
|
|
||||||
|
|
||||||
geminiAccountDir := getEnvVal(e, []string{"GEMINI_ACCOUNT_DIR"})
|
|
||||||
if geminiAccountDir == "" {
|
|
||||||
geminiAccountDir = filepath.Join(home, ".cursor-api-proxy", "gemini-accounts")
|
|
||||||
} else {
|
|
||||||
geminiAccountDir = resolveAbs(geminiAccountDir, cwd)
|
|
||||||
}
|
|
||||||
|
|
||||||
return LoadedEnv{
|
|
||||||
AgentBin: agentBin,
|
|
||||||
AgentNode: getEnvVal(e, []string{"CURSOR_AGENT_NODE"}),
|
|
||||||
AgentScript: getEnvVal(e, []string{"CURSOR_AGENT_SCRIPT"}),
|
|
||||||
CommandShell: commandShell,
|
|
||||||
Host: host,
|
|
||||||
Port: port,
|
|
||||||
RequiredKey: getEnvVal(e, []string{"CURSOR_BRIDGE_API_KEY"}),
|
|
||||||
DefaultModel: normalizeModelId(getEnvVal(e, []string{"CURSOR_BRIDGE_DEFAULT_MODEL"})),
|
|
||||||
Provider: getEnvVal(e, []string{"CURSOR_BRIDGE_PROVIDER"}),
|
|
||||||
Force: envBool(e, []string{"CURSOR_BRIDGE_FORCE"}, false),
|
|
||||||
ApproveMcps: envBool(e, []string{"CURSOR_BRIDGE_APPROVE_MCPS"}, false),
|
|
||||||
StrictModel: envBool(e, []string{"CURSOR_BRIDGE_STRICT_MODEL"}, true),
|
|
||||||
Workspace: workspace,
|
|
||||||
TimeoutMs: envInt(e, []string{"CURSOR_BRIDGE_TIMEOUT_MS"}, 300000),
|
|
||||||
TLSCertPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_CERT"}), cwd),
|
|
||||||
TLSKeyPath: resolveAbs(getEnvVal(e, []string{"CURSOR_BRIDGE_TLS_KEY"}), cwd),
|
|
||||||
SessionsLogPath: sessionsLogPath,
|
|
||||||
ChatOnlyWorkspace: envBool(e, []string{"CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE"}, true),
|
|
||||||
Verbose: envBool(e, []string{"CURSOR_BRIDGE_VERBOSE"}, false),
|
|
||||||
MaxMode: envBool(e, []string{"CURSOR_BRIDGE_MAX_MODE"}, false),
|
|
||||||
ConfigDirs: configDirs,
|
|
||||||
MultiPort: envBool(e, []string{"CURSOR_BRIDGE_MULTI_PORT"}, false),
|
|
||||||
WinCmdlineMax: winMax,
|
|
||||||
GeminiAccountDir: geminiAccountDir,
|
|
||||||
GeminiBrowserVisible: envBool(e, []string{"GEMINI_BROWSER_VISIBLE"}, false),
|
|
||||||
GeminiMaxSessions: envInt(e, []string{"GEMINI_MAX_SESSIONS"}, 3),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResolveAgentCommand(cmd string, args []string, e EnvSource, cwd string) AgentCommand {
|
|
||||||
if e == nil {
|
|
||||||
e = OsEnvToMap()
|
|
||||||
}
|
|
||||||
loaded := LoadEnvConfig(e, cwd)
|
|
||||||
|
|
||||||
cloneEnv := func() map[string]string {
|
|
||||||
m := make(map[string]string, len(e))
|
|
||||||
for k, v := range e {
|
|
||||||
m[k] = v
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
if loaded.AgentNode != "" && loaded.AgentScript != "" {
|
|
||||||
agentScriptPath := loaded.AgentScript
|
|
||||||
if !filepath.IsAbs(agentScriptPath) {
|
|
||||||
agentScriptPath = filepath.Join(cwd, agentScriptPath)
|
|
||||||
}
|
|
||||||
agentDir := filepath.Dir(agentScriptPath)
|
|
||||||
configDir := filepath.Join(agentDir, "..", "data", "config")
|
|
||||||
env2 := cloneEnv()
|
|
||||||
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
|
|
||||||
ac := AgentCommand{
|
|
||||||
Command: loaded.AgentNode,
|
|
||||||
Args: append([]string{loaded.AgentScript}, args...),
|
|
||||||
Env: env2,
|
|
||||||
AgentScriptPath: agentScriptPath,
|
|
||||||
}
|
|
||||||
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
|
|
||||||
ac.ConfigDir = configDir
|
|
||||||
}
|
|
||||||
return ac
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasSuffix(strings.ToLower(cmd), ".cmd") {
|
|
||||||
cmdResolved := cmd
|
|
||||||
if !filepath.IsAbs(cmd) {
|
|
||||||
cmdResolved = filepath.Join(cwd, cmd)
|
|
||||||
}
|
|
||||||
dir := filepath.Dir(cmdResolved)
|
|
||||||
nodeBin := filepath.Join(dir, "node.exe")
|
|
||||||
script := filepath.Join(dir, "index.js")
|
|
||||||
if _, err1 := os.Stat(nodeBin); err1 == nil {
|
|
||||||
if _, err2 := os.Stat(script); err2 == nil {
|
|
||||||
configDir := filepath.Join(dir, "..", "data", "config")
|
|
||||||
env2 := cloneEnv()
|
|
||||||
env2["CURSOR_INVOKED_AS"] = "agent.cmd"
|
|
||||||
ac := AgentCommand{
|
|
||||||
Command: nodeBin,
|
|
||||||
Args: append([]string{script}, args...),
|
|
||||||
Env: env2,
|
|
||||||
AgentScriptPath: script,
|
|
||||||
}
|
|
||||||
if _, err := os.Stat(filepath.Join(configDir, "cli-config.json")); err == nil {
|
|
||||||
ac.ConfigDir = configDir
|
|
||||||
}
|
|
||||||
return ac
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
quotedArgs := make([]string, len(args))
|
|
||||||
for i, a := range args {
|
|
||||||
if strings.Contains(a, " ") {
|
|
||||||
quotedArgs[i] = `"` + a + `"`
|
|
||||||
} else {
|
|
||||||
quotedArgs[i] = a
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cmdLine := `""` + cmd + `" ` + strings.Join(quotedArgs, " ") + `"`
|
|
||||||
return AgentCommand{
|
|
||||||
Command: loaded.CommandShell,
|
|
||||||
Args: []string{"/d", "/s", "/c", cmdLine},
|
|
||||||
Env: cloneEnv(),
|
|
||||||
WindowsVerbatimArguments: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return AgentCommand{Command: cmd, Args: args, Env: cloneEnv()}
|
|
||||||
}
|
|
||||||
|
|
@ -1,65 +0,0 @@
|
||||||
package env
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestLoadEnvConfigDefaults(t *testing.T) {
|
|
||||||
e := EnvSource{}
|
|
||||||
loaded := LoadEnvConfig(e, "/tmp")
|
|
||||||
|
|
||||||
if loaded.Host != "127.0.0.1" {
|
|
||||||
t.Errorf("expected 127.0.0.1, got %s", loaded.Host)
|
|
||||||
}
|
|
||||||
if loaded.Port != 8765 {
|
|
||||||
t.Errorf("expected 8765, got %d", loaded.Port)
|
|
||||||
}
|
|
||||||
if loaded.DefaultModel != "auto" {
|
|
||||||
t.Errorf("expected auto, got %s", loaded.DefaultModel)
|
|
||||||
}
|
|
||||||
if loaded.AgentBin != "agent" {
|
|
||||||
t.Errorf("expected agent, got %s", loaded.AgentBin)
|
|
||||||
}
|
|
||||||
if !loaded.StrictModel {
|
|
||||||
t.Error("expected strictModel=true by default")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadEnvConfigOverride(t *testing.T) {
|
|
||||||
e := EnvSource{
|
|
||||||
"CURSOR_BRIDGE_HOST": "0.0.0.0",
|
|
||||||
"CURSOR_BRIDGE_PORT": "9000",
|
|
||||||
"CURSOR_BRIDGE_DEFAULT_MODEL": "gpt-4",
|
|
||||||
"CURSOR_AGENT_BIN": "/usr/local/bin/agent",
|
|
||||||
}
|
|
||||||
loaded := LoadEnvConfig(e, "/tmp")
|
|
||||||
|
|
||||||
if loaded.Host != "0.0.0.0" {
|
|
||||||
t.Errorf("expected 0.0.0.0, got %s", loaded.Host)
|
|
||||||
}
|
|
||||||
if loaded.Port != 9000 {
|
|
||||||
t.Errorf("expected 9000, got %d", loaded.Port)
|
|
||||||
}
|
|
||||||
if loaded.DefaultModel != "gpt-4" {
|
|
||||||
t.Errorf("expected gpt-4, got %s", loaded.DefaultModel)
|
|
||||||
}
|
|
||||||
if loaded.AgentBin != "/usr/local/bin/agent" {
|
|
||||||
t.Errorf("expected /usr/local/bin/agent, got %s", loaded.AgentBin)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNormalizeModelID(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"gpt-4", "gpt-4"},
|
|
||||||
{"openai/gpt-4", "gpt-4"},
|
|
||||||
{"", "auto"},
|
|
||||||
{" ", "auto"},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
got := normalizeModelId(tc.input)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("normalizeModelId(%q) = %q, want %q", tc.input, got, tc.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,577 +0,0 @@
|
||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/agent"
|
|
||||||
"cursor-api-proxy/internal/anthropic"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/httputil"
|
|
||||||
"cursor-api-proxy/internal/logger"
|
|
||||||
"cursor-api-proxy/internal/models"
|
|
||||||
"cursor-api-proxy/internal/openai"
|
|
||||||
"cursor-api-proxy/internal/parser"
|
|
||||||
"cursor-api-proxy/internal/pool"
|
|
||||||
"cursor-api-proxy/internal/sanitize"
|
|
||||||
"cursor-api-proxy/internal/toolcall"
|
|
||||||
"cursor-api-proxy/internal/winlimit"
|
|
||||||
"cursor-api-proxy/internal/workspace"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
|
||||||
var req anthropic.MessagesRequest
|
|
||||||
if err := json.Unmarshal([]byte(rawBody), &req); err != nil {
|
|
||||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
|
||||||
"error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requested := openai.NormalizeModelID(req.Model)
|
|
||||||
model := ResolveModel(requested, lastModelRef, cfg)
|
|
||||||
|
|
||||||
var rawMap map[string]interface{}
|
|
||||||
_ = json.Unmarshal([]byte(rawBody), &rawMap)
|
|
||||||
|
|
||||||
cleanSystem := sanitize.SanitizeSystem(req.System)
|
|
||||||
|
|
||||||
rawMessages := make([]interface{}, len(req.Messages))
|
|
||||||
for i, m := range req.Messages {
|
|
||||||
rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
|
||||||
}
|
|
||||||
cleanRawMessages := sanitize.SanitizeMessages(rawMessages)
|
|
||||||
|
|
||||||
var cleanMessages []anthropic.MessageParam
|
|
||||||
for _, raw := range cleanRawMessages {
|
|
||||||
if m, ok := raw.(map[string]interface{}); ok {
|
|
||||||
role, _ := m["role"].(string)
|
|
||||||
cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
toolsText := openai.ToolsToSystemText(req.Tools, nil)
|
|
||||||
var systemWithTools interface{}
|
|
||||||
if toolsText != "" {
|
|
||||||
sysStr := ""
|
|
||||||
switch v := cleanSystem.(type) {
|
|
||||||
case string:
|
|
||||||
sysStr = v
|
|
||||||
}
|
|
||||||
if sysStr != "" {
|
|
||||||
systemWithTools = sysStr + "\n\n" + toolsText
|
|
||||||
} else {
|
|
||||||
systemWithTools = toolsText
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
systemWithTools = cleanSystem
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools)
|
|
||||||
|
|
||||||
if req.MaxTokens == 0 {
|
|
||||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
|
||||||
"error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cursorModel := models.ResolveToCursorModel(model)
|
|
||||||
if cursorModel == "" {
|
|
||||||
cursorModel = model
|
|
||||||
}
|
|
||||||
|
|
||||||
var trafficMsgs []logger.TrafficMessage
|
|
||||||
if s := systemToString(cleanSystem); s != "" {
|
|
||||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s})
|
|
||||||
}
|
|
||||||
for _, m := range cleanMessages {
|
|
||||||
text := contentToString(m.Content)
|
|
||||||
if text != "" {
|
|
||||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream)
|
|
||||||
|
|
||||||
headerWs := r.Header.Get("x-cursor-workspace")
|
|
||||||
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
|
||||||
|
|
||||||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream)
|
|
||||||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
|
||||||
|
|
||||||
if cfg.Verbose {
|
|
||||||
if len(prompt) > 200 {
|
|
||||||
logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...")
|
|
||||||
} else {
|
|
||||||
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt)
|
|
||||||
}
|
|
||||||
logger.LogDebug("cmd_args=%v", fit.Args)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !fit.OK {
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"type": "api_error", "message": fit.Error},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if fit.Truncated {
|
|
||||||
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdArgs := fit.Args
|
|
||||||
msgID := "msg_" + uuid.New().String()
|
|
||||||
|
|
||||||
var truncatedHeaders map[string]string
|
|
||||||
if fit.Truncated {
|
|
||||||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
|
||||||
}
|
|
||||||
|
|
||||||
hasTools := len(req.Tools) > 0
|
|
||||||
var toolNames map[string]bool
|
|
||||||
if hasTools {
|
|
||||||
toolNames = toolcall.CollectToolNames(req.Tools)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Stream {
|
|
||||||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
|
||||||
flusher, _ := w.(http.Flusher)
|
|
||||||
|
|
||||||
writeEvent := func(evt interface{}) {
|
|
||||||
data, _ := json.Marshal(evt)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var accumulated string
|
|
||||||
var accumulatedThinking string
|
|
||||||
var chunkNum int
|
|
||||||
var p parser.Parser
|
|
||||||
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"id": msgID,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": model,
|
|
||||||
"content": []interface{}{},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if hasTools {
|
|
||||||
toolCallMarkerRe := regexp.MustCompile(`行政法规|<function_calls>`)
|
|
||||||
var toolCallMode bool
|
|
||||||
|
|
||||||
textBlockOpen := false
|
|
||||||
textBlockIndex := 0
|
|
||||||
thinkingOpen := false
|
|
||||||
thinkingBlockIndex := 0
|
|
||||||
blockCount := 0
|
|
||||||
|
|
||||||
p = parser.CreateStreamParserWithThinking(
|
|
||||||
func(text string) {
|
|
||||||
accumulated += text
|
|
||||||
chunkNum++
|
|
||||||
logger.LogStreamChunk(model, text, chunkNum)
|
|
||||||
|
|
||||||
if toolCallMode {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if toolCallMarkerRe.MatchString(text) {
|
|
||||||
if textBlockOpen {
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
|
|
||||||
textBlockOpen = false
|
|
||||||
}
|
|
||||||
if thinkingOpen {
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
|
|
||||||
thinkingOpen = false
|
|
||||||
}
|
|
||||||
toolCallMode = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !textBlockOpen && !thinkingOpen {
|
|
||||||
textBlockIndex = blockCount
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": textBlockIndex,
|
|
||||||
"content_block": map[string]string{"type": "text", "text": ""},
|
|
||||||
})
|
|
||||||
textBlockOpen = true
|
|
||||||
blockCount++
|
|
||||||
}
|
|
||||||
if thinkingOpen {
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
|
|
||||||
thinkingOpen = false
|
|
||||||
}
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": textBlockIndex,
|
|
||||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
func(thinking string) {
|
|
||||||
accumulatedThinking += thinking
|
|
||||||
chunkNum++
|
|
||||||
if toolCallMode {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !thinkingOpen {
|
|
||||||
thinkingBlockIndex = blockCount
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": thinkingBlockIndex,
|
|
||||||
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
|
||||||
})
|
|
||||||
thinkingOpen = true
|
|
||||||
blockCount++
|
|
||||||
}
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": thinkingBlockIndex,
|
|
||||||
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
func() {
|
|
||||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
|
||||||
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
|
|
||||||
|
|
||||||
blockIndex := 0
|
|
||||||
if thinkingOpen {
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex})
|
|
||||||
thinkingOpen = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if parsed.HasToolCalls() {
|
|
||||||
if textBlockOpen {
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
|
|
||||||
blockIndex = textBlockIndex + 1
|
|
||||||
}
|
|
||||||
if parsed.TextContent != "" && !textBlockOpen && !toolCallMode {
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start", "index": blockIndex,
|
|
||||||
"content_block": map[string]string{"type": "text", "text": ""},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_delta", "index": blockIndex,
|
|
||||||
"delta": map[string]string{"type": "text_delta", "text": parsed.TextContent},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
|
||||||
blockIndex++
|
|
||||||
}
|
|
||||||
for _, tc := range parsed.ToolCalls {
|
|
||||||
toolID := "toolu_" + uuid.New().String()[:12]
|
|
||||||
var inputObj interface{}
|
|
||||||
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
|
|
||||||
if inputObj == nil {
|
|
||||||
inputObj = map[string]interface{}{}
|
|
||||||
}
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start", "index": blockIndex,
|
|
||||||
"content_block": map[string]interface{}{
|
|
||||||
"type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_delta", "index": blockIndex,
|
|
||||||
"delta": map[string]interface{}{
|
|
||||||
"type": "input_json_delta", "partial_json": tc.Arguments,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
|
||||||
blockIndex++
|
|
||||||
}
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil},
|
|
||||||
"usage": map[string]int{"output_tokens": 0},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
|
||||||
} else {
|
|
||||||
if textBlockOpen {
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex})
|
|
||||||
} else if accumulated != "" {
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start", "index": blockIndex,
|
|
||||||
"content_block": map[string]string{"type": "text", "text": ""},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_delta", "index": blockIndex,
|
|
||||||
"delta": map[string]string{"type": "text_delta", "text": accumulated},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
|
||||||
blockIndex++
|
|
||||||
} else {
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start", "index": blockIndex,
|
|
||||||
"content_block": map[string]string{"type": "text", "text": ""},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
|
||||||
blockIndex++
|
|
||||||
}
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
|
||||||
"usage": map[string]int{"output_tokens": 0},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
// 非 tools 模式:即時串流 thinking 和 text
|
|
||||||
blockCount := 0
|
|
||||||
thinkingOpen := false
|
|
||||||
textOpen := false
|
|
||||||
|
|
||||||
p = parser.CreateStreamParserWithThinking(
|
|
||||||
func(text string) {
|
|
||||||
accumulated += text
|
|
||||||
chunkNum++
|
|
||||||
logger.LogStreamChunk(model, text, chunkNum)
|
|
||||||
if thinkingOpen {
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
|
||||||
thinkingOpen = false
|
|
||||||
}
|
|
||||||
if !textOpen {
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": blockCount,
|
|
||||||
"content_block": map[string]string{"type": "text", "text": ""},
|
|
||||||
})
|
|
||||||
textOpen = true
|
|
||||||
blockCount++
|
|
||||||
}
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": blockCount - 1,
|
|
||||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
func(thinking string) {
|
|
||||||
accumulatedThinking += thinking
|
|
||||||
chunkNum++
|
|
||||||
if !thinkingOpen {
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": blockCount,
|
|
||||||
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
|
||||||
})
|
|
||||||
thinkingOpen = true
|
|
||||||
blockCount++
|
|
||||||
}
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": blockCount - 1,
|
|
||||||
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
func() {
|
|
||||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
|
||||||
if thinkingOpen {
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
|
||||||
thinkingOpen = false
|
|
||||||
}
|
|
||||||
if !textOpen {
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": blockCount,
|
|
||||||
"content_block": map[string]string{"type": "text", "text": ""},
|
|
||||||
})
|
|
||||||
blockCount++
|
|
||||||
}
|
|
||||||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
|
||||||
writeEvent(map[string]interface{}{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
|
||||||
"usage": map[string]int{"output_tokens": 0},
|
|
||||||
})
|
|
||||||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
configDir := ph.GetNextConfigDir()
|
|
||||||
logger.LogAccountAssigned(configDir)
|
|
||||||
ph.ReportRequestStart(configDir)
|
|
||||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
|
||||||
streamStart := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
wrappedParser := func(line string) {
|
|
||||||
logger.LogRawLine(line)
|
|
||||||
p.Parse(line)
|
|
||||||
}
|
|
||||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
|
|
||||||
|
|
||||||
if ctx.Err() == nil {
|
|
||||||
p.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
latencyMs := time.Now().UnixMilli() - streamStart
|
|
||||||
ph.ReportRequestEnd(configDir)
|
|
||||||
|
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
|
||||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
|
||||||
} else if ctx.Err() == context.Canceled {
|
|
||||||
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
|
||||||
} else if err == nil && isRateLimited(result.Stderr) {
|
|
||||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
|
||||||
ph.ReportRequestError(configDir, latencyMs)
|
|
||||||
if err != nil {
|
|
||||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
|
||||||
} else {
|
|
||||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
|
||||||
}
|
|
||||||
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
|
||||||
} else if ctx.Err() == nil {
|
|
||||||
ph.ReportRequestSuccess(configDir, latencyMs)
|
|
||||||
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
|
||||||
}
|
|
||||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
configDir := ph.GetNextConfigDir()
|
|
||||||
logger.LogAccountAssigned(configDir)
|
|
||||||
ph.ReportRequestStart(configDir)
|
|
||||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
|
|
||||||
syncStart := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
|
||||||
syncLatency := time.Now().UnixMilli() - syncStart
|
|
||||||
ph.ReportRequestEnd(configDir)
|
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
|
||||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
|
||||||
httputil.WriteJSON(w, 504, map[string]interface{}{
|
|
||||||
"error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if ctx.Err() == context.Canceled {
|
|
||||||
logger.LogClientDisconnect(method, pathname, model, syncLatency)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
ph.ReportRequestError(configDir, syncLatency)
|
|
||||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
||||||
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"type": "api_error", "message": err.Error()},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if isRateLimited(out.Stderr) {
|
|
||||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
|
|
||||||
}
|
|
||||||
|
|
||||||
if out.Code != 0 {
|
|
||||||
ph.ReportRequestError(configDir, syncLatency)
|
|
||||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
||||||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
|
||||||
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"type": "api_error", "message": errMsg},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ph.ReportRequestSuccess(configDir, syncLatency)
|
|
||||||
content := strings.TrimSpace(out.Stdout)
|
|
||||||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
|
||||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
||||||
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
|
|
||||||
|
|
||||||
if hasTools {
|
|
||||||
parsed := toolcall.ExtractToolCalls(content, toolNames)
|
|
||||||
if parsed.HasToolCalls() {
|
|
||||||
var contentBlocks []map[string]interface{}
|
|
||||||
if parsed.TextContent != "" {
|
|
||||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
|
||||||
"type": "text", "text": parsed.TextContent,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
for _, tc := range parsed.ToolCalls {
|
|
||||||
toolID := "toolu_" + uuid.New().String()[:12]
|
|
||||||
var inputObj interface{}
|
|
||||||
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
|
|
||||||
if inputObj == nil {
|
|
||||||
inputObj = map[string]interface{}{}
|
|
||||||
}
|
|
||||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
|
||||||
"type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
|
||||||
"id": msgID,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": contentBlocks,
|
|
||||||
"model": model,
|
|
||||||
"stop_reason": "tool_use",
|
|
||||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
|
||||||
}, truncatedHeaders)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
|
||||||
"id": msgID,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": []map[string]string{{"type": "text", "text": content}},
|
|
||||||
"model": model,
|
|
||||||
"stop_reason": "end_turn",
|
|
||||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
|
||||||
}, truncatedHeaders)
|
|
||||||
}
|
|
||||||
|
|
||||||
func systemToString(system interface{}) string {
|
|
||||||
switch v := system.(type) {
|
|
||||||
case string:
|
|
||||||
return v
|
|
||||||
case []interface{}:
|
|
||||||
result := ""
|
|
||||||
for _, p := range v {
|
|
||||||
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
|
|
||||||
if t, ok := m["text"].(string); ok {
|
|
||||||
result += t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func contentToString(content interface{}) string {
|
|
||||||
switch v := content.(type) {
|
|
||||||
case string:
|
|
||||||
return v
|
|
||||||
case []interface{}:
|
|
||||||
result := ""
|
|
||||||
for _, p := range v {
|
|
||||||
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
|
|
||||||
if t, ok := m["text"].(string); ok {
|
|
||||||
result += t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
@ -1,471 +0,0 @@
|
||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/agent"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/httputil"
|
|
||||||
"cursor-api-proxy/internal/logger"
|
|
||||||
"cursor-api-proxy/internal/models"
|
|
||||||
"cursor-api-proxy/internal/openai"
|
|
||||||
"cursor-api-proxy/internal/parser"
|
|
||||||
"cursor-api-proxy/internal/pool"
|
|
||||||
"cursor-api-proxy/internal/sanitize"
|
|
||||||
"cursor-api-proxy/internal/toolcall"
|
|
||||||
"cursor-api-proxy/internal/winlimit"
|
|
||||||
"cursor-api-proxy/internal/workspace"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`)
|
|
||||||
var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`)
|
|
||||||
|
|
||||||
func isRateLimited(stderr string) bool {
|
|
||||||
return rateLimitRe.MatchString(stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractRetryAfterMs(stderr string) int64 {
|
|
||||||
if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 {
|
|
||||||
if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 {
|
|
||||||
return secs * 1000
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 60000
|
|
||||||
}
|
|
||||||
|
|
||||||
func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
|
||||||
var bodyMap map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
|
||||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawModel, _ := bodyMap["model"].(string)
|
|
||||||
requested := openai.NormalizeModelID(rawModel)
|
|
||||||
model := ResolveModel(requested, lastModelRef, cfg)
|
|
||||||
cursorModel := models.ResolveToCursorModel(model)
|
|
||||||
if cursorModel == "" {
|
|
||||||
cursorModel = model
|
|
||||||
}
|
|
||||||
|
|
||||||
var messages []interface{}
|
|
||||||
if m, ok := bodyMap["messages"].([]interface{}); ok {
|
|
||||||
messages = m
|
|
||||||
}
|
|
||||||
|
|
||||||
var tools []interface{}
|
|
||||||
if t, ok := bodyMap["tools"].([]interface{}); ok {
|
|
||||||
tools = t
|
|
||||||
}
|
|
||||||
var funcs []interface{}
|
|
||||||
if f, ok := bodyMap["functions"].([]interface{}); ok {
|
|
||||||
funcs = f
|
|
||||||
}
|
|
||||||
|
|
||||||
cleanMessages := sanitize.SanitizeMessages(messages)
|
|
||||||
|
|
||||||
toolsText := openai.ToolsToSystemText(tools, funcs)
|
|
||||||
messagesWithTools := cleanMessages
|
|
||||||
if toolsText != "" {
|
|
||||||
messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...)
|
|
||||||
}
|
|
||||||
prompt := openai.BuildPromptFromMessages(messagesWithTools)
|
|
||||||
|
|
||||||
var trafficMsgs []logger.TrafficMessage
|
|
||||||
for _, raw := range cleanMessages {
|
|
||||||
if m, ok := raw.(map[string]interface{}); ok {
|
|
||||||
role, _ := m["role"].(string)
|
|
||||||
content := openai.MessageContentToText(m["content"])
|
|
||||||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
isStream := false
|
|
||||||
if s, ok := bodyMap["stream"].(bool); ok {
|
|
||||||
isStream = s
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream)
|
|
||||||
|
|
||||||
headerWs := r.Header.Get("x-cursor-workspace")
|
|
||||||
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
|
||||||
|
|
||||||
promptLen := len(prompt)
|
|
||||||
if cfg.Verbose {
|
|
||||||
if promptLen > 200 {
|
|
||||||
logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200])
|
|
||||||
} else {
|
|
||||||
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream)
|
|
||||||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
|
||||||
|
|
||||||
if cfg.Verbose {
|
|
||||||
logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !fit.OK {
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if fit.Truncated {
|
|
||||||
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdArgs := fit.Args
|
|
||||||
id := "chatcmpl_" + uuid.New().String()
|
|
||||||
created := time.Now().Unix()
|
|
||||||
|
|
||||||
var truncatedHeaders map[string]string
|
|
||||||
if fit.Truncated {
|
|
||||||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
|
||||||
}
|
|
||||||
|
|
||||||
hasTools := len(tools) > 0 || len(funcs) > 0
|
|
||||||
var toolNames map[string]bool
|
|
||||||
if hasTools {
|
|
||||||
toolNames = toolcall.CollectToolNames(tools)
|
|
||||||
for _, f := range funcs {
|
|
||||||
if fm, ok := f.(map[string]interface{}); ok {
|
|
||||||
if name, ok := fm["name"].(string); ok {
|
|
||||||
toolNames[name] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isStream {
|
|
||||||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
|
||||||
flusher, _ := w.(http.Flusher)
|
|
||||||
|
|
||||||
var accumulated string
|
|
||||||
var chunkNum int
|
|
||||||
var p parser.Parser
|
|
||||||
|
|
||||||
// toolCallMarkerRe 偵測 tool call 開頭標記,一旦出現就停止即時輸出並進入累積模式
|
|
||||||
toolCallMarkerRe := regexp.MustCompile(`<tool_call>|<function_calls>`)
|
|
||||||
if hasTools {
|
|
||||||
var toolCallMode bool // 是否已進入 tool call 累積模式
|
|
||||||
p = parser.CreateStreamParserWithThinking(
|
|
||||||
func(text string) {
|
|
||||||
accumulated += text
|
|
||||||
chunkNum++
|
|
||||||
logger.LogStreamChunk(model, text, chunkNum)
|
|
||||||
if toolCallMode {
|
|
||||||
// 已進入累積模式,不即時輸出
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if toolCallMarkerRe.MatchString(text) {
|
|
||||||
// 偵測到 tool call 標記,切換為累積模式
|
|
||||||
toolCallMode = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
chunk := map[string]interface{}{
|
|
||||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(chunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
func(_ string) {}, // thinking ignored in tools mode
|
|
||||||
func() {
|
|
||||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
|
||||||
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
|
|
||||||
|
|
||||||
if parsed.HasToolCalls() {
|
|
||||||
if parsed.TextContent != "" && toolCallMode {
|
|
||||||
// 已有部分 text 被即時輸出,只補發剩餘的
|
|
||||||
chunk := map[string]interface{}{
|
|
||||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(chunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i, tc := range parsed.ToolCalls {
|
|
||||||
callID := "call_" + uuid.New().String()[:8]
|
|
||||||
chunk := map[string]interface{}{
|
|
||||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "delta": map[string]interface{}{
|
|
||||||
"tool_calls": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"index": i,
|
|
||||||
"id": callID,
|
|
||||||
"type": "function",
|
|
||||||
"function": map[string]interface{}{
|
|
||||||
"name": tc.Name,
|
|
||||||
"arguments": tc.Arguments,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, "finish_reason": nil},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(chunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stopChunk := map[string]interface{}{
|
|
||||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(stopChunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
stopChunk := map[string]interface{}{
|
|
||||||
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(stopChunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
p = parser.CreateStreamParserWithThinking(
|
|
||||||
func(text string) {
|
|
||||||
accumulated += text
|
|
||||||
chunkNum++
|
|
||||||
logger.LogStreamChunk(model, text, chunkNum)
|
|
||||||
chunk := map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(chunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
func(thinking string) {
|
|
||||||
chunk := map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(chunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
func() {
|
|
||||||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
|
||||||
stopChunk := map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(stopChunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
configDir := ph.GetNextConfigDir()
|
|
||||||
logger.LogAccountAssigned(configDir)
|
|
||||||
ph.ReportRequestStart(configDir)
|
|
||||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
|
||||||
streamStart := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
wrappedParser := func(line string) {
|
|
||||||
logger.LogRawLine(line)
|
|
||||||
p.Parse(line)
|
|
||||||
}
|
|
||||||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
|
|
||||||
|
|
||||||
// agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾
|
|
||||||
if ctx.Err() == nil {
|
|
||||||
p.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
latencyMs := time.Now().UnixMilli() - streamStart
|
|
||||||
ph.ReportRequestEnd(configDir)
|
|
||||||
|
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
|
||||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
|
||||||
} else if ctx.Err() == context.Canceled {
|
|
||||||
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
|
||||||
} else if err == nil && isRateLimited(result.Stderr) {
|
|
||||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
|
||||||
ph.ReportRequestError(configDir, latencyMs)
|
|
||||||
if err != nil {
|
|
||||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
|
||||||
} else {
|
|
||||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
|
||||||
}
|
|
||||||
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
|
||||||
} else if ctx.Err() == nil {
|
|
||||||
ph.ReportRequestSuccess(configDir, latencyMs)
|
|
||||||
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
|
||||||
}
|
|
||||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
configDir := ph.GetNextConfigDir()
|
|
||||||
logger.LogAccountAssigned(configDir)
|
|
||||||
ph.ReportRequestStart(configDir)
|
|
||||||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
|
|
||||||
syncStart := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
|
||||||
syncLatency := time.Now().UnixMilli() - syncStart
|
|
||||||
ph.ReportRequestEnd(configDir)
|
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
|
||||||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
|
||||||
httputil.WriteJSON(w, 504, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs), "code": "timeout"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if ctx.Err() == context.Canceled {
|
|
||||||
logger.LogClientDisconnect(method, pathname, model, syncLatency)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
ph.ReportRequestError(configDir, syncLatency)
|
|
||||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
||||||
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if isRateLimited(out.Stderr) {
|
|
||||||
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
|
|
||||||
}
|
|
||||||
|
|
||||||
if out.Code != 0 {
|
|
||||||
ph.ReportRequestError(configDir, syncLatency)
|
|
||||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
||||||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
|
||||||
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": errMsg, "code": "cursor_cli_error"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ph.ReportRequestSuccess(configDir, syncLatency)
|
|
||||||
content := strings.TrimSpace(out.Stdout)
|
|
||||||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
|
||||||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
||||||
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
|
|
||||||
|
|
||||||
if hasTools {
|
|
||||||
parsed := toolcall.ExtractToolCalls(content, toolNames)
|
|
||||||
if parsed.HasToolCalls() {
|
|
||||||
msg := map[string]interface{}{"role": "assistant"}
|
|
||||||
if parsed.TextContent != "" {
|
|
||||||
msg["content"] = parsed.TextContent
|
|
||||||
} else {
|
|
||||||
msg["content"] = nil
|
|
||||||
}
|
|
||||||
var tcArr []map[string]interface{}
|
|
||||||
for _, tc := range parsed.ToolCalls {
|
|
||||||
callID := "call_" + uuid.New().String()[:8]
|
|
||||||
tcArr = append(tcArr, map[string]interface{}{
|
|
||||||
"id": callID,
|
|
||||||
"type": "function",
|
|
||||||
"function": map[string]interface{}{
|
|
||||||
"name": tc.Name,
|
|
||||||
"arguments": tc.Arguments,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
msg["tool_calls"] = tcArr
|
|
||||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{"index": 0, "message": msg, "finish_reason": "tool_calls"},
|
|
||||||
},
|
|
||||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
||||||
}, truncatedHeaders)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": map[string]string{"role": "assistant", "content": content},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
||||||
}, truncatedHeaders)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,203 +0,0 @@
|
||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/apitypes"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/httputil"
|
|
||||||
"cursor-api-proxy/internal/logger"
|
|
||||||
"cursor-api-proxy/internal/providers/geminiweb"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) {
|
|
||||||
_ = context.Background() // 確保 context 被使用
|
|
||||||
var bodyMap map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
|
||||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawModel, _ := bodyMap["model"].(string)
|
|
||||||
if rawModel == "" {
|
|
||||||
rawModel = "gemini-2.0-flash"
|
|
||||||
}
|
|
||||||
|
|
||||||
var messages []interface{}
|
|
||||||
if m, ok := bodyMap["messages"].([]interface{}); ok {
|
|
||||||
messages = m
|
|
||||||
}
|
|
||||||
|
|
||||||
isStream := false
|
|
||||||
if s, ok := bodyMap["stream"].(bool); ok {
|
|
||||||
isStream = s
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換 messages 為 apitypes.Message
|
|
||||||
var apiMessages []apitypes.Message
|
|
||||||
for _, m := range messages {
|
|
||||||
if msgMap, ok := m.(map[string]interface{}); ok {
|
|
||||||
role, _ := msgMap["role"].(string)
|
|
||||||
content := ""
|
|
||||||
if c, ok := msgMap["content"].(string); ok {
|
|
||||||
content = c
|
|
||||||
}
|
|
||||||
apiMessages = append(apiMessages, apitypes.Message{
|
|
||||||
Role: role,
|
|
||||||
Content: content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream)
|
|
||||||
start := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
// 創建 Gemini provider (使用 Playwright)
|
|
||||||
provider, provErr := geminiweb.NewPlaywrightProvider(cfg)
|
|
||||||
if provErr != nil {
|
|
||||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, provErr.Error())
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": provErr.Error(), "code": "provider_error"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if isStream {
|
|
||||||
httputil.WriteSSEHeaders(w, nil)
|
|
||||||
flusher, _ := w.(http.Flusher)
|
|
||||||
|
|
||||||
id := "chatcmpl_" + uuid.New().String()
|
|
||||||
created := time.Now().Unix()
|
|
||||||
var accumulated string
|
|
||||||
|
|
||||||
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
|
|
||||||
if chunk.Type == apitypes.ChunkText {
|
|
||||||
accumulated += chunk.Text
|
|
||||||
respChunk := map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": rawModel,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": map[string]string{"content": chunk.Text},
|
|
||||||
"finish_reason": nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(respChunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
} else if chunk.Type == apitypes.ChunkThinking {
|
|
||||||
respChunk := map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": rawModel,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": map[string]interface{}{"reasoning_content": chunk.Thinking},
|
|
||||||
"finish_reason": nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(respChunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
} else if chunk.Type == apitypes.ChunkDone {
|
|
||||||
stopChunk := map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": rawModel,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": map[string]interface{}{},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(stopChunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
latencyMs := time.Now().UnixMilli() - start
|
|
||||||
if err != nil {
|
|
||||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
|
||||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true)
|
|
||||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 非串流模式
|
|
||||||
var resultText string
|
|
||||||
var resultThinking string
|
|
||||||
|
|
||||||
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
|
|
||||||
if chunk.Type == apitypes.ChunkText {
|
|
||||||
resultText += chunk.Text
|
|
||||||
} else if chunk.Type == apitypes.ChunkThinking {
|
|
||||||
resultThinking += chunk.Thinking
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
latencyMs := time.Now().UnixMilli() - start
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
|
||||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": err.Error(), "code": "gemini_error"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false)
|
|
||||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
|
|
||||||
|
|
||||||
id := "chatcmpl_" + uuid.New().String()
|
|
||||||
created := time.Now().Unix()
|
|
||||||
|
|
||||||
resp := map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": created,
|
|
||||||
"model": rawModel,
|
|
||||||
"choices": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": resultText,
|
|
||||||
},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
httputil.WriteJSON(w, 200, resp, nil)
|
|
||||||
}
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/httputil"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HandleHealth(w http.ResponseWriter, r *http.Request, version string, cfg config.BridgeConfig) {
|
|
||||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
|
||||||
"ok": true,
|
|
||||||
"version": version,
|
|
||||||
"workspace": cfg.Workspace,
|
|
||||||
"mode": cfg.Mode,
|
|
||||||
"defaultModel": cfg.DefaultModel,
|
|
||||||
"force": cfg.Force,
|
|
||||||
"approveMcps": cfg.ApproveMcps,
|
|
||||||
"strictModel": cfg.StrictModel,
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
|
|
@ -1,107 +0,0 @@
|
||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/httputil"
|
|
||||||
"cursor-api-proxy/internal/models"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const modelCacheTTLMs = 5 * 60 * 1000
|
|
||||||
|
|
||||||
type ModelCache struct {
|
|
||||||
At int64
|
|
||||||
Models []models.CursorCliModel
|
|
||||||
}
|
|
||||||
|
|
||||||
type ModelCacheRef struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
cache *ModelCache
|
|
||||||
inflight bool
|
|
||||||
waiters []chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ref *ModelCacheRef) HandleModels(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig) {
|
|
||||||
now := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
ref.mu.Lock()
|
|
||||||
if ref.cache != nil && now-ref.cache.At <= modelCacheTTLMs {
|
|
||||||
cache := ref.cache
|
|
||||||
ref.mu.Unlock()
|
|
||||||
writeModels(w, cache.Models)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ref.inflight {
|
|
||||||
// Wait for the in-flight fetch
|
|
||||||
ch := make(chan struct{}, 1)
|
|
||||||
ref.waiters = append(ref.waiters, ch)
|
|
||||||
ref.mu.Unlock()
|
|
||||||
<-ch
|
|
||||||
ref.mu.Lock()
|
|
||||||
cache := ref.cache
|
|
||||||
ref.mu.Unlock()
|
|
||||||
writeModels(w, cache.Models)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ref.inflight = true
|
|
||||||
ref.mu.Unlock()
|
|
||||||
|
|
||||||
fetched, err := models.ListCursorCliModels(cfg.AgentBin, 60000)
|
|
||||||
|
|
||||||
ref.mu.Lock()
|
|
||||||
ref.inflight = false
|
|
||||||
if err == nil {
|
|
||||||
ref.cache = &ModelCache{At: time.Now().UnixMilli(), Models: fetched}
|
|
||||||
}
|
|
||||||
waiters := ref.waiters
|
|
||||||
ref.waiters = nil
|
|
||||||
ref.mu.Unlock()
|
|
||||||
|
|
||||||
for _, ch := range waiters {
|
|
||||||
ch <- struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": err.Error(), "code": "models_fetch_error"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writeModels(w, fetched)
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeModels(w http.ResponseWriter, mods []models.CursorCliModel) {
|
|
||||||
cursorModels := make([]map[string]interface{}, len(mods))
|
|
||||||
for i, m := range mods {
|
|
||||||
cursorModels[i] = map[string]interface{}{
|
|
||||||
"id": m.ID,
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "cursor",
|
|
||||||
"name": m.Name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ids := make([]string, len(mods))
|
|
||||||
for i, m := range mods {
|
|
||||||
ids[i] = m.ID
|
|
||||||
}
|
|
||||||
aliases := models.GetAnthropicModelAliases(ids)
|
|
||||||
for _, a := range aliases {
|
|
||||||
cursorModels = append(cursorModels, map[string]interface{}{
|
|
||||||
"id": a.ID,
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "cursor",
|
|
||||||
"name": a.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
|
||||||
"object": "list",
|
|
||||||
"data": cursorModels,
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
package handlers
|
|
||||||
|
|
||||||
import "cursor-api-proxy/internal/config"
|
|
||||||
|
|
||||||
func ResolveModel(requested string, lastModelRef *string, cfg config.BridgeConfig) string {
|
|
||||||
isAuto := requested == "auto"
|
|
||||||
var explicitModel string
|
|
||||||
if requested != "" && !isAuto {
|
|
||||||
explicitModel = requested
|
|
||||||
}
|
|
||||||
if explicitModel != "" {
|
|
||||||
*lastModelRef = explicitModel
|
|
||||||
}
|
|
||||||
if isAuto {
|
|
||||||
return "auto"
|
|
||||||
}
|
|
||||||
if explicitModel != "" {
|
|
||||||
return explicitModel
|
|
||||||
}
|
|
||||||
if cfg.StrictModel && *lastModelRef != "" {
|
|
||||||
return *lastModelRef
|
|
||||||
}
|
|
||||||
if *lastModelRef != "" {
|
|
||||||
return *lastModelRef
|
|
||||||
}
|
|
||||||
return cfg.DefaultModel
|
|
||||||
}
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
package httputil
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
)
|
|
||||||
|
|
||||||
var bearerRe = regexp.MustCompile(`(?i)^Bearer\s+(.+)$`)
|
|
||||||
|
|
||||||
func ExtractBearerToken(r *http.Request) string {
|
|
||||||
h := r.Header.Get("Authorization")
|
|
||||||
if h == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
m := bearerRe.FindStringSubmatch(h)
|
|
||||||
if m == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return m[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteJSON(w http.ResponseWriter, status int, body interface{}, extraHeaders map[string]string) {
|
|
||||||
for k, v := range extraHeaders {
|
|
||||||
w.Header().Set(k, v)
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
_ = json.NewEncoder(w).Encode(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteSSEHeaders(w http.ResponseWriter, extraHeaders map[string]string) {
|
|
||||||
for k, v := range extraHeaders {
|
|
||||||
w.Header().Set(k, v)
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
w.WriteHeader(200)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadBody(r *http.Request) (string, error) {
|
|
||||||
data, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return string(data), nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
package httputil
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestExtractBearerToken(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
header string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"Bearer mytoken123", "mytoken123"},
|
|
||||||
{"bearer MYTOKEN", "MYTOKEN"},
|
|
||||||
{"", ""},
|
|
||||||
{"Basic abc", ""},
|
|
||||||
{"Bearer ", ""},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
if tc.header != "" {
|
|
||||||
req.Header.Set("Authorization", tc.header)
|
|
||||||
}
|
|
||||||
got := ExtractBearerToken(req)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("ExtractBearerToken(%q) = %q, want %q", tc.header, got, tc.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteJSON(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
WriteJSON(w, 200, map[string]string{"ok": "true"}, nil)
|
|
||||||
|
|
||||||
if w.Code != 200 {
|
|
||||||
t.Errorf("expected 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
if w.Header().Get("Content-Type") != "application/json" {
|
|
||||||
t.Errorf("expected application/json, got %s", w.Header().Get("Content-Type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteJSONWithExtraHeaders(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
WriteJSON(w, 201, nil, map[string]string{"X-Custom": "value"})
|
|
||||||
|
|
||||||
if w.Header().Get("X-Custom") != "value" {
|
|
||||||
t.Errorf("expected X-Custom=value, got %s", w.Header().Get("X-Custom"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,309 +0,0 @@
|
||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/pool"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
cReset = "\x1b[0m"
|
|
||||||
cBold = "\x1b[1m"
|
|
||||||
cDim = "\x1b[2m"
|
|
||||||
cCyan = "\x1b[36m"
|
|
||||||
cBCyan = "\x1b[1;96m"
|
|
||||||
cGreen = "\x1b[32m"
|
|
||||||
cBGreen = "\x1b[1;92m"
|
|
||||||
cYellow = "\x1b[33m"
|
|
||||||
cMagenta = "\x1b[35m"
|
|
||||||
cBMagenta = "\x1b[1;95m"
|
|
||||||
cRed = "\x1b[31m"
|
|
||||||
cGray = "\x1b[90m"
|
|
||||||
cWhite = "\x1b[97m"
|
|
||||||
)
|
|
||||||
|
|
||||||
var roleStyle = map[string]string{
|
|
||||||
"system": cYellow,
|
|
||||||
"user": cCyan,
|
|
||||||
"assistant": cGreen,
|
|
||||||
}
|
|
||||||
|
|
||||||
var roleEmoji = map[string]string{
|
|
||||||
"system": "🔧",
|
|
||||||
"user": "👤",
|
|
||||||
"assistant": "🤖",
|
|
||||||
}
|
|
||||||
|
|
||||||
func ts() string {
|
|
||||||
return cGray + time.Now().UTC().Format("15:04:05") + cReset
|
|
||||||
}
|
|
||||||
|
|
||||||
func tsDate() string {
|
|
||||||
return cGray + time.Now().UTC().Format("2006-01-02 15:04:05") + cReset
|
|
||||||
}
|
|
||||||
|
|
||||||
func truncate(s string, max int) string {
|
|
||||||
if len(s) <= max {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
head := int(float64(max) * 0.6)
|
|
||||||
tail := max - head
|
|
||||||
omitted := len(s) - head - tail
|
|
||||||
return s[:head] + fmt.Sprintf("%s … (%d chars omitted) … ", cDim, omitted) + s[len(s)-tail:] + cReset
|
|
||||||
}
|
|
||||||
|
|
||||||
func hr(ch string, length int) string {
|
|
||||||
return cGray + strings.Repeat(ch, length) + cReset
|
|
||||||
}
|
|
||||||
|
|
||||||
type TrafficMessage struct {
|
|
||||||
Role string
|
|
||||||
Content string
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogDebug(format string, args ...interface{}) {
|
|
||||||
msg := fmt.Sprintf(format, args...)
|
|
||||||
fmt.Printf("%s %s[DEBUG]%s %s\n", ts(), cGray, cReset, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogServerStart(version, scheme, host string, port int, cfg config.BridgeConfig) {
|
|
||||||
provider := cfg.Provider
|
|
||||||
if provider == "" {
|
|
||||||
provider = "cursor"
|
|
||||||
}
|
|
||||||
fmt.Printf("\n%s%s╔══════════════════════════════════════════╗%s\n", cBold, cBCyan, cReset)
|
|
||||||
fmt.Printf("%s%s cursor-api-proxy %sv%s%s%s%s ready%s\n",
|
|
||||||
cBold, cBCyan, cReset, cBold, cWhite, version, cBCyan, cReset)
|
|
||||||
fmt.Printf("%s%s╚══════════════════════════════════════════╝%s\n\n", cBold, cBCyan, cReset)
|
|
||||||
url := fmt.Sprintf("%s://%s:%d", scheme, host, port)
|
|
||||||
fmt.Printf(" %s●%s listening %s%s%s\n", cBGreen, cReset, cBold, url, cReset)
|
|
||||||
fmt.Printf(" %s▸%s provider %s%s%s\n", cCyan, cReset, cBold, provider, cReset)
|
|
||||||
fmt.Printf(" %s▸%s agent %s%s%s\n", cCyan, cReset, cDim, cfg.AgentBin, cReset)
|
|
||||||
fmt.Printf(" %s▸%s workspace %s%s%s\n", cCyan, cReset, cDim, cfg.Workspace, cReset)
|
|
||||||
fmt.Printf(" %s▸%s model %s%s%s\n", cCyan, cReset, cDim, cfg.DefaultModel, cReset)
|
|
||||||
fmt.Printf(" %s▸%s mode %s%s%s\n", cCyan, cReset, cDim, cfg.Mode, cReset)
|
|
||||||
fmt.Printf(" %s▸%s timeout %s%d ms%s\n", cCyan, cReset, cDim, cfg.TimeoutMs, cReset)
|
|
||||||
|
|
||||||
// 顯示 Gemini Web Provider 相關設定
|
|
||||||
if provider == "gemini-web" {
|
|
||||||
fmt.Printf(" %s▸%s gemini-dir %s%s%s\n", cCyan, cReset, cDim, cfg.GeminiAccountDir, cReset)
|
|
||||||
fmt.Printf(" %s▸%s max-sess %s%d%s\n", cCyan, cReset, cDim, cfg.GeminiMaxSessions, cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
flags := []string{}
|
|
||||||
if cfg.Force {
|
|
||||||
flags = append(flags, "force")
|
|
||||||
}
|
|
||||||
if cfg.ApproveMcps {
|
|
||||||
flags = append(flags, "approve-mcps")
|
|
||||||
}
|
|
||||||
if cfg.MaxMode {
|
|
||||||
flags = append(flags, "max-mode")
|
|
||||||
}
|
|
||||||
if cfg.Verbose {
|
|
||||||
flags = append(flags, "verbose")
|
|
||||||
}
|
|
||||||
if cfg.ChatOnlyWorkspace {
|
|
||||||
flags = append(flags, "chat-only")
|
|
||||||
}
|
|
||||||
if cfg.RequiredKey != "" {
|
|
||||||
flags = append(flags, "api-key-required")
|
|
||||||
}
|
|
||||||
if len(flags) > 0 {
|
|
||||||
fmt.Printf(" %s▸%s flags %s%s%s\n", cCyan, cReset, cYellow, strings.Join(flags, " · "), cReset)
|
|
||||||
}
|
|
||||||
if len(cfg.ConfigDirs) > 0 {
|
|
||||||
fmt.Printf(" %s▸%s pool %s%d accounts%s\n", cCyan, cReset, cBGreen, len(cfg.ConfigDirs), cReset)
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogShutdown(sig string) {
|
|
||||||
fmt.Printf("\n%s %s⊘ %s received — shutting down gracefully…%s\n", tsDate(), cYellow, sig, cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogRequestStart(method, pathname, model string, timeoutMs int, isStream bool) {
|
|
||||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
|
||||||
if isStream {
|
|
||||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
|
|
||||||
}
|
|
||||||
fmt.Printf("%s %s▶%s %s %s %s timeout:%dms %s\n",
|
|
||||||
ts(), cBCyan, cReset, method, pathname, model, timeoutMs, modeTag)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogRequestDone(method, pathname, model string, latencyMs int64, code int) {
|
|
||||||
statusColor := cBGreen
|
|
||||||
if code != 0 {
|
|
||||||
statusColor = cRed
|
|
||||||
}
|
|
||||||
fmt.Printf("%s %s■%s %s %s %s %s%dms exit:%d%s\n",
|
|
||||||
ts(), statusColor, cReset, method, pathname, model, cDim, latencyMs, code, cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogRequestTimeout(method, pathname, model string, timeoutMs int) {
|
|
||||||
fmt.Printf("%s %s⏱%s %s %s %s %stimed-out after %dms%s\n",
|
|
||||||
ts(), cRed, cReset, method, pathname, model, cRed, timeoutMs, cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogClientDisconnect(method, pathname, model string, latencyMs int64) {
|
|
||||||
fmt.Printf("%s %s⚡%s %s %s %s %sclient disconnected after %dms%s\n",
|
|
||||||
ts(), cYellow, cReset, method, pathname, model, cYellow, latencyMs, cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogStreamChunk(model string, text string, chunkNum int) {
|
|
||||||
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 120)
|
|
||||||
fmt.Printf("%s %s▸%s #%d %s%s%s\n",
|
|
||||||
ts(), cDim, cReset, chunkNum, cWhite, preview, cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogRawLine(line string) {
|
|
||||||
preview := truncate(strings.ReplaceAll(line, "\n", "↵ "), 200)
|
|
||||||
fmt.Printf("%s %s│%s %sraw%s %s\n",
|
|
||||||
ts(), cGray, cReset, cDim, cReset, preview)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogIncoming(method, pathname, remoteAddress string) {
|
|
||||||
methodColor := cBCyan
|
|
||||||
switch method {
|
|
||||||
case "POST":
|
|
||||||
methodColor = cBMagenta
|
|
||||||
case "GET":
|
|
||||||
methodColor = cBCyan
|
|
||||||
case "DELETE":
|
|
||||||
methodColor = cRed
|
|
||||||
}
|
|
||||||
fmt.Printf("%s %s%s%s%s %s%s%s %s(%s)%s\n",
|
|
||||||
ts(),
|
|
||||||
methodColor, cBold, method, cReset,
|
|
||||||
cWhite, pathname, cReset,
|
|
||||||
cDim, remoteAddress, cReset,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogAccountAssigned(configDir string) {
|
|
||||||
if configDir == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
name := filepath.Base(configDir)
|
|
||||||
fmt.Printf("%s %s→%s account %s%s%s\n", ts(), cBCyan, cReset, cBold, name, cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogAccountStats(verbose bool, stats []pool.AccountStat) {
|
|
||||||
if !verbose || len(stats) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
now := time.Now().UnixMilli()
|
|
||||||
fmt.Printf("%s┌─ Account Stats %s┐%s\n", cGray, strings.Repeat("─", 44), cReset)
|
|
||||||
for _, s := range stats {
|
|
||||||
name := fmt.Sprintf("%-20s", filepath.Base(s.ConfigDir))
|
|
||||||
active := fmt.Sprintf("%sactive:0%s", cDim, cReset)
|
|
||||||
if s.ActiveRequests > 0 {
|
|
||||||
active = fmt.Sprintf("%sactive:%d%s", cBCyan, s.ActiveRequests, cReset)
|
|
||||||
}
|
|
||||||
total := fmt.Sprintf("total:%s%d%s", cBold, s.TotalRequests, cReset)
|
|
||||||
ok := fmt.Sprintf("%sok:%d%s", cGreen, s.TotalSuccess, cReset)
|
|
||||||
errStr := fmt.Sprintf("%serr:0%s", cDim, cReset)
|
|
||||||
if s.TotalErrors > 0 {
|
|
||||||
errStr = fmt.Sprintf("%serr:%d%s", cRed, s.TotalErrors, cReset)
|
|
||||||
}
|
|
||||||
rl := fmt.Sprintf("%srl:0%s", cDim, cReset)
|
|
||||||
if s.TotalRateLimits > 0 {
|
|
||||||
rl = fmt.Sprintf("%srl:%d%s", cYellow, s.TotalRateLimits, cReset)
|
|
||||||
}
|
|
||||||
avg := "avg:-"
|
|
||||||
if s.TotalRequests > 0 {
|
|
||||||
avg = fmt.Sprintf("avg:%dms", s.TotalLatencyMs/int64(s.TotalRequests))
|
|
||||||
}
|
|
||||||
status := fmt.Sprintf("%s✓%s", cGreen, cReset)
|
|
||||||
if s.IsRateLimited {
|
|
||||||
recovers := time.UnixMilli(s.RateLimitUntil).UTC().Format(time.RFC3339)
|
|
||||||
_ = now
|
|
||||||
status = fmt.Sprintf("%s⛔ rate-limited (recovers %s)%s", cRed, recovers, cReset)
|
|
||||||
}
|
|
||||||
fmt.Printf(" %s%s%s %s %s %s %s %s %s%s%s %s\n",
|
|
||||||
cBold, name, cReset, active, total, ok, errStr, rl, cDim, avg, cReset, status)
|
|
||||||
}
|
|
||||||
fmt.Printf("%s└%s┘%s\n", cGray, strings.Repeat("─", 60), cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogTrafficRequest(verbose bool, model string, messages []TrafficMessage, isStream bool) {
|
|
||||||
if !verbose {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
|
||||||
if isStream {
|
|
||||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBCyan, cReset)
|
|
||||||
}
|
|
||||||
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
|
|
||||||
fmt.Println(hr("─", 60))
|
|
||||||
fmt.Printf("%s 📤 %s%sREQUEST%s %s %s\n", ts(), cBCyan, cBold, cReset, modelStr, modeTag)
|
|
||||||
for _, m := range messages {
|
|
||||||
roleColor := cWhite
|
|
||||||
if c, ok := roleStyle[m.Role]; ok {
|
|
||||||
roleColor = c
|
|
||||||
}
|
|
||||||
emoji := "💬"
|
|
||||||
if e, ok := roleEmoji[m.Role]; ok {
|
|
||||||
emoji = e
|
|
||||||
}
|
|
||||||
label := fmt.Sprintf("%s%s[%s]%s", roleColor, cBold, m.Role, cReset)
|
|
||||||
charCount := fmt.Sprintf("%s(%d chars)%s", cDim, len(m.Content), cReset)
|
|
||||||
preview := truncate(strings.ReplaceAll(m.Content, "\n", "↵ "), 280)
|
|
||||||
fmt.Printf(" %s %s %s\n", emoji, label, charCount)
|
|
||||||
fmt.Printf(" %s%s%s\n", cDim, preview, cReset)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogTrafficResponse(verbose bool, model, text string, isStream bool) {
|
|
||||||
if !verbose {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
modeTag := fmt.Sprintf("%ssync%s", cDim, cReset)
|
|
||||||
if isStream {
|
|
||||||
modeTag = fmt.Sprintf("%s⚡ stream%s", cBGreen, cReset)
|
|
||||||
}
|
|
||||||
modelStr := fmt.Sprintf("%s✦ %s%s", cBMagenta, model, cReset)
|
|
||||||
charCount := fmt.Sprintf("%s%d%s%s chars%s", cBold, len(text), cReset, cDim, cReset)
|
|
||||||
preview := truncate(strings.ReplaceAll(text, "\n", "↵ "), 480)
|
|
||||||
fmt.Printf("%s 📥 %s%sRESPONSE%s %s %s %s\n", ts(), cBGreen, cBold, cReset, modelStr, modeTag, charCount)
|
|
||||||
fmt.Printf(" 🤖 %s%s%s\n", cGreen, preview, cReset)
|
|
||||||
fmt.Println(hr("─", 60))
|
|
||||||
}
|
|
||||||
|
|
||||||
func AppendSessionLine(logPath, method, pathname, remoteAddress string, statusCode int) {
|
|
||||||
line := fmt.Sprintf("%s %s %s %s %d\n", time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, statusCode)
|
|
||||||
dir := filepath.Dir(logPath)
|
|
||||||
if err := os.MkdirAll(dir, 0755); err == nil {
|
|
||||||
f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err == nil {
|
|
||||||
_, _ = f.WriteString(line)
|
|
||||||
f.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogTruncation(originalLen, finalLen int) {
|
|
||||||
fmt.Printf("%s %s⚠ prompt truncated%s %s(%d → %d chars, tail preserved)%s\n",
|
|
||||||
ts(), cYellow, cReset, cDim, originalLen, finalLen, cReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogAgentError(logPath, method, pathname, remoteAddress string, exitCode int, stderr string) string {
|
|
||||||
errMsg := fmt.Sprintf("Cursor CLI failed (exit %d): %s", exitCode, strings.TrimSpace(stderr))
|
|
||||||
fmt.Fprintf(os.Stderr, "%s %s✗ agent error%s %s%s%s\n", ts(), cRed, cReset, cDim, errMsg, cReset)
|
|
||||||
truncated := strings.TrimSpace(stderr)
|
|
||||||
if len(truncated) > 200 {
|
|
||||||
truncated = truncated[:200]
|
|
||||||
}
|
|
||||||
truncated = strings.ReplaceAll(truncated, "\n", " ")
|
|
||||||
line := fmt.Sprintf("%s ERROR %s %s %s agent_exit_%d %s\n",
|
|
||||||
time.Now().UTC().Format(time.RFC3339), method, pathname, remoteAddress, exitCode, truncated)
|
|
||||||
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
|
|
||||||
_, _ = f.WriteString(line)
|
|
||||||
f.Close()
|
|
||||||
}
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
package models
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/process"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CursorCliModel struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
var modelLineRe = regexp.MustCompile(`^([A-Za-z0-9][A-Za-z0-9._:/-]*)\s+-\s+(.*)$`)
|
|
||||||
var trailingParenRe = regexp.MustCompile(`\s*\([^)]*\)\s*$`)
|
|
||||||
|
|
||||||
func ParseCursorCliModels(output string) []CursorCliModel {
|
|
||||||
lines := strings.Split(output, "\n")
|
|
||||||
seen := make(map[string]CursorCliModel)
|
|
||||||
var order []string
|
|
||||||
|
|
||||||
for _, line := range lines {
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
m := modelLineRe.FindStringSubmatch(line)
|
|
||||||
if m == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
id := m[1]
|
|
||||||
rawName := m[2]
|
|
||||||
name := strings.TrimSpace(trailingParenRe.ReplaceAllString(rawName, ""))
|
|
||||||
if name == "" {
|
|
||||||
name = id
|
|
||||||
}
|
|
||||||
if _, exists := seen[id]; !exists {
|
|
||||||
seen[id] = CursorCliModel{ID: id, Name: name}
|
|
||||||
order = append(order, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]CursorCliModel, 0, len(order))
|
|
||||||
for _, id := range order {
|
|
||||||
result = append(result, seen[id])
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListCursorCliModels(agentBin string, timeoutMs int) ([]CursorCliModel, error) {
|
|
||||||
tmpDir := os.TempDir()
|
|
||||||
result, err := process.Run(agentBin, []string{"--list-models"}, process.RunOptions{
|
|
||||||
Cwd: tmpDir,
|
|
||||||
TimeoutMs: timeoutMs,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if result.Code != 0 {
|
|
||||||
return nil, fmt.Errorf("agent --list-models failed: %s", strings.TrimSpace(result.Stderr))
|
|
||||||
}
|
|
||||||
return ParseCursorCliModels(result.Stdout), nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
package models
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestParseCursorCliModels(t *testing.T) {
|
|
||||||
output := `
|
|
||||||
gpt-4o - GPT-4o (some info)
|
|
||||||
claude-3-5-sonnet - Claude 3.5 Sonnet
|
|
||||||
gpt-4o - GPT-4o duplicate
|
|
||||||
invalid line without dash
|
|
||||||
`
|
|
||||||
result := ParseCursorCliModels(output)
|
|
||||||
|
|
||||||
if len(result) != 2 {
|
|
||||||
t.Fatalf("expected 2 unique models, got %d: %v", len(result), result)
|
|
||||||
}
|
|
||||||
if result[0].ID != "gpt-4o" {
|
|
||||||
t.Errorf("expected gpt-4o, got %s", result[0].ID)
|
|
||||||
}
|
|
||||||
if result[0].Name != "GPT-4o" {
|
|
||||||
t.Errorf("expected 'GPT-4o', got %s", result[0].Name)
|
|
||||||
}
|
|
||||||
if result[1].ID != "claude-3-5-sonnet" {
|
|
||||||
t.Errorf("expected claude-3-5-sonnet, got %s", result[1].ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseCursorCliModelsEmpty(t *testing.T) {
|
|
||||||
result := ParseCursorCliModels("")
|
|
||||||
if len(result) != 0 {
|
|
||||||
t.Fatalf("expected empty, got %v", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,123 +0,0 @@
|
||||||
package models
|
|
||||||
|
|
||||||
import (
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
var anthropicToCursor = map[string]string{
|
|
||||||
"claude-opus-4-6": "opus-4.6",
|
|
||||||
"claude-opus-4.6": "opus-4.6",
|
|
||||||
"claude-sonnet-4-6": "sonnet-4.6",
|
|
||||||
"claude-sonnet-4.6": "sonnet-4.6",
|
|
||||||
"claude-opus-4-5": "opus-4.5",
|
|
||||||
"claude-opus-4.5": "opus-4.5",
|
|
||||||
"claude-sonnet-4-5": "sonnet-4.5",
|
|
||||||
"claude-sonnet-4.5": "sonnet-4.5",
|
|
||||||
"claude-opus-4": "opus-4.6",
|
|
||||||
"claude-sonnet-4": "sonnet-4.6",
|
|
||||||
"claude-haiku-4-5-20251001": "sonnet-4.5",
|
|
||||||
"claude-haiku-4-5": "sonnet-4.5",
|
|
||||||
"claude-haiku-4-6": "sonnet-4.6",
|
|
||||||
"claude-haiku-4": "sonnet-4.5",
|
|
||||||
"claude-opus-4-6-thinking": "opus-4.6-thinking",
|
|
||||||
"claude-sonnet-4-6-thinking": "sonnet-4.6-thinking",
|
|
||||||
"claude-opus-4-5-thinking": "opus-4.5-thinking",
|
|
||||||
"claude-sonnet-4-5-thinking": "sonnet-4.5-thinking",
|
|
||||||
}
|
|
||||||
|
|
||||||
type ModelAlias struct {
|
|
||||||
CursorID string
|
|
||||||
AnthropicID string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
var cursorToAnthropicAlias = []ModelAlias{
|
|
||||||
{"opus-4.6", "claude-opus-4-6", "Claude 4.6 Opus"},
|
|
||||||
{"opus-4.6-thinking", "claude-opus-4-6-thinking", "Claude 4.6 Opus (Thinking)"},
|
|
||||||
{"sonnet-4.6", "claude-sonnet-4-6", "Claude 4.6 Sonnet"},
|
|
||||||
{"sonnet-4.6-thinking", "claude-sonnet-4-6-thinking", "Claude 4.6 Sonnet (Thinking)"},
|
|
||||||
{"opus-4.5", "claude-opus-4-5", "Claude 4.5 Opus"},
|
|
||||||
{"opus-4.5-thinking", "claude-opus-4-5-thinking", "Claude 4.5 Opus (Thinking)"},
|
|
||||||
{"sonnet-4.5", "claude-sonnet-4-5", "Claude 4.5 Sonnet"},
|
|
||||||
{"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// cursorModelPattern matches cursor model IDs like "opus-4.6", "sonnet-4.7-thinking".
|
|
||||||
var cursorModelPattern = regexp.MustCompile(`^([a-zA-Z]+)-(\d+)\.(\d+)(-thinking)?$`)
|
|
||||||
|
|
||||||
// reverseDynamicPattern matches dynamically generated anthropic aliases
|
|
||||||
// like "claude-opus-4-7", "claude-sonnet-4-7-thinking".
|
|
||||||
var reverseDynamicPattern = regexp.MustCompile(`^claude-([a-zA-Z]+)-(\d+)-(\d+)(-thinking)?$`)
|
|
||||||
|
|
||||||
func generateDynamicAlias(cursorID string) (AnthropicAlias, bool) {
|
|
||||||
m := cursorModelPattern.FindStringSubmatch(cursorID)
|
|
||||||
if m == nil {
|
|
||||||
return AnthropicAlias{}, false
|
|
||||||
}
|
|
||||||
family := m[1]
|
|
||||||
major := m[2]
|
|
||||||
minor := m[3]
|
|
||||||
thinking := m[4]
|
|
||||||
|
|
||||||
anthropicID := "claude-" + family + "-" + major + "-" + minor + thinking
|
|
||||||
capFamily := strings.ToUpper(family[:1]) + family[1:]
|
|
||||||
name := capFamily + " " + major + "." + minor
|
|
||||||
if thinking == "-thinking" {
|
|
||||||
name += " (Thinking)"
|
|
||||||
}
|
|
||||||
return AnthropicAlias{ID: anthropicID, Name: name}, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func reverseDynamicAlias(anthropicID string) (string, bool) {
|
|
||||||
m := reverseDynamicPattern.FindStringSubmatch(anthropicID)
|
|
||||||
if m == nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return m[1] + "-" + m[2] + "." + m[3] + m[4], true
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResolveToCursorModel(requested string) string {
|
|
||||||
if strings.TrimSpace(requested) == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
key := strings.ToLower(strings.TrimSpace(requested))
|
|
||||||
if v, ok := anthropicToCursor[key]; ok {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
if v, ok := reverseDynamicAlias(key); ok {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(requested)
|
|
||||||
}
|
|
||||||
|
|
||||||
type AnthropicAlias struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAnthropicModelAliases(availableCursorIDs []string) []AnthropicAlias {
|
|
||||||
set := make(map[string]bool, len(availableCursorIDs))
|
|
||||||
for _, id := range availableCursorIDs {
|
|
||||||
set[id] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
staticSet := make(map[string]bool, len(cursorToAnthropicAlias))
|
|
||||||
var result []AnthropicAlias
|
|
||||||
for _, a := range cursorToAnthropicAlias {
|
|
||||||
if set[a.CursorID] {
|
|
||||||
staticSet[a.CursorID] = true
|
|
||||||
result = append(result, AnthropicAlias{ID: a.AnthropicID, Name: a.Name})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, id := range availableCursorIDs {
|
|
||||||
if staticSet[id] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if alias, ok := generateDynamicAlias(id); ok {
|
|
||||||
result = append(result, alias)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
@ -1,163 +0,0 @@
|
||||||
package models
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestGetAnthropicModelAliases_StaticOnly(t *testing.T) {
|
|
||||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.5"})
|
|
||||||
if len(aliases) != 2 {
|
|
||||||
t.Fatalf("expected 2 aliases, got %d: %v", len(aliases), aliases)
|
|
||||||
}
|
|
||||||
ids := map[string]string{}
|
|
||||||
for _, a := range aliases {
|
|
||||||
ids[a.ID] = a.Name
|
|
||||||
}
|
|
||||||
if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" {
|
|
||||||
t.Errorf("unexpected name for claude-sonnet-4-6: %s", ids["claude-sonnet-4-6"])
|
|
||||||
}
|
|
||||||
if ids["claude-opus-4-5"] != "Claude 4.5 Opus" {
|
|
||||||
t.Errorf("unexpected name for claude-opus-4-5: %s", ids["claude-opus-4-5"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAnthropicModelAliases_DynamicFallback(t *testing.T) {
|
|
||||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.7", "opus-5.0-thinking", "gpt-4o"})
|
|
||||||
ids := map[string]string{}
|
|
||||||
for _, a := range aliases {
|
|
||||||
ids[a.ID] = a.Name
|
|
||||||
}
|
|
||||||
if ids["claude-sonnet-4-7"] != "Sonnet 4.7" {
|
|
||||||
t.Errorf("unexpected name for claude-sonnet-4-7: %s", ids["claude-sonnet-4-7"])
|
|
||||||
}
|
|
||||||
if ids["claude-opus-5-0-thinking"] != "Opus 5.0 (Thinking)" {
|
|
||||||
t.Errorf("unexpected name for claude-opus-5-0-thinking: %s", ids["claude-opus-5-0-thinking"])
|
|
||||||
}
|
|
||||||
if _, ok := ids["claude-gpt-4o"]; ok {
|
|
||||||
t.Errorf("gpt-4o should not generate a claude alias")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAnthropicModelAliases_Mixed(t *testing.T) {
|
|
||||||
aliases := GetAnthropicModelAliases([]string{"sonnet-4.6", "opus-4.7", "gpt-4o"})
|
|
||||||
ids := map[string]string{}
|
|
||||||
for _, a := range aliases {
|
|
||||||
ids[a.ID] = a.Name
|
|
||||||
}
|
|
||||||
// static entry keeps its custom name
|
|
||||||
if ids["claude-sonnet-4-6"] != "Claude 4.6 Sonnet" {
|
|
||||||
t.Errorf("static alias should keep original name, got: %s", ids["claude-sonnet-4-6"])
|
|
||||||
}
|
|
||||||
// dynamic entry uses auto-generated name
|
|
||||||
if ids["claude-opus-4-7"] != "Opus 4.7" {
|
|
||||||
t.Errorf("dynamic alias name mismatch: %s", ids["claude-opus-4-7"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAnthropicModelAliases_UnknownPattern(t *testing.T) {
|
|
||||||
aliases := GetAnthropicModelAliases([]string{"some-unknown-model"})
|
|
||||||
if len(aliases) != 0 {
|
|
||||||
t.Fatalf("expected 0 aliases for unknown pattern, got %d: %v", len(aliases), aliases)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveToCursorModel_Static(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"claude-opus-4-6", "opus-4.6"},
|
|
||||||
{"claude-opus-4.6", "opus-4.6"},
|
|
||||||
{"claude-sonnet-4-5", "sonnet-4.5"},
|
|
||||||
{"claude-opus-4-6-thinking", "opus-4.6-thinking"},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
got := ResolveToCursorModel(tc.input)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveToCursorModel_DynamicFallback(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"claude-opus-4-7", "opus-4.7"},
|
|
||||||
{"claude-sonnet-5-0", "sonnet-5.0"},
|
|
||||||
{"claude-opus-4-7-thinking", "opus-4.7-thinking"},
|
|
||||||
{"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking"},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
got := ResolveToCursorModel(tc.input)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("ResolveToCursorModel(%q) = %q, want %q", tc.input, got, tc.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveToCursorModel_Passthrough(t *testing.T) {
|
|
||||||
tests := []string{"sonnet-4.6", "gpt-4o", "custom-model"}
|
|
||||||
for _, input := range tests {
|
|
||||||
got := ResolveToCursorModel(input)
|
|
||||||
if got != input {
|
|
||||||
t.Errorf("ResolveToCursorModel(%q) = %q, want passthrough %q", input, got, input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveToCursorModel_Empty(t *testing.T) {
|
|
||||||
if got := ResolveToCursorModel(""); got != "" {
|
|
||||||
t.Errorf("ResolveToCursorModel(\"\") = %q, want empty", got)
|
|
||||||
}
|
|
||||||
if got := ResolveToCursorModel(" "); got != "" {
|
|
||||||
t.Errorf("ResolveToCursorModel(\" \") = %q, want empty", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateDynamicAlias(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
want AnthropicAlias
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{"opus-4.7", AnthropicAlias{"claude-opus-4-7", "Opus 4.7"}, true},
|
|
||||||
{"sonnet-5.0-thinking", AnthropicAlias{"claude-sonnet-5-0-thinking", "Sonnet 5.0 (Thinking)"}, true},
|
|
||||||
{"gpt-4o", AnthropicAlias{}, false},
|
|
||||||
{"invalid", AnthropicAlias{}, false},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
got, ok := generateDynamicAlias(tc.input)
|
|
||||||
if ok != tc.ok {
|
|
||||||
t.Errorf("generateDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ok && (got.ID != tc.want.ID || got.Name != tc.want.Name) {
|
|
||||||
t.Errorf("generateDynamicAlias(%q) = {%q, %q}, want {%q, %q}", tc.input, got.ID, got.Name, tc.want.ID, tc.want.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReverseDynamicAlias(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
want string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{"claude-opus-4-7", "opus-4.7", true},
|
|
||||||
{"claude-sonnet-5-0-thinking", "sonnet-5.0-thinking", true},
|
|
||||||
{"claude-opus-4-6", "opus-4.6", true},
|
|
||||||
{"claude-opus-4.6", "", false},
|
|
||||||
{"claude-haiku-4-5-20251001", "", false},
|
|
||||||
{"some-model", "", false},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
got, ok := reverseDynamicAlias(tc.input)
|
|
||||||
if ok != tc.ok {
|
|
||||||
t.Errorf("reverseDynamicAlias(%q) ok = %v, want %v", tc.input, ok, tc.ok)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ok && got != tc.want {
|
|
||||||
t.Errorf("reverseDynamicAlias(%q) = %q, want %q", tc.input, got, tc.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,243 +0,0 @@
|
||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []interface{} `json:"messages"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
Tools []interface{} `json:"tools"`
|
|
||||||
ToolChoice interface{} `json:"tool_choice"`
|
|
||||||
Functions []interface{} `json:"functions"`
|
|
||||||
FunctionCall interface{} `json:"function_call"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NormalizeModelID(raw string) string {
|
|
||||||
trimmed := strings.TrimSpace(raw)
|
|
||||||
if trimmed == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
parts := strings.Split(trimmed, "/")
|
|
||||||
last := parts[len(parts)-1]
|
|
||||||
if last == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return last
|
|
||||||
}
|
|
||||||
|
|
||||||
func imageURLToText(imageURL interface{}) string {
|
|
||||||
if imageURL == nil {
|
|
||||||
return "[Image]"
|
|
||||||
}
|
|
||||||
var url string
|
|
||||||
switch v := imageURL.(type) {
|
|
||||||
case string:
|
|
||||||
url = v
|
|
||||||
case map[string]interface{}:
|
|
||||||
if u, ok := v["url"].(string); ok {
|
|
||||||
url = u
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if url == "" {
|
|
||||||
return "[Image]"
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(url, "data:") {
|
|
||||||
end := strings.Index(url, ";")
|
|
||||||
mime := "image"
|
|
||||||
if end > 5 {
|
|
||||||
mime = url[5:end]
|
|
||||||
}
|
|
||||||
return "[Image: base64 " + mime + "]"
|
|
||||||
}
|
|
||||||
return "[Image: " + url + "]"
|
|
||||||
}
|
|
||||||
|
|
||||||
func MessageContentToText(content interface{}) string {
|
|
||||||
if content == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
switch v := content.(type) {
|
|
||||||
case string:
|
|
||||||
return v
|
|
||||||
case []interface{}:
|
|
||||||
var parts []string
|
|
||||||
for _, p := range v {
|
|
||||||
if p == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch part := p.(type) {
|
|
||||||
case string:
|
|
||||||
parts = append(parts, part)
|
|
||||||
case map[string]interface{}:
|
|
||||||
typ, _ := part["type"].(string)
|
|
||||||
switch typ {
|
|
||||||
case "text":
|
|
||||||
if t, ok := part["text"].(string); ok {
|
|
||||||
parts = append(parts, t)
|
|
||||||
}
|
|
||||||
case "image_url":
|
|
||||||
parts = append(parts, imageURLToText(part["image_url"]))
|
|
||||||
case "image":
|
|
||||||
src := part["source"]
|
|
||||||
if src == nil {
|
|
||||||
src = part["url"]
|
|
||||||
}
|
|
||||||
parts = append(parts, imageURLToText(src))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(parts, " ")
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func ToolsToSystemText(tools []interface{}, functions []interface{}) string {
|
|
||||||
var defs []interface{}
|
|
||||||
|
|
||||||
for _, t := range tools {
|
|
||||||
if m, ok := t.(map[string]interface{}); ok {
|
|
||||||
if m["type"] == "function" {
|
|
||||||
if fn := m["function"]; fn != nil {
|
|
||||||
defs = append(defs, fn)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
defs = append(defs, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defs = append(defs, functions...)
|
|
||||||
|
|
||||||
if len(defs) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var lines []string
|
|
||||||
lines = append(lines, "Available tools (respond with a JSON object to call one):", "")
|
|
||||||
|
|
||||||
for _, raw := range defs {
|
|
||||||
fn, ok := raw.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name, _ := fn["name"].(string)
|
|
||||||
desc, _ := fn["description"].(string)
|
|
||||||
params := "{}"
|
|
||||||
if p := fn["parameters"]; p != nil {
|
|
||||||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
|
||||||
params = string(b)
|
|
||||||
}
|
|
||||||
} else if p := fn["input_schema"]; p != nil {
|
|
||||||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
|
||||||
params = string(b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params)
|
|
||||||
}
|
|
||||||
|
|
||||||
lines = append(lines, "",
|
|
||||||
"When you want to call a tool, use this EXACT format:",
|
|
||||||
"",
|
|
||||||
"<tool_call>",
|
|
||||||
`{"name": "function_name", "arguments": {"param1": "value1"}}`,
|
|
||||||
"</tool_call>",
|
|
||||||
"",
|
|
||||||
"Rules:",
|
|
||||||
"- Write your reasoning BEFORE the tool call",
|
|
||||||
"- You may make multiple tool calls by using multiple <tool_call> blocks",
|
|
||||||
"- STOP writing after the last </tool_call> tag",
|
|
||||||
"- If no tool is needed, respond normally without <tool_call> tags",
|
|
||||||
)
|
|
||||||
|
|
||||||
return strings.Join(lines, "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
type SimpleMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func BuildPromptFromMessages(messages []interface{}) string {
|
|
||||||
var systemParts []string
|
|
||||||
var convo []string
|
|
||||||
|
|
||||||
for _, raw := range messages {
|
|
||||||
m, ok := raw.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
role, _ := m["role"].(string)
|
|
||||||
text := MessageContentToText(m["content"])
|
|
||||||
|
|
||||||
switch role {
|
|
||||||
case "system", "developer":
|
|
||||||
if text != "" {
|
|
||||||
systemParts = append(systemParts, text)
|
|
||||||
}
|
|
||||||
case "user":
|
|
||||||
if text != "" {
|
|
||||||
convo = append(convo, "User: "+text)
|
|
||||||
}
|
|
||||||
case "assistant":
|
|
||||||
toolCalls, _ := m["tool_calls"].([]interface{})
|
|
||||||
if len(toolCalls) > 0 {
|
|
||||||
var parts []string
|
|
||||||
if text != "" {
|
|
||||||
parts = append(parts, text)
|
|
||||||
}
|
|
||||||
for _, tc := range toolCalls {
|
|
||||||
tcMap, ok := tc.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fn, _ := tcMap["function"].(map[string]interface{})
|
|
||||||
if fn == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name, _ := fn["name"].(string)
|
|
||||||
args, _ := fn["arguments"].(string)
|
|
||||||
if args == "" {
|
|
||||||
args = "{}"
|
|
||||||
}
|
|
||||||
parts = append(parts, fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, args))
|
|
||||||
}
|
|
||||||
if len(parts) > 0 {
|
|
||||||
convo = append(convo, "Assistant: "+strings.Join(parts, "\n"))
|
|
||||||
}
|
|
||||||
} else if text != "" {
|
|
||||||
convo = append(convo, "Assistant: "+text)
|
|
||||||
}
|
|
||||||
case "tool", "function":
|
|
||||||
name, _ := m["name"].(string)
|
|
||||||
toolCallID, _ := m["tool_call_id"].(string)
|
|
||||||
label := "Tool result"
|
|
||||||
if name != "" {
|
|
||||||
label = "Tool result (" + name + ")"
|
|
||||||
}
|
|
||||||
if toolCallID != "" {
|
|
||||||
label += " [id=" + toolCallID + "]"
|
|
||||||
}
|
|
||||||
if text != "" {
|
|
||||||
convo = append(convo, label+": "+text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
system := ""
|
|
||||||
if len(systemParts) > 0 {
|
|
||||||
system = "System:\n" + strings.Join(systemParts, "\n\n") + "\n\n"
|
|
||||||
}
|
|
||||||
transcript := strings.Join(convo, "\n\n")
|
|
||||||
return system + transcript + "\n\nAssistant:"
|
|
||||||
}
|
|
||||||
|
|
||||||
func BuildPromptFromSimpleMessages(messages []SimpleMessage) string {
|
|
||||||
ifaces := make([]interface{}, len(messages))
|
|
||||||
for i, m := range messages {
|
|
||||||
ifaces[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
|
||||||
}
|
|
||||||
return BuildPromptFromMessages(ifaces)
|
|
||||||
}
|
|
||||||
|
|
@ -1,80 +0,0 @@
|
||||||
package openai
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestNormalizeModelID(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"gpt-4", "gpt-4"},
|
|
||||||
{"openai/gpt-4", "gpt-4"},
|
|
||||||
{"anthropic/claude-3", "claude-3"},
|
|
||||||
{"", ""},
|
|
||||||
{" ", ""},
|
|
||||||
{"a/b/c", "c"},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
got := NormalizeModelID(tc.input)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("NormalizeModelID(%q) = %q, want %q", tc.input, got, tc.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildPromptFromMessages(t *testing.T) {
|
|
||||||
messages := []interface{}{
|
|
||||||
map[string]interface{}{"role": "system", "content": "You are helpful."},
|
|
||||||
map[string]interface{}{"role": "user", "content": "Hello"},
|
|
||||||
map[string]interface{}{"role": "assistant", "content": "Hi there"},
|
|
||||||
}
|
|
||||||
got := BuildPromptFromMessages(messages)
|
|
||||||
if got == "" {
|
|
||||||
t.Fatal("expected non-empty prompt")
|
|
||||||
}
|
|
||||||
containsSystem := false
|
|
||||||
containsUser := false
|
|
||||||
containsAssistant := false
|
|
||||||
for i := 0; i < len(got)-10; i++ {
|
|
||||||
if got[i:i+6] == "System" {
|
|
||||||
containsSystem = true
|
|
||||||
}
|
|
||||||
if got[i:i+4] == "User" {
|
|
||||||
containsUser = true
|
|
||||||
}
|
|
||||||
if got[i:i+9] == "Assistant" {
|
|
||||||
containsAssistant = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !containsSystem || !containsUser || !containsAssistant {
|
|
||||||
t.Errorf("prompt missing sections: system=%v user=%v assistant=%v\n%s",
|
|
||||||
containsSystem, containsUser, containsAssistant, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToolsToSystemText(t *testing.T) {
|
|
||||||
tools := []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"type": "function",
|
|
||||||
"function": map[string]interface{}{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get weather",
|
|
||||||
"parameters": map[string]interface{}{"type": "object"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := ToolsToSystemText(tools, nil)
|
|
||||||
if got == "" {
|
|
||||||
t.Fatal("expected non-empty tools text")
|
|
||||||
}
|
|
||||||
if len(got) < 10 {
|
|
||||||
t.Errorf("tools text too short: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToolsToSystemTextEmpty(t *testing.T) {
|
|
||||||
got := ToolsToSystemText(nil, nil)
|
|
||||||
if got != "" {
|
|
||||||
t.Errorf("expected empty string for no tools, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,110 +0,0 @@
|
||||||
package parser
|
|
||||||
|
|
||||||
import "encoding/json"
|
|
||||||
|
|
||||||
type StreamParser func(line string)
|
|
||||||
|
|
||||||
type Parser struct {
|
|
||||||
Parse StreamParser
|
|
||||||
Flush func()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateStreamParser 建立串流解析器(向後相容,不傳遞 thinking)
|
|
||||||
func CreateStreamParser(onText func(string), onDone func()) Parser {
|
|
||||||
return CreateStreamParserWithThinking(onText, nil, onDone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateStreamParserWithThinking 建立串流解析器,支援思考過程輸出。
|
|
||||||
// onThinking 可為 nil,表示忽略思考過程。
|
|
||||||
func CreateStreamParserWithThinking(onText func(string), onThinking func(string), onDone func()) Parser {
|
|
||||||
// accumulated 是所有已輸出內容的串接
|
|
||||||
accumulatedText := ""
|
|
||||||
accumulatedThinking := ""
|
|
||||||
done := false
|
|
||||||
|
|
||||||
parse := func(line string) {
|
|
||||||
if done {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var obj struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Subtype string `json:"subtype"`
|
|
||||||
Message *struct {
|
|
||||||
Content []struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
Thinking string `json:"thinking"`
|
|
||||||
} `json:"content"`
|
|
||||||
} `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal([]byte(line), &obj); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if obj.Type == "assistant" && obj.Message != nil {
|
|
||||||
fullText := ""
|
|
||||||
fullThinking := ""
|
|
||||||
for _, p := range obj.Message.Content {
|
|
||||||
switch p.Type {
|
|
||||||
case "text":
|
|
||||||
if p.Text != "" {
|
|
||||||
fullText += p.Text
|
|
||||||
}
|
|
||||||
case "thinking":
|
|
||||||
if p.Thinking != "" {
|
|
||||||
fullThinking += p.Thinking
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 處理思考過程(不因去重而 return,避免跳過同行的文字內容)
|
|
||||||
if onThinking != nil && fullThinking != "" && fullThinking != accumulatedThinking {
|
|
||||||
// 增量模式:新內容以 accumulated 為前綴
|
|
||||||
if len(fullThinking) >= len(accumulatedThinking) && fullThinking[:len(accumulatedThinking)] == accumulatedThinking {
|
|
||||||
delta := fullThinking[len(accumulatedThinking):]
|
|
||||||
if delta != "" {
|
|
||||||
onThinking(delta)
|
|
||||||
}
|
|
||||||
accumulatedThinking = fullThinking
|
|
||||||
} else {
|
|
||||||
// 獨立片段:直接輸出,但 accumulated 要串接
|
|
||||||
onThinking(fullThinking)
|
|
||||||
accumulatedThinking = accumulatedThinking + fullThinking
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 處理一般文字
|
|
||||||
if fullText == "" || fullText == accumulatedText {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 增量模式:新內容以 accumulated 為前綴
|
|
||||||
if len(fullText) >= len(accumulatedText) && fullText[:len(accumulatedText)] == accumulatedText {
|
|
||||||
delta := fullText[len(accumulatedText):]
|
|
||||||
if delta != "" {
|
|
||||||
onText(delta)
|
|
||||||
}
|
|
||||||
accumulatedText = fullText
|
|
||||||
} else {
|
|
||||||
// 獨立片段:直接輸出,但 accumulated 要串接
|
|
||||||
onText(fullText)
|
|
||||||
accumulatedText = accumulatedText + fullText
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if obj.Type == "result" && obj.Subtype == "success" {
|
|
||||||
done = true
|
|
||||||
onDone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
flush := func() {
|
|
||||||
if !done {
|
|
||||||
done = true
|
|
||||||
onDone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Parser{Parse: parse, Flush: flush}
|
|
||||||
}
|
|
||||||
|
|
@ -1,304 +0,0 @@
|
||||||
package parser
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func makeAssistantLine(text string) string {
|
|
||||||
obj := map[string]interface{}{
|
|
||||||
"type": "assistant",
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"content": []map[string]interface{}{
|
|
||||||
{"type": "text", "text": text},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(obj)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeResultLine() string {
|
|
||||||
b, _ := json.Marshal(map[string]string{"type": "result", "subtype": "success"})
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserFragmentMode(t *testing.T) {
|
|
||||||
// cursor --stream-partial-output 模式:每個訊息是獨立 token fragment
|
|
||||||
var texts []string
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func() {},
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeAssistantLine("你"))
|
|
||||||
p.Parse(makeAssistantLine("好!有"))
|
|
||||||
p.Parse(makeAssistantLine("什"))
|
|
||||||
p.Parse(makeAssistantLine("麼"))
|
|
||||||
|
|
||||||
if len(texts) != 4 {
|
|
||||||
t.Fatalf("expected 4 fragments, got %d: %v", len(texts), texts)
|
|
||||||
}
|
|
||||||
if texts[0] != "你" || texts[1] != "好!有" || texts[2] != "什" || texts[3] != "麼" {
|
|
||||||
t.Fatalf("unexpected texts: %v", texts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserDeduplicatesFinalFullText(t *testing.T) {
|
|
||||||
// 最後一個訊息是完整的累積文字,應被跳過(去重)
|
|
||||||
var texts []string
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func() {},
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeAssistantLine("Hello"))
|
|
||||||
p.Parse(makeAssistantLine(" world"))
|
|
||||||
// 最後一個是完整累積文字,應被去重
|
|
||||||
p.Parse(makeAssistantLine("Hello world"))
|
|
||||||
|
|
||||||
if len(texts) != 2 {
|
|
||||||
t.Fatalf("expected 2 fragments (final full text deduplicated), got %d: %v", len(texts), texts)
|
|
||||||
}
|
|
||||||
if texts[0] != "Hello" || texts[1] != " world" {
|
|
||||||
t.Fatalf("unexpected texts: %v", texts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserCallsOnDone(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
doneCount := 0
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func() { doneCount++ },
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeResultLine())
|
|
||||||
if doneCount != 1 {
|
|
||||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
|
||||||
}
|
|
||||||
if len(texts) != 0 {
|
|
||||||
t.Fatalf("expected no text, got %v", texts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserIgnoresLinesAfterDone(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
doneCount := 0
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func() { doneCount++ },
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeResultLine())
|
|
||||||
p.Parse(makeAssistantLine("late"))
|
|
||||||
if len(texts) != 0 {
|
|
||||||
t.Fatalf("expected no text after done, got %v", texts)
|
|
||||||
}
|
|
||||||
if doneCount != 1 {
|
|
||||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserIgnoresNonAssistantLines(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func() {},
|
|
||||||
)
|
|
||||||
|
|
||||||
b1, _ := json.Marshal(map[string]interface{}{"type": "user", "message": map[string]interface{}{}})
|
|
||||||
p.Parse(string(b1))
|
|
||||||
b2, _ := json.Marshal(map[string]interface{}{
|
|
||||||
"type": "assistant",
|
|
||||||
"message": map[string]interface{}{"content": []interface{}{}},
|
|
||||||
})
|
|
||||||
p.Parse(string(b2))
|
|
||||||
p.Parse(`{"type":"assistant","message":{"content":[{"type":"code","text":"x"}]}}`)
|
|
||||||
|
|
||||||
if len(texts) != 0 {
|
|
||||||
t.Fatalf("expected no texts, got %v", texts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserIgnoresParseErrors(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
doneCount := 0
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func() { doneCount++ },
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse("not json")
|
|
||||||
p.Parse("{")
|
|
||||||
p.Parse("")
|
|
||||||
|
|
||||||
if len(texts) != 0 || doneCount != 0 {
|
|
||||||
t.Fatalf("expected nothing, got texts=%v done=%d", texts, doneCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserJoinsMultipleTextParts(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func() {},
|
|
||||||
)
|
|
||||||
|
|
||||||
obj := map[string]interface{}{
|
|
||||||
"type": "assistant",
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"content": []map[string]interface{}{
|
|
||||||
{"type": "text", "text": "Hello"},
|
|
||||||
{"type": "text", "text": " "},
|
|
||||||
{"type": "text", "text": "world"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(obj)
|
|
||||||
p.Parse(string(b))
|
|
||||||
|
|
||||||
if len(texts) != 1 || texts[0] != "Hello world" {
|
|
||||||
t.Fatalf("expected ['Hello world'], got %v", texts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserFlushTriggersDone(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
doneCount := 0
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func() { doneCount++ },
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeAssistantLine("Hello"))
|
|
||||||
// agent 結束但沒有 result/success,手動 flush
|
|
||||||
p.Flush()
|
|
||||||
if doneCount != 1 {
|
|
||||||
t.Fatalf("expected onDone called once after Flush, got %d", doneCount)
|
|
||||||
}
|
|
||||||
// 再 flush 不應重複觸發
|
|
||||||
p.Flush()
|
|
||||||
if doneCount != 1 {
|
|
||||||
t.Fatalf("expected onDone called only once, got %d", doneCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserFlushAfterDoneIsNoop(t *testing.T) {
|
|
||||||
doneCount := 0
|
|
||||||
p := CreateStreamParser(
|
|
||||||
func(text string) {},
|
|
||||||
func() { doneCount++ },
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeResultLine())
|
|
||||||
p.Flush()
|
|
||||||
if doneCount != 1 {
|
|
||||||
t.Fatalf("expected onDone called once, got %d", doneCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeThinkingLine(thinking string) string {
|
|
||||||
obj := map[string]interface{}{
|
|
||||||
"type": "assistant",
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"content": []map[string]interface{}{
|
|
||||||
{"type": "thinking", "thinking": thinking},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(obj)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeThinkingAndTextLine(thinking, text string) string {
|
|
||||||
obj := map[string]interface{}{
|
|
||||||
"type": "assistant",
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"content": []map[string]interface{}{
|
|
||||||
{"type": "thinking", "thinking": thinking},
|
|
||||||
{"type": "text", "text": text},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(obj)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserWithThinkingCallsOnThinking(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
var thinkings []string
|
|
||||||
p := CreateStreamParserWithThinking(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
|
||||||
func() {},
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeThinkingLine("思考中..."))
|
|
||||||
p.Parse(makeAssistantLine("回答"))
|
|
||||||
|
|
||||||
if len(thinkings) != 1 || thinkings[0] != "思考中..." {
|
|
||||||
t.Fatalf("expected thinkings=['思考中...'], got %v", thinkings)
|
|
||||||
}
|
|
||||||
if len(texts) != 1 || texts[0] != "回答" {
|
|
||||||
t.Fatalf("expected texts=['回答'], got %v", texts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserWithThinkingNilOnThinkingIgnoresThinking(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
p := CreateStreamParserWithThinking(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
nil,
|
|
||||||
func() {},
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeThinkingLine("忽略的思考"))
|
|
||||||
p.Parse(makeAssistantLine("文字"))
|
|
||||||
|
|
||||||
if len(texts) != 1 || texts[0] != "文字" {
|
|
||||||
t.Fatalf("expected texts=['文字'], got %v", texts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamParserWithThinkingDeduplication(t *testing.T) {
|
|
||||||
var thinkings []string
|
|
||||||
p := CreateStreamParserWithThinking(
|
|
||||||
func(text string) {},
|
|
||||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
|
||||||
func() {},
|
|
||||||
)
|
|
||||||
|
|
||||||
p.Parse(makeThinkingLine("A"))
|
|
||||||
p.Parse(makeThinkingLine("B"))
|
|
||||||
// 重複的完整思考,應被跳過
|
|
||||||
p.Parse(makeThinkingLine("AB"))
|
|
||||||
|
|
||||||
if len(thinkings) != 2 || thinkings[0] != "A" || thinkings[1] != "B" {
|
|
||||||
t.Fatalf("expected thinkings=['A','B'], got %v", thinkings)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestStreamParserThinkingDuplicateButTextStillEmitted 驗證 bug 修復:
|
|
||||||
// 當 thinking 重複(去重跳過)但同一行有 text 時,text 仍必須輸出。
|
|
||||||
func TestStreamParserThinkingDuplicateButTextStillEmitted(t *testing.T) {
|
|
||||||
var texts []string
|
|
||||||
var thinkings []string
|
|
||||||
p := CreateStreamParserWithThinking(
|
|
||||||
func(text string) { texts = append(texts, text) },
|
|
||||||
func(thinking string) { thinkings = append(thinkings, thinking) },
|
|
||||||
func() {},
|
|
||||||
)
|
|
||||||
|
|
||||||
// 第一行:thinking="思考中" + text(thinking 為新增,兩者都應輸出)
|
|
||||||
p.Parse(makeThinkingAndTextLine("思考中", "第一段"))
|
|
||||||
// 第二行:thinking 與上一行相同(去重),但 text 是新的,text 仍應輸出
|
|
||||||
p.Parse(makeThinkingAndTextLine("思考中", "第二段"))
|
|
||||||
|
|
||||||
if len(thinkings) != 1 || thinkings[0] != "思考中" {
|
|
||||||
t.Fatalf("expected thinkings=['思考中'], got %v", thinkings)
|
|
||||||
}
|
|
||||||
if len(texts) != 2 || texts[0] != "第一段" || texts[1] != "第二段" {
|
|
||||||
t.Fatalf("expected texts=['第一段','第二段'], got %v", texts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,284 +0,0 @@
|
||||||
package pool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type accountStatus struct {
|
|
||||||
configDir string
|
|
||||||
activeRequests int
|
|
||||||
lastUsed int64
|
|
||||||
rateLimitUntil int64
|
|
||||||
totalRequests int
|
|
||||||
totalSuccess int
|
|
||||||
totalErrors int
|
|
||||||
totalRateLimits int
|
|
||||||
totalLatencyMs int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type AccountStat struct {
|
|
||||||
ConfigDir string
|
|
||||||
ActiveRequests int
|
|
||||||
TotalRequests int
|
|
||||||
TotalSuccess int
|
|
||||||
TotalErrors int
|
|
||||||
TotalRateLimits int
|
|
||||||
TotalLatencyMs int64
|
|
||||||
IsRateLimited bool
|
|
||||||
RateLimitUntil int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type AccountPool struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
accounts []*accountStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAccountPool(configDirs []string) *AccountPool {
|
|
||||||
accounts := make([]*accountStatus, 0, len(configDirs))
|
|
||||||
for _, dir := range configDirs {
|
|
||||||
accounts = append(accounts, &accountStatus{configDir: dir})
|
|
||||||
}
|
|
||||||
return &AccountPool{accounts: accounts}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) GetNextConfigDir() string {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
if len(p.accounts) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
available := make([]*accountStatus, 0, len(p.accounts))
|
|
||||||
for _, a := range p.accounts {
|
|
||||||
if a.rateLimitUntil < now {
|
|
||||||
available = append(available, a)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
target := available
|
|
||||||
if len(target) == 0 {
|
|
||||||
target = make([]*accountStatus, len(p.accounts))
|
|
||||||
copy(target, p.accounts)
|
|
||||||
// sort by earliest recovery
|
|
||||||
for i := 1; i < len(target); i++ {
|
|
||||||
for j := i; j > 0 && target[j].rateLimitUntil < target[j-1].rateLimitUntil; j-- {
|
|
||||||
target[j], target[j-1] = target[j-1], target[j]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// pick least busy then least recently used
|
|
||||||
best := target[0]
|
|
||||||
for _, a := range target[1:] {
|
|
||||||
if a.activeRequests < best.activeRequests {
|
|
||||||
best = a
|
|
||||||
} else if a.activeRequests == best.activeRequests && a.lastUsed < best.lastUsed {
|
|
||||||
best = a
|
|
||||||
}
|
|
||||||
}
|
|
||||||
best.lastUsed = now
|
|
||||||
return best.configDir
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) find(configDir string) *accountStatus {
|
|
||||||
for _, a := range p.accounts {
|
|
||||||
if a.configDir == configDir {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) ReportRequestStart(configDir string) {
|
|
||||||
if configDir == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
if a := p.find(configDir); a != nil {
|
|
||||||
a.activeRequests++
|
|
||||||
a.totalRequests++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) ReportRequestEnd(configDir string) {
|
|
||||||
if configDir == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
if a := p.find(configDir); a != nil && a.activeRequests > 0 {
|
|
||||||
a.activeRequests--
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) ReportRequestSuccess(configDir string, latencyMs int64) {
|
|
||||||
if configDir == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
if a := p.find(configDir); a != nil {
|
|
||||||
a.totalSuccess++
|
|
||||||
a.totalLatencyMs += latencyMs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) ReportRequestError(configDir string, latencyMs int64) {
|
|
||||||
if configDir == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
if a := p.find(configDir); a != nil {
|
|
||||||
a.totalErrors++
|
|
||||||
a.totalLatencyMs += latencyMs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) ReportRateLimit(configDir string, penaltyMs int64) {
|
|
||||||
if configDir == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if penaltyMs <= 0 {
|
|
||||||
penaltyMs = 60000
|
|
||||||
}
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
if a := p.find(configDir); a != nil {
|
|
||||||
a.rateLimitUntil = time.Now().UnixMilli() + penaltyMs
|
|
||||||
a.totalRateLimits++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) GetStats() []AccountStat {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
now := time.Now().UnixMilli()
|
|
||||||
stats := make([]AccountStat, len(p.accounts))
|
|
||||||
for i, a := range p.accounts {
|
|
||||||
stats[i] = AccountStat{
|
|
||||||
ConfigDir: a.configDir,
|
|
||||||
ActiveRequests: a.activeRequests,
|
|
||||||
TotalRequests: a.totalRequests,
|
|
||||||
TotalSuccess: a.totalSuccess,
|
|
||||||
TotalErrors: a.totalErrors,
|
|
||||||
TotalRateLimits: a.totalRateLimits,
|
|
||||||
TotalLatencyMs: a.totalLatencyMs,
|
|
||||||
IsRateLimited: a.rateLimitUntil > now,
|
|
||||||
RateLimitUntil: a.rateLimitUntil,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return stats
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AccountPool) Count() int {
|
|
||||||
return len(p.accounts)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// ─── PoolHandle interface ──────────────────────────────────────────────────
|
|
||||||
// PoolHandle 讓 handler 可以注入獨立的 pool 實例,避免多 port 模式共用全域 pool。
|
|
||||||
|
|
||||||
type PoolHandle interface {
|
|
||||||
GetNextConfigDir() string
|
|
||||||
ReportRequestStart(configDir string)
|
|
||||||
ReportRequestEnd(configDir string)
|
|
||||||
ReportRequestSuccess(configDir string, latencyMs int64)
|
|
||||||
ReportRequestError(configDir string, latencyMs int64)
|
|
||||||
ReportRateLimit(configDir string, penaltyMs int64)
|
|
||||||
GetStats() []AccountStat
|
|
||||||
}
|
|
||||||
|
|
||||||
// GlobalPoolHandle 包裝全域函式以實作 PoolHandle 介面(單 port 模式使用)
|
|
||||||
type GlobalPoolHandle struct{}
|
|
||||||
|
|
||||||
func (GlobalPoolHandle) GetNextConfigDir() string { return GetNextAccountConfigDir() }
|
|
||||||
func (GlobalPoolHandle) ReportRequestStart(d string) { ReportRequestStart(d) }
|
|
||||||
func (GlobalPoolHandle) ReportRequestEnd(d string) { ReportRequestEnd(d) }
|
|
||||||
func (GlobalPoolHandle) ReportRequestSuccess(d string, l int64) { ReportRequestSuccess(d, l) }
|
|
||||||
func (GlobalPoolHandle) ReportRequestError(d string, l int64) { ReportRequestError(d, l) }
|
|
||||||
func (GlobalPoolHandle) ReportRateLimit(d string, p int64) { ReportRateLimit(d, p) }
|
|
||||||
func (GlobalPoolHandle) GetStats() []AccountStat { return GetAccountStats() }
|
|
||||||
|
|
||||||
// ─── Global pool ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
var (
|
|
||||||
globalPool *AccountPool
|
|
||||||
globalMu sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
func InitAccountPool(configDirs []string) {
|
|
||||||
globalMu.Lock()
|
|
||||||
defer globalMu.Unlock()
|
|
||||||
globalPool = NewAccountPool(configDirs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetNextAccountConfigDir() string {
|
|
||||||
globalMu.Lock()
|
|
||||||
p := globalPool
|
|
||||||
globalMu.Unlock()
|
|
||||||
if p == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return p.GetNextConfigDir()
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportRequestStart(configDir string) {
|
|
||||||
globalMu.Lock()
|
|
||||||
p := globalPool
|
|
||||||
globalMu.Unlock()
|
|
||||||
if p != nil {
|
|
||||||
p.ReportRequestStart(configDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportRequestEnd(configDir string) {
|
|
||||||
globalMu.Lock()
|
|
||||||
p := globalPool
|
|
||||||
globalMu.Unlock()
|
|
||||||
if p != nil {
|
|
||||||
p.ReportRequestEnd(configDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportRequestSuccess(configDir string, latencyMs int64) {
|
|
||||||
globalMu.Lock()
|
|
||||||
p := globalPool
|
|
||||||
globalMu.Unlock()
|
|
||||||
if p != nil {
|
|
||||||
p.ReportRequestSuccess(configDir, latencyMs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportRequestError(configDir string, latencyMs int64) {
|
|
||||||
globalMu.Lock()
|
|
||||||
p := globalPool
|
|
||||||
globalMu.Unlock()
|
|
||||||
if p != nil {
|
|
||||||
p.ReportRequestError(configDir, latencyMs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportRateLimit(configDir string, penaltyMs int64) {
|
|
||||||
globalMu.Lock()
|
|
||||||
p := globalPool
|
|
||||||
globalMu.Unlock()
|
|
||||||
if p != nil {
|
|
||||||
p.ReportRateLimit(configDir, penaltyMs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAccountStats() []AccountStat {
|
|
||||||
globalMu.Lock()
|
|
||||||
p := globalPool
|
|
||||||
globalMu.Unlock()
|
|
||||||
if p == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return p.GetStats()
|
|
||||||
}
|
|
||||||
|
|
@ -1,152 +0,0 @@
|
||||||
package pool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEmptyPool(t *testing.T) {
|
|
||||||
p := NewAccountPool(nil)
|
|
||||||
if got := p.GetNextConfigDir(); got != "" {
|
|
||||||
t.Fatalf("expected empty string for empty pool, got %q", got)
|
|
||||||
}
|
|
||||||
if p.Count() != 0 {
|
|
||||||
t.Fatalf("expected count 0, got %d", p.Count())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSingleDir(t *testing.T) {
|
|
||||||
p := NewAccountPool([]string{"/dir1"})
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
|
||||||
t.Fatalf("expected /dir1, got %q", got)
|
|
||||||
}
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
|
||||||
t.Fatalf("expected /dir1 again, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoundRobin(t *testing.T) {
|
|
||||||
p := NewAccountPool([]string{"/a", "/b", "/c"})
|
|
||||||
got := []string{
|
|
||||||
p.GetNextConfigDir(),
|
|
||||||
p.GetNextConfigDir(),
|
|
||||||
p.GetNextConfigDir(),
|
|
||||||
p.GetNextConfigDir(),
|
|
||||||
}
|
|
||||||
want := []string{"/a", "/b", "/c", "/a"}
|
|
||||||
for i, w := range want {
|
|
||||||
if got[i] != w {
|
|
||||||
t.Fatalf("call %d: expected %q, got %q", i, w, got[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLeastBusy(t *testing.T) {
|
|
||||||
p := NewAccountPool([]string{"/dir1", "/dir2", "/dir3"})
|
|
||||||
p.ReportRequestStart("/dir1")
|
|
||||||
p.ReportRequestStart("/dir2")
|
|
||||||
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir3" {
|
|
||||||
t.Fatalf("expected /dir3 (least busy), got %q", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.ReportRequestStart("/dir3")
|
|
||||||
p.ReportRequestEnd("/dir1")
|
|
||||||
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
|
||||||
t.Fatalf("expected /dir1 after end, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSkipsRateLimited(t *testing.T) {
|
|
||||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
|
||||||
p.ReportRateLimit("/dir1", 60000)
|
|
||||||
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
|
||||||
t.Fatalf("expected /dir2, got %q", got)
|
|
||||||
}
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
|
||||||
t.Fatalf("expected /dir2 again, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFallbackToSoonestRecovery(t *testing.T) {
|
|
||||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
|
||||||
p.ReportRateLimit("/dir1", 60000)
|
|
||||||
p.ReportRateLimit("/dir2", 30000)
|
|
||||||
|
|
||||||
// dir2 recovers sooner — should be selected
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
|
||||||
t.Fatalf("expected /dir2 (sooner recovery), got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestActiveRequestsDoesNotGoNegative(t *testing.T) {
|
|
||||||
p := NewAccountPool([]string{"/dir1"})
|
|
||||||
p.ReportRequestEnd("/dir1")
|
|
||||||
p.ReportRequestEnd("/dir1")
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
|
||||||
t.Fatalf("pool should still work, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIgnoreUnknownConfigDir(t *testing.T) {
|
|
||||||
p := NewAccountPool([]string{"/dir1"})
|
|
||||||
p.ReportRequestStart("/nonexistent")
|
|
||||||
p.ReportRequestEnd("/nonexistent")
|
|
||||||
p.ReportRateLimit("/nonexistent", 60000)
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
|
||||||
t.Fatalf("expected /dir1, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimitExpires(t *testing.T) {
|
|
||||||
p := NewAccountPool([]string{"/dir1", "/dir2"})
|
|
||||||
p.ReportRateLimit("/dir1", 50)
|
|
||||||
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir2" {
|
|
||||||
t.Fatalf("immediately expected /dir2, got %q", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
if got := p.GetNextConfigDir(); got != "/dir1" {
|
|
||||||
t.Fatalf("after expiry expected /dir1, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGlobalPool(t *testing.T) {
|
|
||||||
InitAccountPool([]string{"/g1", "/g2"})
|
|
||||||
if got := GetNextAccountConfigDir(); got != "/g1" {
|
|
||||||
t.Fatalf("expected /g1, got %q", got)
|
|
||||||
}
|
|
||||||
if got := GetNextAccountConfigDir(); got != "/g2" {
|
|
||||||
t.Fatalf("expected /g2, got %q", got)
|
|
||||||
}
|
|
||||||
if got := GetNextAccountConfigDir(); got != "/g1" {
|
|
||||||
t.Fatalf("expected /g1 again, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGlobalPoolEmpty(t *testing.T) {
|
|
||||||
InitAccountPool(nil)
|
|
||||||
if got := GetNextAccountConfigDir(); got != "" {
|
|
||||||
t.Fatalf("expected empty string for empty global pool, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGlobalPoolReinit(t *testing.T) {
|
|
||||||
InitAccountPool([]string{"/old1", "/old2"})
|
|
||||||
GetNextAccountConfigDir()
|
|
||||||
InitAccountPool([]string{"/new1"})
|
|
||||||
if got := GetNextAccountConfigDir(); got != "/new1" {
|
|
||||||
t.Fatalf("expected /new1 after reinit, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGlobalPoolFunctionsNoopBeforeInit(t *testing.T) {
|
|
||||||
InitAccountPool(nil)
|
|
||||||
ReportRequestStart("/dir1")
|
|
||||||
ReportRequestEnd("/dir1")
|
|
||||||
ReportRateLimit("/dir1", 1000)
|
|
||||||
}
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
//go:build !windows
|
|
||||||
|
|
||||||
package process
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os/exec"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
func killProcessGroup(c *exec.Cmd) error {
|
|
||||||
if c.Process == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// 殺死整個 process group(負號表示 group)
|
|
||||||
pgid, err := syscall.Getpgid(c.Process.Pid)
|
|
||||||
if err == nil {
|
|
||||||
_ = syscall.Kill(-pgid, syscall.SIGKILL)
|
|
||||||
}
|
|
||||||
// 同時也 kill 主程序,以防萬一
|
|
||||||
return c.Process.Kill()
|
|
||||||
}
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
//go:build windows
|
|
||||||
|
|
||||||
package process
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os/exec"
|
|
||||||
)
|
|
||||||
|
|
||||||
func killProcessGroup(c *exec.Cmd) error {
|
|
||||||
if c.Process == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.Process.Kill()
|
|
||||||
}
|
|
||||||
|
|
@ -1,250 +0,0 @@
|
||||||
package process
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/env"
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RunResult struct {
|
|
||||||
Code int
|
|
||||||
Stdout string
|
|
||||||
Stderr string
|
|
||||||
}
|
|
||||||
|
|
||||||
type RunOptions struct {
|
|
||||||
Cwd string
|
|
||||||
TimeoutMs int
|
|
||||||
MaxMode bool
|
|
||||||
ConfigDir string
|
|
||||||
Ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
type RunStreamingOptions struct {
|
|
||||||
RunOptions
|
|
||||||
OnLine func(line string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ─── Global child process registry ──────────────────────────────────────────
|
|
||||||
|
|
||||||
var (
|
|
||||||
activeMu sync.Mutex
|
|
||||||
activeChildren []*exec.Cmd
|
|
||||||
)
|
|
||||||
|
|
||||||
func registerChild(c *exec.Cmd) {
|
|
||||||
activeMu.Lock()
|
|
||||||
activeChildren = append(activeChildren, c)
|
|
||||||
activeMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func unregisterChild(c *exec.Cmd) {
|
|
||||||
activeMu.Lock()
|
|
||||||
for i, ch := range activeChildren {
|
|
||||||
if ch == c {
|
|
||||||
activeChildren = append(activeChildren[:i], activeChildren[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
activeMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func KillAllChildProcesses() {
|
|
||||||
activeMu.Lock()
|
|
||||||
all := make([]*exec.Cmd, len(activeChildren))
|
|
||||||
copy(all, activeChildren)
|
|
||||||
activeChildren = nil
|
|
||||||
activeMu.Unlock()
|
|
||||||
for _, c := range all {
|
|
||||||
killProcessGroup(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ─── Spawn ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
func spawnChild(cmdStr string, args []string, opts *RunOptions, maxModeFn func(scriptPath, configDir string)) *exec.Cmd {
|
|
||||||
envSrc := env.OsEnvToMap()
|
|
||||||
resolved := env.ResolveAgentCommand(cmdStr, args, envSrc, opts.Cwd)
|
|
||||||
|
|
||||||
if opts.MaxMode && maxModeFn != nil {
|
|
||||||
maxModeFn(resolved.AgentScriptPath, opts.ConfigDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
envMap := make(map[string]string, len(resolved.Env))
|
|
||||||
for k, v := range resolved.Env {
|
|
||||||
envMap[k] = v
|
|
||||||
}
|
|
||||||
if opts.ConfigDir != "" {
|
|
||||||
envMap["CURSOR_CONFIG_DIR"] = opts.ConfigDir
|
|
||||||
} else if resolved.ConfigDir != "" {
|
|
||||||
if _, exists := envMap["CURSOR_CONFIG_DIR"]; !exists {
|
|
||||||
envMap["CURSOR_CONFIG_DIR"] = resolved.ConfigDir
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
envSlice := make([]string, 0, len(envMap))
|
|
||||||
for k, v := range envMap {
|
|
||||||
envSlice = append(envSlice, k+"="+v)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := opts.Ctx
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 WaitDelay 確保 context cancel 後子程序 goroutine 能及時退出
|
|
||||||
c := exec.CommandContext(ctx, resolved.Command, resolved.Args...)
|
|
||||||
c.Dir = opts.Cwd
|
|
||||||
c.Env = envSlice
|
|
||||||
// 設定新的 process group,使 kill 能傳遞給所有子孫程序
|
|
||||||
c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
|
||||||
// WaitDelay:context cancel 後額外等待這麼久再強制關閉 pipes
|
|
||||||
c.WaitDelay = 5 * time.Second
|
|
||||||
// Cancel 函式:殺死整個 process group
|
|
||||||
c.Cancel = func() error {
|
|
||||||
return killProcessGroup(c)
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaxModeFn is set by the agent package to avoid import cycle.
|
|
||||||
var MaxModeFn func(agentScriptPath, configDir string)
|
|
||||||
|
|
||||||
func Run(cmdStr string, args []string, opts RunOptions) (RunResult, error) {
|
|
||||||
ctx := opts.Ctx
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if opts.TimeoutMs > 0 {
|
|
||||||
if ctx == nil {
|
|
||||||
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
|
|
||||||
} else {
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
|
|
||||||
}
|
|
||||||
defer cancel()
|
|
||||||
opts.Ctx = ctx
|
|
||||||
} else if ctx == nil {
|
|
||||||
opts.Ctx = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
c := spawnChild(cmdStr, args, &opts, MaxModeFn)
|
|
||||||
var stdoutBuf, stderrBuf strings.Builder
|
|
||||||
c.Stdout = &stdoutBuf
|
|
||||||
c.Stderr = &stderrBuf
|
|
||||||
|
|
||||||
if err := c.Start(); err != nil {
|
|
||||||
// context 已取消或命令找不到時
|
|
||||||
if opts.Ctx != nil && opts.Ctx.Err() != nil {
|
|
||||||
return RunResult{Code: -1}, nil
|
|
||||||
}
|
|
||||||
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
|
|
||||||
return RunResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
|
|
||||||
}
|
|
||||||
return RunResult{}, err
|
|
||||||
}
|
|
||||||
registerChild(c)
|
|
||||||
defer unregisterChild(c)
|
|
||||||
|
|
||||||
err := c.Wait()
|
|
||||||
code := 0
|
|
||||||
if err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
code = exitErr.ExitCode()
|
|
||||||
if code == 0 {
|
|
||||||
code = -1
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// context cancelled or killed — return -1 but no error
|
|
||||||
return RunResult{Code: -1, Stdout: stdoutBuf.String(), Stderr: stderrBuf.String()}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return RunResult{
|
|
||||||
Code: code,
|
|
||||||
Stdout: stdoutBuf.String(),
|
|
||||||
Stderr: stderrBuf.String(),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type StreamResult struct {
|
|
||||||
Code int
|
|
||||||
Stderr string
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunStreaming(cmdStr string, args []string, opts RunStreamingOptions) (StreamResult, error) {
|
|
||||||
ctx := opts.Ctx
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if opts.TimeoutMs > 0 {
|
|
||||||
if ctx == nil {
|
|
||||||
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(opts.TimeoutMs)*time.Millisecond)
|
|
||||||
} else {
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(opts.TimeoutMs)*time.Millisecond)
|
|
||||||
}
|
|
||||||
defer cancel()
|
|
||||||
opts.RunOptions.Ctx = ctx
|
|
||||||
} else if opts.RunOptions.Ctx == nil {
|
|
||||||
opts.RunOptions.Ctx = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
c := spawnChild(cmdStr, args, &opts.RunOptions, MaxModeFn)
|
|
||||||
stdoutPipe, err := c.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return StreamResult{}, err
|
|
||||||
}
|
|
||||||
stderrPipe, err := c.StderrPipe()
|
|
||||||
if err != nil {
|
|
||||||
return StreamResult{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Start(); err != nil {
|
|
||||||
if strings.Contains(err.Error(), "exec: ") || strings.Contains(err.Error(), "no such file") {
|
|
||||||
return StreamResult{}, fmt.Errorf("command not found: %s. Install Cursor CLI (agent) or set CURSOR_AGENT_BIN to its path", cmdStr)
|
|
||||||
}
|
|
||||||
return StreamResult{}, err
|
|
||||||
}
|
|
||||||
registerChild(c)
|
|
||||||
defer unregisterChild(c)
|
|
||||||
|
|
||||||
var stderrBuf strings.Builder
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
scanner := bufio.NewScanner(stdoutPipe)
|
|
||||||
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if strings.TrimSpace(line) != "" {
|
|
||||||
opts.OnLine(line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
scanner := bufio.NewScanner(stderrPipe)
|
|
||||||
scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024)
|
|
||||||
for scanner.Scan() {
|
|
||||||
stderrBuf.WriteString(scanner.Text())
|
|
||||||
stderrBuf.WriteString("\n")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
err = c.Wait()
|
|
||||||
code := 0
|
|
||||||
if err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
code = exitErr.ExitCode()
|
|
||||||
if code == 0 {
|
|
||||||
code = -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return StreamResult{Code: code, Stderr: stderrBuf.String()}, nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,283 +0,0 @@
|
||||||
package process_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/process"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// sh 是跨平台 shell 執行小 script 的輔助函式
|
|
||||||
func sh(t *testing.T, script string, opts process.RunOptions) (process.RunResult, error) {
|
|
||||||
t.Helper()
|
|
||||||
return process.Run("sh", []string{"-c", script}, opts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_StdoutAndStderr(t *testing.T) {
|
|
||||||
result, err := sh(t, "echo hello; echo world >&2", process.RunOptions{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Code != 0 {
|
|
||||||
t.Errorf("Code = %d, want 0", result.Code)
|
|
||||||
}
|
|
||||||
if result.Stdout != "hello\n" {
|
|
||||||
t.Errorf("Stdout = %q, want %q", result.Stdout, "hello\n")
|
|
||||||
}
|
|
||||||
if result.Stderr != "world\n" {
|
|
||||||
t.Errorf("Stderr = %q, want %q", result.Stderr, "world\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_BasicSpawn(t *testing.T) {
|
|
||||||
result, err := sh(t, "printf ok", process.RunOptions{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Code != 0 {
|
|
||||||
t.Errorf("Code = %d, want 0", result.Code)
|
|
||||||
}
|
|
||||||
if result.Stdout != "ok" {
|
|
||||||
t.Errorf("Stdout = %q, want ok", result.Stdout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_ConfigDir_Propagated(t *testing.T) {
|
|
||||||
result, err := process.Run("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
|
|
||||||
process.RunOptions{ConfigDir: "/test/account/dir"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Stdout != "/test/account/dir" {
|
|
||||||
t.Errorf("Stdout = %q, want /test/account/dir", result.Stdout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_ConfigDir_Absent(t *testing.T) {
|
|
||||||
// 確保沒有殘留的環境變數
|
|
||||||
_ = os.Unsetenv("CURSOR_CONFIG_DIR")
|
|
||||||
result, err := process.Run("sh", []string{"-c", `printf "${CURSOR_CONFIG_DIR:-unset}"`},
|
|
||||||
process.RunOptions{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Stdout != "unset" {
|
|
||||||
t.Errorf("Stdout = %q, want unset", result.Stdout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_NonZeroExit(t *testing.T) {
|
|
||||||
result, err := sh(t, "exit 42", process.RunOptions{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Code != 42 {
|
|
||||||
t.Errorf("Code = %d, want 42", result.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_Timeout(t *testing.T) {
|
|
||||||
start := time.Now()
|
|
||||||
result, err := sh(t, "sleep 30", process.RunOptions{TimeoutMs: 300})
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Code == 0 {
|
|
||||||
t.Error("expected non-zero exit code after timeout")
|
|
||||||
}
|
|
||||||
if elapsed > 2*time.Second {
|
|
||||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunStreaming_OnLine(t *testing.T) {
|
|
||||||
var lines []string
|
|
||||||
result, err := process.RunStreaming("sh", []string{"-c", "printf 'a\nb\nc\n'"},
|
|
||||||
process.RunStreamingOptions{
|
|
||||||
OnLine: func(line string) { lines = append(lines, line) },
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Code != 0 {
|
|
||||||
t.Errorf("Code = %d, want 0", result.Code)
|
|
||||||
}
|
|
||||||
if len(lines) != 3 {
|
|
||||||
t.Errorf("got %d lines, want 3: %v", len(lines), lines)
|
|
||||||
}
|
|
||||||
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
|
|
||||||
t.Errorf("lines = %v, want [a b c]", lines)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunStreaming_FlushFinalLine(t *testing.T) {
|
|
||||||
var lines []string
|
|
||||||
result, err := process.RunStreaming("sh", []string{"-c", "printf tail"},
|
|
||||||
process.RunStreamingOptions{
|
|
||||||
OnLine: func(line string) { lines = append(lines, line) },
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Code != 0 {
|
|
||||||
t.Errorf("Code = %d, want 0", result.Code)
|
|
||||||
}
|
|
||||||
if len(lines) != 1 {
|
|
||||||
t.Errorf("got %d lines, want 1: %v", len(lines), lines)
|
|
||||||
}
|
|
||||||
if lines[0] != "tail" {
|
|
||||||
t.Errorf("lines[0] = %q, want tail", lines[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunStreaming_ConfigDir(t *testing.T) {
|
|
||||||
var lines []string
|
|
||||||
_, err := process.RunStreaming("sh", []string{"-c", `printf "$CURSOR_CONFIG_DIR"`},
|
|
||||||
process.RunStreamingOptions{
|
|
||||||
RunOptions: process.RunOptions{ConfigDir: "/my/config/dir"},
|
|
||||||
OnLine: func(line string) { lines = append(lines, line) },
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(lines) != 1 || lines[0] != "/my/config/dir" {
|
|
||||||
t.Errorf("lines = %v, want [/my/config/dir]", lines)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunStreaming_Stderr(t *testing.T) {
|
|
||||||
result, err := process.RunStreaming("sh", []string{"-c", "echo err-output >&2"},
|
|
||||||
process.RunStreamingOptions{OnLine: func(string) {}})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Stderr == "" {
|
|
||||||
t.Error("expected stderr to contain output")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunStreaming_Timeout(t *testing.T) {
|
|
||||||
start := time.Now()
|
|
||||||
result, err := process.RunStreaming("sh", []string{"-c", "sleep 30"},
|
|
||||||
process.RunStreamingOptions{
|
|
||||||
RunOptions: process.RunOptions{TimeoutMs: 300},
|
|
||||||
OnLine: func(string) {},
|
|
||||||
})
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if result.Code == 0 {
|
|
||||||
t.Error("expected non-zero exit code after timeout")
|
|
||||||
}
|
|
||||||
if elapsed > 2*time.Second {
|
|
||||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunStreaming_Concurrent(t *testing.T) {
|
|
||||||
var lines1, lines2 []string
|
|
||||||
done := make(chan struct{}, 2)
|
|
||||||
|
|
||||||
run := func(label string, target *[]string) {
|
|
||||||
process.RunStreaming("sh", []string{"-c", "printf '" + label + "'"},
|
|
||||||
process.RunStreamingOptions{
|
|
||||||
OnLine: func(line string) { *target = append(*target, line) },
|
|
||||||
})
|
|
||||||
done <- struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
go run("stream1", &lines1)
|
|
||||||
go run("stream2", &lines2)
|
|
||||||
|
|
||||||
<-done
|
|
||||||
<-done
|
|
||||||
|
|
||||||
if len(lines1) != 1 || lines1[0] != "stream1" {
|
|
||||||
t.Errorf("lines1 = %v, want [stream1]", lines1)
|
|
||||||
}
|
|
||||||
if len(lines2) != 1 || lines2[0] != "stream2" {
|
|
||||||
t.Errorf("lines2 = %v, want [stream2]", lines2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunStreaming_ContextCancel(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
start := time.Now()
|
|
||||||
done := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
process.RunStreaming("sh", []string{"-c", "sleep 30"},
|
|
||||||
process.RunStreamingOptions{
|
|
||||||
RunOptions: process.RunOptions{Ctx: ctx},
|
|
||||||
OnLine: func(string) {},
|
|
||||||
})
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.AfterFunc(100*time.Millisecond, cancel)
|
|
||||||
<-done
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
if elapsed > 2*time.Second {
|
|
||||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_ContextCancel(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
start := time.Now()
|
|
||||||
done := make(chan process.RunResult, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
|
|
||||||
done <- r
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.AfterFunc(100*time.Millisecond, cancel)
|
|
||||||
result := <-done
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
if result.Code == 0 {
|
|
||||||
t.Error("expected non-zero exit code after cancel")
|
|
||||||
}
|
|
||||||
if elapsed > 2*time.Second {
|
|
||||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_AlreadyCancelledContext(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel() // 已取消
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
result, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{Ctx: ctx})
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
if result.Code == 0 {
|
|
||||||
t.Error("expected non-zero exit code")
|
|
||||||
}
|
|
||||||
if elapsed > 2*time.Second {
|
|
||||||
t.Errorf("elapsed %v, want < 2s", elapsed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKillAllChildProcesses(t *testing.T) {
|
|
||||||
done := make(chan process.RunResult, 1)
|
|
||||||
go func() {
|
|
||||||
r, _ := process.Run("sh", []string{"-c", "sleep 30"}, process.RunOptions{})
|
|
||||||
done <- r
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(80 * time.Millisecond)
|
|
||||||
process.KillAllChildProcesses()
|
|
||||||
result := <-done
|
|
||||||
|
|
||||||
if result.Code == 0 {
|
|
||||||
t.Error("expected non-zero exit code after kill")
|
|
||||||
}
|
|
||||||
// 再次呼叫不應 panic
|
|
||||||
process.KillAllChildProcesses()
|
|
||||||
}
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
package cursor
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/apitypes"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Provider struct {
|
|
||||||
cfg config.BridgeConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProvider(cfg config.BridgeConfig) *Provider {
|
|
||||||
return &Provider{cfg: cfg}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) Name() string {
|
|
||||||
return "cursor"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,32 +0,0 @@
|
||||||
package providers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/apitypes"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/providers/cursor"
|
|
||||||
"cursor-api-proxy/internal/providers/geminiweb"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Provider interface {
|
|
||||||
Name() string
|
|
||||||
Close() error
|
|
||||||
Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProvider(cfg config.BridgeConfig) (Provider, error) {
|
|
||||||
providerType := cfg.Provider
|
|
||||||
if providerType == "" {
|
|
||||||
providerType = "cursor"
|
|
||||||
}
|
|
||||||
|
|
||||||
switch providerType {
|
|
||||||
case "cursor":
|
|
||||||
return cursor.NewProvider(cfg), nil
|
|
||||||
case "gemini-web":
|
|
||||||
return geminiweb.NewPlaywrightProvider(cfg)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown provider: %s", providerType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,125 +0,0 @@
|
||||||
package geminiweb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-rod/rod"
|
|
||||||
"github.com/go-rod/rod/lib/launcher"
|
|
||||||
"github.com/go-rod/rod/lib/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Browser struct {
|
|
||||||
browser *rod.Browser
|
|
||||||
visible bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBrowser(visible bool) (*Browser, error) {
|
|
||||||
l := launcher.New()
|
|
||||||
if visible {
|
|
||||||
l = l.Headless(false)
|
|
||||||
} else {
|
|
||||||
l = l.Headless(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
url, err := l.Launch()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to launch browser: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
b := rod.New().ControlURL(url)
|
|
||||||
if err := b.Connect(); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to connect browser: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Browser{browser: b, visible: visible}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Browser) Close() error {
|
|
||||||
if b.browser != nil {
|
|
||||||
return b.browser.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Browser) NewPage() (*rod.Page, error) {
|
|
||||||
return b.browser.Page(proto.TargetCreateTarget{URL: "about:blank"})
|
|
||||||
}
|
|
||||||
|
|
||||||
type Cookie struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Value string `json:"value"`
|
|
||||||
Domain string `json:"domain"`
|
|
||||||
Path string `json:"path"`
|
|
||||||
Expires float64 `json:"expires"`
|
|
||||||
HTTPOnly bool `json:"httpOnly"`
|
|
||||||
Secure bool `json:"secure"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadCookiesFromFile(cookieFile string) ([]Cookie, error) {
|
|
||||||
data, err := os.ReadFile(cookieFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read cookies: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cookies []Cookie
|
|
||||||
if err := json.Unmarshal(data, &cookies); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse cookies: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cookies, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SaveCookiesToFile(cookies []Cookie, cookieFile string) error {
|
|
||||||
data, err := json.MarshalIndent(cookies, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal cookies: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dir := filepath.Dir(cookieFile)
|
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create cookie dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(cookieFile, data, 0644); err != nil {
|
|
||||||
return fmt.Errorf("failed to write cookies: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetCookiesOnPage(page *rod.Page, cookies []Cookie) error {
|
|
||||||
var protoCookies []*proto.NetworkCookieParam
|
|
||||||
for _, c := range cookies {
|
|
||||||
p := &proto.NetworkCookieParam{
|
|
||||||
Name: c.Name,
|
|
||||||
Value: c.Value,
|
|
||||||
Domain: c.Domain,
|
|
||||||
Path: c.Path,
|
|
||||||
HTTPOnly: c.HTTPOnly,
|
|
||||||
Secure: c.Secure,
|
|
||||||
}
|
|
||||||
if c.Expires > 0 {
|
|
||||||
exp := proto.TimeSinceEpoch(c.Expires)
|
|
||||||
p.Expires = exp
|
|
||||||
}
|
|
||||||
protoCookies = append(protoCookies, p)
|
|
||||||
}
|
|
||||||
return page.SetCookies(protoCookies)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WaitForElement(page *rod.Page, selector string, timeout time.Duration) (*rod.Element, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
return page.Context(ctx).Element(selector)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WaitForElements(page *rod.Page, selector string, timeout time.Duration) (rod.Elements, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
return page.Context(ctx).Elements(selector)
|
|
||||||
}
|
|
||||||
|
|
@ -1,173 +0,0 @@
|
||||||
package geminiweb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/go-rod/rod"
|
|
||||||
"github.com/go-rod/rod/lib/launcher"
|
|
||||||
"github.com/go-rod/rod/lib/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BrowserManager 管理瀏覽器實例的生命週期
|
|
||||||
type BrowserManager struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
browser *rod.Browser
|
|
||||||
userDataDir string
|
|
||||||
page *rod.Page
|
|
||||||
visible bool
|
|
||||||
isRunning bool
|
|
||||||
currentModel string
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
globalManager *BrowserManager
|
|
||||||
globalMu sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetBrowserManager 獲取全域瀏覽器管理器(單例)
|
|
||||||
func GetBrowserManager(userDataDir string, visible bool) (*BrowserManager, error) {
|
|
||||||
globalMu.Lock()
|
|
||||||
defer globalMu.Unlock()
|
|
||||||
|
|
||||||
if globalManager != nil {
|
|
||||||
return globalManager, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := NewBrowserManager(userDataDir, visible)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
globalManager = manager
|
|
||||||
return globalManager, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBrowserManager 建立新的瀏覽器管理器
|
|
||||||
func NewBrowserManager(userDataDir string, visible bool) (*BrowserManager, error) {
|
|
||||||
cleanLockFiles(userDataDir)
|
|
||||||
|
|
||||||
if err := os.MkdirAll(userDataDir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create user data dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &BrowserManager{
|
|
||||||
userDataDir: userDataDir,
|
|
||||||
visible: visible,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanLockFiles 清理 Chrome 的殘留鎖檔案
|
|
||||||
func cleanLockFiles(userDataDir string) {
|
|
||||||
lockFiles := []string{
|
|
||||||
"SingletonLock",
|
|
||||||
"SingletonCookie",
|
|
||||||
"SingletonSocket",
|
|
||||||
"Default/SingletonLock",
|
|
||||||
"Default/SingletonCookie",
|
|
||||||
"Default/SingletonSocket",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, file := range lockFiles {
|
|
||||||
path := filepath.Join(userDataDir, file)
|
|
||||||
os.Remove(path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch 啟動瀏覽器(如果尚未啟動)
|
|
||||||
func (m *BrowserManager) Launch() error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if m.isRunning && m.browser != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
l := launcher.New()
|
|
||||||
|
|
||||||
if m.visible {
|
|
||||||
l = l.Headless(false)
|
|
||||||
} else {
|
|
||||||
l = l.Headless(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
l = l.UserDataDir(m.userDataDir)
|
|
||||||
|
|
||||||
url, err := l.Launch()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to launch browser: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
b := rod.New().ControlURL(url)
|
|
||||||
if err := b.Connect(); err != nil {
|
|
||||||
return fmt.Errorf("failed to connect browser: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.browser = b
|
|
||||||
|
|
||||||
page, err := b.Page(proto.TargetCreateTarget{URL: "about:blank"})
|
|
||||||
if err != nil {
|
|
||||||
_ = b.Close()
|
|
||||||
return fmt.Errorf("failed to create page: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.page = page
|
|
||||||
m.isRunning = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPage 獲取頁面
|
|
||||||
func (m *BrowserManager) GetPage() (*rod.Page, error) {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if !m.isRunning || m.browser == nil {
|
|
||||||
return nil, fmt.Errorf("browser not running")
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.page, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close 關閉瀏覽器
|
|
||||||
func (m *BrowserManager) Close() error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if !m.isRunning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if m.browser != nil {
|
|
||||||
err = m.browser.Close()
|
|
||||||
m.browser = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m.page = nil
|
|
||||||
m.isRunning = false
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsRunning 檢查瀏覽器是否正在運行
|
|
||||||
func (m *BrowserManager) IsRunning() bool {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
return m.isRunning
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCurrentModel 設定當前模型
|
|
||||||
func (m *BrowserManager) SetCurrentModel(model string) {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
m.currentModel = model
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCurrentModel 獲取當前模型
|
|
||||||
func (m *BrowserManager) GetCurrentModel() string {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
return m.currentModel
|
|
||||||
}
|
|
||||||
|
|
@ -1,250 +0,0 @@
|
||||||
package geminiweb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-rod/rod"
|
|
||||||
)
|
|
||||||
|
|
||||||
const geminiURL = "https://gemini.google.com/app"
|
|
||||||
|
|
||||||
// 輸入框選擇器(依優先順序)
|
|
||||||
var inputSelectors = []string{
|
|
||||||
".ProseMirror",
|
|
||||||
"rich-textarea",
|
|
||||||
"div[role='textbox'][contenteditable='true']",
|
|
||||||
"div[contenteditable='true']",
|
|
||||||
"textarea",
|
|
||||||
}
|
|
||||||
|
|
||||||
// NavigateToGemini 導航到 Gemini
|
|
||||||
func NavigateToGemini(page *rod.Page) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := page.Context(ctx).Navigate(geminiURL); err != nil {
|
|
||||||
return fmt.Errorf("failed to navigate: %w", err)
|
|
||||||
}
|
|
||||||
return page.Context(ctx).WaitLoad()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsLoggedIn 檢查是否已登入
|
|
||||||
func IsLoggedIn(page *rod.Page) bool {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
for _, sel := range inputSelectors {
|
|
||||||
if _, err := page.Context(ctx).Element(sel); err == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectModel 選擇模型(可選)
|
|
||||||
func SelectModel(page *rod.Page, model string) error {
|
|
||||||
fmt.Printf("[GeminiWeb] Model selection skipped (using current model)\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TypeInput 在輸入框中輸入文字
|
|
||||||
func TypeInput(page *rod.Page, text string) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
fmt.Println("[GeminiWeb] Looking for input field...")
|
|
||||||
|
|
||||||
// 1. 嘗試所有選擇器
|
|
||||||
var inputEl *rod.Element
|
|
||||||
var err error
|
|
||||||
|
|
||||||
for _, sel := range inputSelectors {
|
|
||||||
fmt.Printf(" Trying: %s\n", sel)
|
|
||||||
inputEl, err = page.Context(ctx).Element(sel)
|
|
||||||
if err == nil {
|
|
||||||
fmt.Printf(" ✓ Found with: %s\n", sel)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
// 2. Fallback: 嘗試等待頁面載入完成後重試
|
|
||||||
fmt.Println("[GeminiWeb] Waiting for page to fully load...")
|
|
||||||
time.Sleep(3 * time.Second)
|
|
||||||
|
|
||||||
for _, sel := range inputSelectors {
|
|
||||||
fmt.Printf(" Retrying: %s\n", sel)
|
|
||||||
inputEl, err = page.Context(ctx).Element(sel)
|
|
||||||
if err == nil {
|
|
||||||
fmt.Printf(" ✓ Found with: %s\n", sel)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
// 3. Debug: 印出頁面標題和 URL
|
|
||||||
info, _ := page.Info()
|
|
||||||
fmt.Printf("[GeminiWeb] DEBUG: URL=%s Title=%s\n", info.URL, info.Title)
|
|
||||||
|
|
||||||
// 4. Fallback: 嘗試更通用的選擇器
|
|
||||||
fmt.Println("[GeminiWeb] Trying generic selectors...")
|
|
||||||
genericSelectors := []string{
|
|
||||||
"div[contenteditable]",
|
|
||||||
"[contenteditable]",
|
|
||||||
"textarea",
|
|
||||||
"input[type='text']",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sel := range genericSelectors {
|
|
||||||
fmt.Printf(" Trying generic: %s\n", sel)
|
|
||||||
inputEl, err = page.Context(ctx).Element(sel)
|
|
||||||
if err == nil {
|
|
||||||
fmt.Printf(" ✓ Found with: %s\n", sel)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
info, _ := page.Info()
|
|
||||||
return fmt.Errorf("input field not found after trying all selectors (URL=%s)", info.URL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Focus 輸入框
|
|
||||||
fmt.Printf("[GeminiWeb] Focusing input field...\n")
|
|
||||||
if err := inputEl.Focus(); err != nil {
|
|
||||||
return fmt.Errorf("failed to focus input: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// 3. 使用 Input 方法
|
|
||||||
fmt.Printf("[GeminiWeb] Typing %d chars...\n", len(text))
|
|
||||||
if err := inputEl.Input(text); err != nil {
|
|
||||||
return fmt.Errorf("failed to input text: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
fmt.Println("[GeminiWeb] Input complete")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClickSend 發送訊息
|
|
||||||
func ClickSend(page *rod.Page) error {
|
|
||||||
// 方法 1: 按 Enter
|
|
||||||
if err := page.Keyboard.Press('\r'); err != nil {
|
|
||||||
return fmt.Errorf("failed to press Enter: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitForReady 等待頁面空閒
|
|
||||||
func WaitForReady(page *rod.Page) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
fmt.Println("[GeminiWeb] Checking if page is ready...")
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
fmt.Println("[GeminiWeb] Page ready check timeout, proceeding anyway")
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// 檢查是否有停止按鈕
|
|
||||||
hasStopBtn := false
|
|
||||||
stopBtns, _ := page.Elements("button[aria-label*='Stop'], button[aria-label*='停止']")
|
|
||||||
for _, btn := range stopBtns {
|
|
||||||
visible, _ := btn.Visible()
|
|
||||||
if visible {
|
|
||||||
hasStopBtn = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasStopBtn {
|
|
||||||
fmt.Println("[GeminiWeb] Page is ready")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtractResponse 提取回應文字
|
|
||||||
func ExtractResponse(page *rod.Page) (string, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var lastText string
|
|
||||||
lastUpdate := time.Now()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
if lastText != "" {
|
|
||||||
return lastText, nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("response timeout")
|
|
||||||
default:
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// 尋找回應文字
|
|
||||||
for _, sel := range responseSelectors {
|
|
||||||
elements, err := page.Elements(sel)
|
|
||||||
if err != nil || len(elements) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 取得最後一個元素的文字
|
|
||||||
lastEl := elements[len(elements)-1]
|
|
||||||
text, err := lastEl.Text()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
text = strings.TrimSpace(text)
|
|
||||||
if text != "" && text != lastText && len(text) > len(lastText) {
|
|
||||||
lastText = text
|
|
||||||
lastUpdate = time.Now()
|
|
||||||
fmt.Printf("[GeminiWeb] Response length: %d\n", len(text))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查是否已完成(2 秒內沒有新內容)
|
|
||||||
if time.Since(lastUpdate) > 2*time.Second && lastText != "" {
|
|
||||||
// 最後檢查一次是否還有停止按鈕
|
|
||||||
hasStopBtn := false
|
|
||||||
stopBtns, _ := page.Elements("button[aria-label*='Stop'], button[aria-label*='停止']")
|
|
||||||
for _, btn := range stopBtns {
|
|
||||||
visible, _ := btn.Visible()
|
|
||||||
if visible {
|
|
||||||
hasStopBtn = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasStopBtn {
|
|
||||||
return lastText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 默認的回應選擇器
|
|
||||||
var responseSelectors = []string{
|
|
||||||
".model-response-text",
|
|
||||||
".message-content",
|
|
||||||
".markdown",
|
|
||||||
".prose",
|
|
||||||
"model-response",
|
|
||||||
}
|
|
||||||
|
|
@ -1,641 +0,0 @@
|
||||||
package geminiweb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/apitypes"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/playwright-community/playwright-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PlaywrightProvider 使用 Playwright 的 Gemini Provider
|
|
||||||
type PlaywrightProvider struct {
|
|
||||||
cfg config.BridgeConfig
|
|
||||||
pw *playwright.Playwright
|
|
||||||
browser playwright.Browser
|
|
||||||
context playwright.BrowserContext
|
|
||||||
page playwright.Page
|
|
||||||
mu sync.Mutex
|
|
||||||
userDataDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
playwrightInstance *playwright.Playwright
|
|
||||||
playwrightOnce sync.Once
|
|
||||||
playwrightErr error
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewPlaywrightProvider 建立新的 Playwright Provider
|
|
||||||
func NewPlaywrightProvider(cfg config.BridgeConfig) (*PlaywrightProvider, error) {
|
|
||||||
// 確保 Playwright 已初始化(單例)
|
|
||||||
playwrightOnce.Do(func() {
|
|
||||||
playwrightInstance, playwrightErr = playwright.Run()
|
|
||||||
if playwrightErr != nil {
|
|
||||||
playwrightErr = fmt.Errorf("failed to run playwright: %w", playwrightErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if playwrightErr != nil {
|
|
||||||
return nil, playwrightErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清理 Chrome 鎖檔案
|
|
||||||
userDataDir := filepath.Join(cfg.GeminiAccountDir, "default-session")
|
|
||||||
cleanLockFiles(userDataDir)
|
|
||||||
|
|
||||||
// 確保目錄存在
|
|
||||||
if err := os.MkdirAll(userDataDir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create user data dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PlaywrightProvider{
|
|
||||||
cfg: cfg,
|
|
||||||
pw: playwrightInstance,
|
|
||||||
userDataDir: userDataDir,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getName 返回 Provider 名稱
|
|
||||||
func (p *PlaywrightProvider) Name() string {
|
|
||||||
return "gemini-web"
|
|
||||||
}
|
|
||||||
|
|
||||||
// launchIfNeeded 如果需要則啟動瀏覽器
|
|
||||||
func (p *PlaywrightProvider) launchIfNeeded() error {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
if p.context != nil && p.page != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("[GeminiWeb] Launching Chromium...")
|
|
||||||
|
|
||||||
// 使用 LaunchPersistentContext(自動保存 session)
|
|
||||||
context, err := p.pw.Chromium.LaunchPersistentContext(p.userDataDir,
|
|
||||||
playwright.BrowserTypeLaunchPersistentContextOptions{
|
|
||||||
Headless: playwright.Bool(!p.cfg.GeminiBrowserVisible),
|
|
||||||
Args: []string{
|
|
||||||
"--no-first-run",
|
|
||||||
"--no-default-browser-check",
|
|
||||||
"--disable-background-networking",
|
|
||||||
"--disable-extensions",
|
|
||||||
"--disable-plugins",
|
|
||||||
"--disable-sync",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to launch persistent context: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.context = context
|
|
||||||
|
|
||||||
// 取得或建立頁面
|
|
||||||
pages := context.Pages()
|
|
||||||
if len(pages) > 0 {
|
|
||||||
p.page = pages[0]
|
|
||||||
} else {
|
|
||||||
page, err := context.NewPage()
|
|
||||||
if err != nil {
|
|
||||||
_ = context.Close()
|
|
||||||
return fmt.Errorf("failed to create page: %w", err)
|
|
||||||
}
|
|
||||||
p.page = page
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("[GeminiWeb] Browser launched")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate 生成回應
|
|
||||||
func (p *PlaywrightProvider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) (err error) {
|
|
||||||
// 確保在返回錯誤時保存診斷
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("[GeminiWeb] Error occurred, saving diagnostics...")
|
|
||||||
_ = p.saveDiagnostics()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
|
|
||||||
|
|
||||||
// 1. 確保瀏覽器已啟動
|
|
||||||
if err := p.launchIfNeeded(); err != nil {
|
|
||||||
return fmt.Errorf("failed to launch browser: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 導航到 Gemini(如果需要)
|
|
||||||
currentURL := p.page.URL()
|
|
||||||
if !strings.Contains(currentURL, "gemini.google.com") {
|
|
||||||
fmt.Println("[GeminiWeb] Navigating to Gemini...")
|
|
||||||
if _, err := p.page.Goto("https://gemini.google.com/app", playwright.PageGotoOptions{
|
|
||||||
WaitUntil: playwright.WaitUntilStateDomcontentloaded,
|
|
||||||
Timeout: playwright.Float(60000),
|
|
||||||
}); err != nil {
|
|
||||||
return fmt.Errorf("failed to navigate: %w", err)
|
|
||||||
}
|
|
||||||
// 額外等待 JavaScript 載入
|
|
||||||
fmt.Println("[GeminiWeb] Waiting for page to initialize...")
|
|
||||||
time.Sleep(3 * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 調試模式:等待用戶確認
|
|
||||||
if p.cfg.GeminiBrowserVisible {
|
|
||||||
fmt.Println("\n" + strings.Repeat("=", 70))
|
|
||||||
fmt.Println("🔍 調試模式:瀏覽器已開啟")
|
|
||||||
fmt.Println("請檢查瀏覽器畫面,然後按 ENTER 繼續...")
|
|
||||||
fmt.Println("如果有問題,請查看: /tmp/gemini-debug.*")
|
|
||||||
fmt.Println(strings.Repeat("=", 70))
|
|
||||||
|
|
||||||
var input string
|
|
||||||
fmt.Scanln(&input)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 等待頁面完全載入(project-golem 策略)
|
|
||||||
fmt.Println("[GeminiWeb] Waiting for page to be ready...")
|
|
||||||
if err := p.waitForPageReady(); err != nil {
|
|
||||||
fmt.Printf("[GeminiWeb] Warning: %v\n", err)
|
|
||||||
|
|
||||||
// 額外調試:輸出頁面 HTML 結構
|
|
||||||
if p.cfg.GeminiBrowserVisible {
|
|
||||||
html, _ := p.page.Content()
|
|
||||||
debugPath := "/tmp/gemini-debug.html"
|
|
||||||
if err := os.WriteFile(debugPath, []byte(html), 0644); err == nil {
|
|
||||||
fmt.Printf("[GeminiWeb] HTML saved to: %s\n", debugPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 檢查登入狀態
|
|
||||||
fmt.Println("[GeminiWeb] Checking login status...")
|
|
||||||
loggedIn := p.isLoggedIn()
|
|
||||||
if !loggedIn {
|
|
||||||
fmt.Println("[GeminiWeb] Not logged in, continuing anyway")
|
|
||||||
if p.cfg.GeminiBrowserVisible {
|
|
||||||
fmt.Println("\n========================================")
|
|
||||||
fmt.Println("Browser is open. You can:")
|
|
||||||
fmt.Println("1. Log in to Gemini now")
|
|
||||||
fmt.Println("2. Continue without login")
|
|
||||||
fmt.Println("========================================\n")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fmt.Println("[GeminiWeb] ✓ Logged in")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. 選擇模型(如果支援)
|
|
||||||
if err := p.selectModel(model); err != nil {
|
|
||||||
fmt.Printf("[GeminiWeb] Warning: model selection failed: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 6. 建構提示詞
|
|
||||||
prompt := buildPromptFromMessagesPlaywright(messages)
|
|
||||||
fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt))
|
|
||||||
|
|
||||||
// 7. 輸入文字(使用 Playwright 的 Auto-wait)
|
|
||||||
if err := p.typeInput(prompt); err != nil {
|
|
||||||
return fmt.Errorf("failed to type: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 7. 發送訊息
|
|
||||||
fmt.Println("[GeminiWeb] Sending message...")
|
|
||||||
if err := p.sendMessage(); err != nil {
|
|
||||||
return fmt.Errorf("failed to send: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 8. 提取回應
|
|
||||||
fmt.Println("[GeminiWeb] Waiting for response...")
|
|
||||||
response, err := p.extractResponse()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to extract response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 9. 回調
|
|
||||||
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response})
|
|
||||||
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
|
|
||||||
|
|
||||||
fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close 關閉 Provider
|
|
||||||
func (p *PlaywrightProvider) Close() error {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
if p.context != nil {
|
|
||||||
if err := p.context.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p.context = nil
|
|
||||||
p.page = nil
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// saveDiagnostics 保存診斷信息
|
|
||||||
func (p *PlaywrightProvider) saveDiagnostics() error {
|
|
||||||
if p.page == nil {
|
|
||||||
return fmt.Errorf("no page available")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 截圖
|
|
||||||
screenshotPath := "/tmp/gemini-debug.png"
|
|
||||||
if _, err := p.page.Screenshot(playwright.PageScreenshotOptions{
|
|
||||||
Path: playwright.String(screenshotPath),
|
|
||||||
}); err == nil {
|
|
||||||
fmt.Printf("[GeminiWeb] Screenshot saved: %s\n", screenshotPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTML
|
|
||||||
htmlPath := "/tmp/gemini-debug.html"
|
|
||||||
if html, err := p.page.Content(); err == nil {
|
|
||||||
if err := os.WriteFile(htmlPath, []byte(html), 0644); err == nil {
|
|
||||||
fmt.Printf("[GeminiWeb] HTML saved: %s\n", htmlPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 輸出頁面信息
|
|
||||||
url := p.page.URL()
|
|
||||||
title, _ := p.page.Title()
|
|
||||||
fmt.Printf("[GeminiWeb] Diagnostics: URL=%s, Title=%s\n", url, title)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitForPageReady 等待頁面完全就緒(project-golem 策略)
|
|
||||||
func (p *PlaywrightProvider) waitForPageReady() error {
|
|
||||||
fmt.Println("[GeminiWeb] Checking for ready state...")
|
|
||||||
|
|
||||||
// 1. 等待停止按鈕消失(如果存在)
|
|
||||||
_, _ = p.page.WaitForSelector("button[aria-label*='Stop'], button[aria-label*='停止']", playwright.PageWaitForSelectorOptions{
|
|
||||||
State: playwright.WaitForSelectorStateDetached,
|
|
||||||
Timeout: playwright.Float(5000),
|
|
||||||
})
|
|
||||||
|
|
||||||
// 2. 嘗試多種等待策略
|
|
||||||
inputSelectors := []string{
|
|
||||||
".ql-editor.ql-blank",
|
|
||||||
".ql-editor",
|
|
||||||
"div[contenteditable='true'][role='textbox']",
|
|
||||||
"div[contenteditable='true']",
|
|
||||||
".ProseMirror",
|
|
||||||
"rich-textarea",
|
|
||||||
"textarea",
|
|
||||||
}
|
|
||||||
|
|
||||||
// 策略 A: 等待任一輸入框出現
|
|
||||||
for i, sel := range inputSelectors {
|
|
||||||
fmt.Printf(" [%d/%d] Waiting for: %s\n", i+1, len(inputSelectors), sel)
|
|
||||||
locator := p.page.Locator(sel)
|
|
||||||
if err := locator.WaitFor(playwright.LocatorWaitForOptions{
|
|
||||||
Timeout: playwright.Float(5000),
|
|
||||||
State: playwright.WaitForSelectorStateVisible,
|
|
||||||
}); err == nil {
|
|
||||||
fmt.Printf(" ✓ Input field found: %s\n", sel)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 策略 B: 等待頁面完全載入
|
|
||||||
fmt.Println("[GeminiWeb] Waiting for page load...")
|
|
||||||
time.Sleep(3 * time.Second)
|
|
||||||
|
|
||||||
// 策略 C: 使用 JavaScript 檢查
|
|
||||||
fmt.Println("[GeminiWeb] Checking with JavaScript...")
|
|
||||||
result, err := p.page.Evaluate(`
|
|
||||||
() => {
|
|
||||||
// 檢查所有可能的輸入元素
|
|
||||||
const selectors = [
|
|
||||||
'.ql-editor.ql-blank',
|
|
||||||
'.ql-editor',
|
|
||||||
'div[contenteditable="true"][role="textbox"]',
|
|
||||||
'div[contenteditable="true"]',
|
|
||||||
'.ProseMirror',
|
|
||||||
'rich-textarea',
|
|
||||||
'textarea'
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const sel of selectors) {
|
|
||||||
const el = document.querySelector(sel);
|
|
||||||
if (el) {
|
|
||||||
return {
|
|
||||||
found: true,
|
|
||||||
selector: sel,
|
|
||||||
tagName: el.tagName,
|
|
||||||
className: el.className,
|
|
||||||
visible: el.offsetParent !== null
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return { found: false };
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
if m, ok := result.(map[string]interface{}); ok {
|
|
||||||
if found, _ := m["found"].(bool); found {
|
|
||||||
sel, _ := m["selector"].(string)
|
|
||||||
fmt.Printf(" ✓ JavaScript found: %s\n", sel)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 策略 D: 調試模式 - 輸出頁面結構
|
|
||||||
if p.cfg.GeminiBrowserVisible {
|
|
||||||
fmt.Println("[GeminiWeb].dump: Page structure analysis")
|
|
||||||
_, _ = p.page.Evaluate(`
|
|
||||||
() => {
|
|
||||||
const allElements = document.querySelectorAll('*');
|
|
||||||
const inputLike = [];
|
|
||||||
for (const el of allElements) {
|
|
||||||
if (el.contentEditable === 'true' ||
|
|
||||||
el.role === 'textbox' ||
|
|
||||||
el.tagName === 'TEXTAREA' ||
|
|
||||||
el.tagName === 'INPUT') {
|
|
||||||
inputLike.push({
|
|
||||||
tag: el.tagName,
|
|
||||||
class: el.className,
|
|
||||||
id: el.id,
|
|
||||||
role: el.role,
|
|
||||||
contentEditable: el.contentEditable
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
console.log('Input-like elements:', inputLike);
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("no input field found after all strategies")
|
|
||||||
}
|
|
||||||
|
|
||||||
// isLoggedIn 檢查是否已登入
|
|
||||||
func (p *PlaywrightProvider) isLoggedIn() bool {
|
|
||||||
// 嘗試找輸入框(登入狀態的主要特徵)
|
|
||||||
selectors := []string{
|
|
||||||
".ProseMirror",
|
|
||||||
"rich-textarea",
|
|
||||||
"div[role='textbox']",
|
|
||||||
"div[contenteditable='true']",
|
|
||||||
"textarea",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sel := range selectors {
|
|
||||||
locator := p.page.Locator(sel)
|
|
||||||
if count, _ := locator.Count(); count > 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// typeInput 輸入文字(使用 Playwright 的 Auto-wait)
|
|
||||||
func (p *PlaywrightProvider) typeInput(text string) error {
|
|
||||||
fmt.Println("[GeminiWeb] Looking for input field...")
|
|
||||||
|
|
||||||
selectors := []string{
|
|
||||||
".ql-editor.ql-blank",
|
|
||||||
".ql-editor",
|
|
||||||
"div[contenteditable='true'][role='textbox']",
|
|
||||||
"div[contenteditable='true']",
|
|
||||||
".ProseMirror",
|
|
||||||
"rich-textarea",
|
|
||||||
"textarea",
|
|
||||||
}
|
|
||||||
|
|
||||||
var inputLocator playwright.Locator
|
|
||||||
var found bool
|
|
||||||
|
|
||||||
for _, sel := range selectors {
|
|
||||||
fmt.Printf(" Trying: %s\n", sel)
|
|
||||||
locator := p.page.Locator(sel)
|
|
||||||
if err := locator.WaitFor(playwright.LocatorWaitForOptions{
|
|
||||||
Timeout: playwright.Float(3000),
|
|
||||||
}); err == nil {
|
|
||||||
inputLocator = locator
|
|
||||||
found = true
|
|
||||||
fmt.Printf(" ✓ Found with: %s\n", sel)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
// 錯誤會被 Generate 的 defer 捕獲並保存診斷
|
|
||||||
url := p.page.URL()
|
|
||||||
title, _ := p.page.Title()
|
|
||||||
return fmt.Errorf("input field not found (URL=%s, Title=%s). Diagnostics will be saved to /tmp/", url, title)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Focus 並填充(Playwright 自動等待)
|
|
||||||
fmt.Printf("[GeminiWeb] Typing %d chars...\n", len(text))
|
|
||||||
if err := inputLocator.Fill(text); err != nil {
|
|
||||||
return fmt.Errorf("failed to fill: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("[GeminiWeb] Input complete")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendMessage 發送訊息
|
|
||||||
func (p *PlaywrightProvider) sendMessage() error {
|
|
||||||
// 方法 1: 按 Enter(最可靠)
|
|
||||||
if err := p.page.Keyboard().Press("Enter"); err != nil {
|
|
||||||
return fmt.Errorf("failed to press Enter: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
// 方法 2: 嘗試點擊發送按鈕(補強)
|
|
||||||
_, _ = p.page.Evaluate(`
|
|
||||||
() => {
|
|
||||||
const keywords = ['發送', 'Send', '傳送'];
|
|
||||||
const buttons = Array.from(document.querySelectorAll('button, [role="button"]'));
|
|
||||||
|
|
||||||
for (const btn of buttons) {
|
|
||||||
const text = (btn.innerText || btn.textContent || '').trim();
|
|
||||||
const label = (btn.getAttribute('aria-label') || '').trim();
|
|
||||||
|
|
||||||
// 跳過停止按鈕
|
|
||||||
if (['停止', 'Stop', '中斷'].includes(text) || label.toLowerCase().includes('stop')) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (keywords.some(kw => text.includes(kw) || label.includes(kw))) {
|
|
||||||
btn.click();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractResponse 提取回應
|
|
||||||
func (p *PlaywrightProvider) extractResponse() (string, error) {
|
|
||||||
var lastText string
|
|
||||||
var stableCount int
|
|
||||||
lastUpdate := time.Now()
|
|
||||||
timeout := 120 * time.Second
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
for time.Since(startTime) < timeout {
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// 使用 JavaScript 提取回應文字(更精確)
|
|
||||||
result, err := p.page.Evaluate(`
|
|
||||||
() => {
|
|
||||||
// 尋找所有可能的回應容器
|
|
||||||
const selectors = [
|
|
||||||
'model-response',
|
|
||||||
'.model-response',
|
|
||||||
'message-content',
|
|
||||||
'.message-content'
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const sel of selectors) {
|
|
||||||
const el = document.querySelector(sel);
|
|
||||||
if (el) {
|
|
||||||
// 嘗試找markdown內容
|
|
||||||
const markdown = el.querySelector('.markdown, .prose, [class*="markdown"]');
|
|
||||||
if (markdown && markdown.innerText.trim()) {
|
|
||||||
let text = markdown.innerText.trim();
|
|
||||||
// 移除常見的標籤前綴
|
|
||||||
text = text.replace(/^Gemini said\s*\n*/i, '').replace(/^Gemini\s*[::]\s*\n*/i, '').trim();
|
|
||||||
return { text: text, source: sel + ' .markdown' };
|
|
||||||
}
|
|
||||||
|
|
||||||
// 嘗試找純文字內容(排除標籤)
|
|
||||||
let textContent = el.innerText.trim();
|
|
||||||
if (textContent) {
|
|
||||||
// 移除常見的標籤前綴
|
|
||||||
textContent = textContent.replace(/^Gemini said\s*\n*/i, '').replace(/^Gemini\s*[::]\s*\n*/i, '').trim();
|
|
||||||
return { text: textContent, source: sel };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return { text: '', source: 'none' };
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
if m, ok := result.(map[string]interface{}); ok {
|
|
||||||
text, _ := m["text"].(string)
|
|
||||||
text = strings.TrimSpace(text)
|
|
||||||
|
|
||||||
if text != "" && len(text) > len(lastText) {
|
|
||||||
lastText = text
|
|
||||||
lastUpdate = time.Now()
|
|
||||||
stableCount = 0
|
|
||||||
fmt.Printf("[GeminiWeb] Response: %d chars\n", len(text))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查是否完成(需要連續 3 次穩定)
|
|
||||||
if time.Since(lastUpdate) > 500*time.Millisecond && lastText != "" {
|
|
||||||
stableCount++
|
|
||||||
if stableCount >= 3 {
|
|
||||||
// 最終檢查:停止按鈕是否還存在
|
|
||||||
stopBtn := p.page.Locator("button[aria-label*='Stop'], button[aria-label*='停止'], button[data-test-id='stop-button']")
|
|
||||||
count, _ := stopBtn.Count()
|
|
||||||
|
|
||||||
if count == 0 {
|
|
||||||
fmt.Println("[GeminiWeb] ✓ Response complete")
|
|
||||||
return lastText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastText != "" {
|
|
||||||
fmt.Println("[GeminiWeb] ✓ Response complete (timeout)")
|
|
||||||
return lastText, nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("response timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectModel 選擇 Gemini 模型
|
|
||||||
// Gemini Web 只有三種模型:fast, thinking, pro
|
|
||||||
func (p *PlaywrightProvider) selectModel(model string) error {
|
|
||||||
// 映射模型名稱到 Gemini Web 的模型選擇器
|
|
||||||
modelMap := map[string]string{
|
|
||||||
"fast": "Fast",
|
|
||||||
"thinking": "Thinking",
|
|
||||||
"pro": "Pro",
|
|
||||||
"gemini-fast": "Fast",
|
|
||||||
"gemini-thinking": "Thinking",
|
|
||||||
"gemini-pro": "Pro",
|
|
||||||
"gemini-2.0-fast": "Fast",
|
|
||||||
"gemini-2.0-flash": "Fast", // 相容舊名稱
|
|
||||||
"gemini-2.5-pro": "Pro",
|
|
||||||
"gemini-2.5-pro-thinking": "Thinking",
|
|
||||||
}
|
|
||||||
|
|
||||||
// 從完整模型名稱中提取類型
|
|
||||||
modelType := ""
|
|
||||||
modelLower := strings.ToLower(model)
|
|
||||||
for key, value := range modelMap {
|
|
||||||
if strings.Contains(modelLower, strings.ToLower(key)) || modelLower == strings.ToLower(key) {
|
|
||||||
modelType = value
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelType == "" {
|
|
||||||
// 默認使用 Fast
|
|
||||||
fmt.Printf("[GeminiWeb] Unknown model '%s', defaulting to Fast\n", model)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("[GeminiWeb] Selecting model: %s\n", modelType)
|
|
||||||
|
|
||||||
// 點擊模型選擇器
|
|
||||||
modelSelector := p.page.Locator("button[aria-label*='Model'], button[aria-label*='模型'], [data-test-id='model-selector']")
|
|
||||||
if count, _ := modelSelector.Count(); count > 0 {
|
|
||||||
if err := modelSelector.First().Click(); err != nil {
|
|
||||||
fmt.Printf("[GeminiWeb] Warning: could not click model selector: %v\n", err)
|
|
||||||
} else {
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// 選擇對應的模型選項
|
|
||||||
optionSelector := p.page.Locator(fmt.Sprintf("button:has-text('%s'), [role='menuitem']:has-text('%s')", modelType, modelType))
|
|
||||||
if count, _ := optionSelector.Count(); count > 0 {
|
|
||||||
if err := optionSelector.First().Click(); err != nil {
|
|
||||||
fmt.Printf("[GeminiWeb] Warning: could not select model: %v\n", err)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("[GeminiWeb] ✓ Model selected: %s\n", modelType)
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildPromptFromMessages 從訊息列表建構提示詞
|
|
||||||
func buildPromptFromMessagesPlaywright(messages []apitypes.Message) string {
|
|
||||||
var prompt string
|
|
||||||
for _, m := range messages {
|
|
||||||
switch m.Role {
|
|
||||||
case "system":
|
|
||||||
prompt += "System: " + m.Content + "\n\n"
|
|
||||||
case "user":
|
|
||||||
prompt += m.Content + "\n\n"
|
|
||||||
case "assistant":
|
|
||||||
prompt += "Assistant: " + m.Content + "\n\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return prompt
|
|
||||||
}
|
|
||||||
|
|
@ -1,169 +0,0 @@
|
||||||
package geminiweb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type GeminiSession struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
CookieFile string `json:"cookie_file"`
|
|
||||||
LastUsed int64 `json:"last_used"`
|
|
||||||
ActiveCount int `json:"active_count"`
|
|
||||||
RateLimitEnd int64 `json:"rate_limit_end"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SessionPool struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
sessions []*GeminiSession
|
|
||||||
dir string
|
|
||||||
maxCount int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSessionPool(dir string, maxSessions int) (*SessionPool, error) {
|
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create session dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sessions, err := loadSessions(dir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load sessions: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &SessionPool{
|
|
||||||
sessions: sessions,
|
|
||||||
dir: dir,
|
|
||||||
maxCount: maxSessions,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadSessions(dir string) ([]*GeminiSession, error) {
|
|
||||||
entries, err := os.ReadDir(dir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var sessions []*GeminiSession
|
|
||||||
for _, entry := range entries {
|
|
||||||
if !entry.IsDir() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name := entry.Name()
|
|
||||||
metaPath := filepath.Join(dir, name, "session.json")
|
|
||||||
data, err := os.ReadFile(metaPath)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var s GeminiSession
|
|
||||||
if err := json.Unmarshal(data, &s); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sessions = append(sessions, &s)
|
|
||||||
}
|
|
||||||
|
|
||||||
return sessions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SessionPool) Count() int {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
return len(p.sessions)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SessionPool) GetAvailable() *GeminiSession {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
var available []*GeminiSession
|
|
||||||
for _, s := range p.sessions {
|
|
||||||
if s.RateLimitEnd < now {
|
|
||||||
available = append(available, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(available) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var best *GeminiSession
|
|
||||||
for _, s := range available {
|
|
||||||
if best == nil || s.ActiveCount < best.ActiveCount {
|
|
||||||
best = s
|
|
||||||
} else if s.ActiveCount == best.ActiveCount && s.LastUsed < best.LastUsed {
|
|
||||||
best = s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return best
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SessionPool) StartSession(s *GeminiSession) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
s.ActiveCount++
|
|
||||||
s.LastUsed = time.Now().UnixMilli()
|
|
||||||
p.saveSession(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SessionPool) EndSession(s *GeminiSession) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
if s.ActiveCount > 0 {
|
|
||||||
s.ActiveCount--
|
|
||||||
}
|
|
||||||
p.saveSession(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SessionPool) RateLimitSession(s *GeminiSession, durationMs int64) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
s.RateLimitEnd = time.Now().UnixMilli() + durationMs
|
|
||||||
p.saveSession(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SessionPool) saveSession(s *GeminiSession) {
|
|
||||||
metaPath := filepath.Join(p.dir, s.Name, "session.json")
|
|
||||||
data, err := json.MarshalIndent(s, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = os.WriteFile(metaPath, data, 0644)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SessionPool) CreateSession(name string) (*GeminiSession, error) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
sessionDir := filepath.Join(p.dir, name)
|
|
||||||
if err := os.MkdirAll(sessionDir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create session dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &GeminiSession{
|
|
||||||
Name: name,
|
|
||||||
CookieFile: filepath.Join(sessionDir, "cookies.json"),
|
|
||||||
LastUsed: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
|
|
||||||
p.sessions = append(p.sessions, s)
|
|
||||||
p.saveSession(s)
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SessionPool) GetSessionNames() []string {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
names := make([]string, len(p.sessions))
|
|
||||||
for i, s := range p.sessions {
|
|
||||||
names[i] = s.Name
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
@ -1,196 +0,0 @@
|
||||||
package geminiweb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"cursor-api-proxy/internal/apitypes"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Provider 使用持久化瀏覽器管理器
|
|
||||||
type Provider struct {
|
|
||||||
cfg config.BridgeConfig
|
|
||||||
managerOnce sync.Once
|
|
||||||
manager *BrowserManager
|
|
||||||
managerErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewProvider 建立新的 Provider
|
|
||||||
func NewProvider(cfg config.BridgeConfig) *Provider {
|
|
||||||
return &Provider{cfg: cfg}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getName 返回 Provider 名稱
|
|
||||||
func (p *Provider) Name() string {
|
|
||||||
return "gemini-web"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close 關閉瀏覽器
|
|
||||||
func (p *Provider) Close() error {
|
|
||||||
if p.manager != nil {
|
|
||||||
return p.manager.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getManager 獲取或初始化瀏覽器管理器(單例)
|
|
||||||
func (p *Provider) getManager() (*BrowserManager, error) {
|
|
||||||
p.managerOnce.Do(func() {
|
|
||||||
sessionDir := p.getSessionDir()
|
|
||||||
p.manager, p.managerErr = GetBrowserManager(sessionDir, p.cfg.GeminiBrowserVisible)
|
|
||||||
})
|
|
||||||
return p.manager, p.managerErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// getSessionDir 獲取 session 目錄
|
|
||||||
func (p *Provider) getSessionDir() string {
|
|
||||||
// 使用單一 session 目錄(簡化設計)
|
|
||||||
return filepath.Join(p.cfg.GeminiAccountDir, "default-session")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate 生成回應
|
|
||||||
func (p *Provider) Generate(ctx context.Context, model string, messages []apitypes.Message, tools []apitypes.Tool, cb func(apitypes.StreamChunk)) error {
|
|
||||||
fmt.Printf("[GeminiWeb] Starting generation with model: %s\n", model)
|
|
||||||
|
|
||||||
// 1. 獲取瀏覽器管理器
|
|
||||||
manager, err := p.getManager()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get browser manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 啟動瀏覽器(如果尚未啟動)
|
|
||||||
if !manager.IsRunning() {
|
|
||||||
fmt.Printf("[GeminiWeb] Launching browser...\n")
|
|
||||||
if err := manager.Launch(); err != nil {
|
|
||||||
return fmt.Errorf("failed to launch browser: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 獲取頁面
|
|
||||||
page, err := manager.GetPage()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get page: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 檢查當前 URL,如果不是 Gemini 則導航
|
|
||||||
currentURL, _ := page.Info()
|
|
||||||
if !strings.Contains(currentURL.URL, "gemini.google.com") {
|
|
||||||
fmt.Printf("[GeminiWeb] Navigating to Gemini...\n")
|
|
||||||
if err := NavigateToGemini(page); err != nil {
|
|
||||||
return fmt.Errorf("failed to navigate: %w", err)
|
|
||||||
}
|
|
||||||
time.Sleep(2 * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. 檢查登入狀態
|
|
||||||
fmt.Printf("[GeminiWeb] Checking login status...\n")
|
|
||||||
if !IsLoggedIn(page) {
|
|
||||||
fmt.Printf("[GeminiWeb] Not logged in, continuing anyway\n")
|
|
||||||
|
|
||||||
if p.cfg.GeminiBrowserVisible {
|
|
||||||
fmt.Println("\n========================================")
|
|
||||||
fmt.Println("Browser is open. You can:")
|
|
||||||
fmt.Println("1. Log in to Gemini now")
|
|
||||||
fmt.Println("2. Continue without login")
|
|
||||||
fmt.Println("========================================\n")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fmt.Printf("[GeminiWeb] Logged in\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 6. 等待頁面就緒
|
|
||||||
if err := WaitForReady(page); err != nil {
|
|
||||||
fmt.Printf("[GeminiWeb] Warning: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 7. 建構提示詞
|
|
||||||
prompt := buildPromptFromMessages(messages)
|
|
||||||
fmt.Printf("[GeminiWeb] Typing prompt (%d chars)...\n", len(prompt))
|
|
||||||
|
|
||||||
// 8. 輸入文字
|
|
||||||
if err := TypeInput(page, prompt); err != nil {
|
|
||||||
return fmt.Errorf("failed to type input: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 9. 發送
|
|
||||||
fmt.Printf("[GeminiWeb] Sending message...\n")
|
|
||||||
if err := ClickSend(page); err != nil {
|
|
||||||
return fmt.Errorf("failed to send: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 10. 提取回應
|
|
||||||
fmt.Printf("[GeminiWeb] Waiting for response...\n")
|
|
||||||
response, err := ExtractResponse(page)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to extract response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 11. 串流回調
|
|
||||||
cb(apitypes.StreamChunk{Type: apitypes.ChunkText, Text: response})
|
|
||||||
cb(apitypes.StreamChunk{Type: apitypes.ChunkDone, Done: true})
|
|
||||||
|
|
||||||
fmt.Printf("[GeminiWeb] Response complete (%d chars)\n", len(response))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildPromptFromMessages 從訊息列表建構提示詞
|
|
||||||
func buildPromptFromMessages(messages []apitypes.Message) string {
|
|
||||||
var prompt string
|
|
||||||
for _, m := range messages {
|
|
||||||
switch m.Role {
|
|
||||||
case "system":
|
|
||||||
prompt += "System: " + m.Content + "\n\n"
|
|
||||||
case "user":
|
|
||||||
prompt += m.Content + "\n\n"
|
|
||||||
case "assistant":
|
|
||||||
prompt += "Assistant: " + m.Content + "\n\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunLogin 執行登入流程(供 gemini-login 命令使用)
|
|
||||||
func RunLogin(cfg config.BridgeConfig, sessionName string) error {
|
|
||||||
if sessionName == "" {
|
|
||||||
sessionName = "default-session"
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionDir := filepath.Join(cfg.GeminiAccountDir, sessionName)
|
|
||||||
if err := os.MkdirAll(sessionDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create session dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Starting browser for login. Session: %s\n", sessionName)
|
|
||||||
fmt.Printf("Session directory: %s\n", sessionDir)
|
|
||||||
fmt.Println("Please log in to your Gemini account in the browser window.")
|
|
||||||
fmt.Println("Press Ctrl+C when you have completed the login...")
|
|
||||||
|
|
||||||
manager, err := NewBrowserManager(sessionDir, true) // visible=true
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create browser manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := manager.Launch(); err != nil {
|
|
||||||
return fmt.Errorf("failed to launch browser: %w", err)
|
|
||||||
}
|
|
||||||
defer manager.Close()
|
|
||||||
|
|
||||||
page, err := manager.GetPage()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get page: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := NavigateToGemini(page); err != nil {
|
|
||||||
return fmt.Errorf("failed to navigate: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 等待用戶手動登入...
|
|
||||||
// 使用 Ctrl+C 退出,瀏覽器資料會自動保存在 userDataDir
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,147 +0,0 @@
|
||||||
package router
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/handlers"
|
|
||||||
"cursor-api-proxy/internal/httputil"
|
|
||||||
"cursor-api-proxy/internal/logger"
|
|
||||||
"cursor-api-proxy/internal/pool"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RouterOptions struct {
|
|
||||||
Version string
|
|
||||||
Config config.BridgeConfig
|
|
||||||
ModelCache *handlers.ModelCacheRef
|
|
||||||
LastModel *string
|
|
||||||
Pool pool.PoolHandle
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRouter(opts RouterOptions) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
cfg := opts.Config
|
|
||||||
pathname := r.URL.Path
|
|
||||||
method := r.Method
|
|
||||||
remoteAddress := r.RemoteAddr
|
|
||||||
if r.Header.Get("X-Real-IP") != "" {
|
|
||||||
remoteAddress = r.Header.Get("X-Real-IP")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LogIncoming(method, pathname, remoteAddress)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
logger.AppendSessionLine(cfg.SessionsLogPath, method, pathname, remoteAddress, 200)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if cfg.RequiredKey != "" {
|
|
||||||
token := httputil.ExtractBearerToken(r)
|
|
||||||
if token != cfg.RequiredKey {
|
|
||||||
httputil.WriteJSON(w, 401, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": "Invalid API key", "code": "unauthorized"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case method == "GET" && pathname == "/health":
|
|
||||||
handlers.HandleHealth(w, r, opts.Version, cfg)
|
|
||||||
|
|
||||||
case method == "GET" && pathname == "/v1/models":
|
|
||||||
opts.ModelCache.HandleModels(w, r, cfg)
|
|
||||||
|
|
||||||
case method == "POST" && pathname == "/v1/chat/completions":
|
|
||||||
raw, err := httputil.ReadBody(r)
|
|
||||||
if err != nil {
|
|
||||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": "failed to read body", "code": "bad_request"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 根據 Provider 選擇處理方式
|
|
||||||
provider := cfg.Provider
|
|
||||||
if provider == "" {
|
|
||||||
provider = "cursor"
|
|
||||||
}
|
|
||||||
if provider == "gemini-web" {
|
|
||||||
handlers.HandleGeminiChatCompletions(w, r, cfg, raw, method, pathname, remoteAddress)
|
|
||||||
} else {
|
|
||||||
handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
case method == "POST" && pathname == "/v1/messages":
|
|
||||||
raw, err := httputil.ReadBody(r)
|
|
||||||
if err != nil {
|
|
||||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": "failed to read body", "code": "bad_request"},
|
|
||||||
}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
handlers.HandleAnthropicMessages(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
|
||||||
|
|
||||||
case (method == "POST" || method == "GET") && pathname == "/v1/completions":
|
|
||||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
|
||||||
"error": map[string]string{
|
|
||||||
"message": "Legacy completions endpoint is not supported. Use POST /v1/chat/completions instead.",
|
|
||||||
"code": "not_found",
|
|
||||||
},
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
case pathname == "/v1/embeddings":
|
|
||||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
|
||||||
"error": map[string]string{
|
|
||||||
"message": "Embeddings are not supported by this proxy.",
|
|
||||||
"code": "not_found",
|
|
||||||
},
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
default:
|
|
||||||
httputil.WriteJSON(w, 404, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": "Not found", "code": "not_found"},
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func recoveryMiddleware(logPath string, next http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer func() {
|
|
||||||
if rec := recover(); rec != nil {
|
|
||||||
msg := fmt.Sprintf("%v", rec)
|
|
||||||
fmt.Fprintf(os.Stderr, "[%s] Proxy panic: %s\n", time.Now().UTC().Format(time.RFC3339), msg)
|
|
||||||
line := fmt.Sprintf("%s ERROR %s %s %s %s\n",
|
|
||||||
time.Now().UTC().Format(time.RFC3339), r.Method, r.URL.Path, r.RemoteAddr,
|
|
||||||
msg[:min(200, len(msg))])
|
|
||||||
if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
|
|
||||||
_, _ = f.WriteString(line)
|
|
||||||
f.Close()
|
|
||||||
}
|
|
||||||
if !isHeaderWritten(w) {
|
|
||||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
||||||
"error": map[string]string{"message": msg, "code": "internal_error"},
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
next(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isHeaderWritten(w http.ResponseWriter) bool {
|
|
||||||
// Can't reliably detect without wrapping; always try to write
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func min(a, b int) int {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func WrapWithRecovery(logPath string, handler http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return recoveryMiddleware(logPath, handler)
|
|
||||||
}
|
|
||||||
|
|
@ -1,95 +0,0 @@
|
||||||
package sanitize
|
|
||||||
|
|
||||||
import "regexp"
|
|
||||||
|
|
||||||
type rule struct {
|
|
||||||
pattern *regexp.Regexp
|
|
||||||
replacement string
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules = []rule{
|
|
||||||
{regexp.MustCompile(`(?i)x-anthropic-billing-header:[^\n]*\n?`), ""},
|
|
||||||
{regexp.MustCompile(`(?i)\bcc_version=[^\s;,\n]+[;,]?\s*`), ""},
|
|
||||||
{regexp.MustCompile(`(?i)\bcc_entrypoint=[^\s;,\n]+[;,]?\s*`), ""},
|
|
||||||
{regexp.MustCompile(`(?i)\bcch=[a-f0-9]+[;,]?\s*`), ""},
|
|
||||||
{regexp.MustCompile(`\bClaude Code\b`), "Cursor"},
|
|
||||||
{regexp.MustCompile(`(?i)Anthropic['']s official CLI for Claude`), "Cursor AI assistant"},
|
|
||||||
{regexp.MustCompile(`\bAnthropic\b`), "Cursor"},
|
|
||||||
{regexp.MustCompile(`(?i)anthropic\.com`), "cursor.com"},
|
|
||||||
{regexp.MustCompile(`(?i)claude\.ai`), "cursor.sh"},
|
|
||||||
{regexp.MustCompile(`^[;,\s]+`), ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeText(text string) string {
|
|
||||||
for _, r := range rules {
|
|
||||||
text = r.pattern.ReplaceAllString(text, r.replacement)
|
|
||||||
}
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeMessages(messages []interface{}) []interface{} {
|
|
||||||
result := make([]interface{}, len(messages))
|
|
||||||
for i, raw := range messages {
|
|
||||||
if raw == nil {
|
|
||||||
result[i] = raw
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m, ok := raw.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
result[i] = raw
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newMsg := make(map[string]interface{}, len(m))
|
|
||||||
for k, v := range m {
|
|
||||||
newMsg[k] = v
|
|
||||||
}
|
|
||||||
switch c := m["content"].(type) {
|
|
||||||
case string:
|
|
||||||
newMsg["content"] = SanitizeText(c)
|
|
||||||
case []interface{}:
|
|
||||||
newParts := make([]interface{}, len(c))
|
|
||||||
for j, p := range c {
|
|
||||||
if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" {
|
|
||||||
if t, ok := pm["text"].(string); ok {
|
|
||||||
newPart := make(map[string]interface{}, len(pm))
|
|
||||||
for k, v := range pm {
|
|
||||||
newPart[k] = v
|
|
||||||
}
|
|
||||||
newPart["text"] = SanitizeText(t)
|
|
||||||
newParts[j] = newPart
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
newParts[j] = p
|
|
||||||
}
|
|
||||||
newMsg["content"] = newParts
|
|
||||||
}
|
|
||||||
result[i] = newMsg
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeSystem(system interface{}) interface{} {
|
|
||||||
switch v := system.(type) {
|
|
||||||
case string:
|
|
||||||
return SanitizeText(v)
|
|
||||||
case []interface{}:
|
|
||||||
result := make([]interface{}, len(v))
|
|
||||||
for i, p := range v {
|
|
||||||
if pm, ok := p.(map[string]interface{}); ok && pm["type"] == "text" {
|
|
||||||
if t, ok := pm["text"].(string); ok {
|
|
||||||
newPart := make(map[string]interface{}, len(pm))
|
|
||||||
for k, val := range pm {
|
|
||||||
newPart[k] = val
|
|
||||||
}
|
|
||||||
newPart["text"] = SanitizeText(t)
|
|
||||||
result[i] = newPart
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result[i] = p
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
return system
|
|
||||||
}
|
|
||||||
|
|
@ -1,60 +0,0 @@
|
||||||
package sanitize
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSanitizeTextAnthropicBilling(t *testing.T) {
|
|
||||||
input := "x-anthropic-billing-header: abc123\nHello"
|
|
||||||
got := SanitizeText(input)
|
|
||||||
if strings.Contains(got, "x-anthropic-billing-header") {
|
|
||||||
t.Errorf("billing header not removed: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeTextClaudeCode(t *testing.T) {
|
|
||||||
input := "I am Claude Code assistant"
|
|
||||||
got := SanitizeText(input)
|
|
||||||
if strings.Contains(got, "Claude Code") {
|
|
||||||
t.Errorf("'Claude Code' not replaced: %q", got)
|
|
||||||
}
|
|
||||||
if !strings.Contains(got, "Cursor") {
|
|
||||||
t.Errorf("'Cursor' not present in output: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeTextAnthropic(t *testing.T) {
|
|
||||||
input := "Powered by Anthropic's technology and anthropic.com"
|
|
||||||
got := SanitizeText(input)
|
|
||||||
if strings.Contains(got, "Anthropic") {
|
|
||||||
t.Errorf("'Anthropic' not replaced: %q", got)
|
|
||||||
}
|
|
||||||
if strings.Contains(got, "anthropic.com") {
|
|
||||||
t.Errorf("'anthropic.com' not replaced: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeTextNoChange(t *testing.T) {
|
|
||||||
input := "Hello, this is a normal message about cursor."
|
|
||||||
got := SanitizeText(input)
|
|
||||||
if got != input {
|
|
||||||
t.Errorf("unexpected change: %q -> %q", input, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeMessages(t *testing.T) {
|
|
||||||
messages := []interface{}{
|
|
||||||
map[string]interface{}{"role": "user", "content": "Ask Claude Code something"},
|
|
||||||
map[string]interface{}{"role": "system", "content": "Use Anthropic's tools"},
|
|
||||||
}
|
|
||||||
result := SanitizeMessages(messages)
|
|
||||||
|
|
||||||
for _, raw := range result {
|
|
||||||
m := raw.(map[string]interface{})
|
|
||||||
c := m["content"].(string)
|
|
||||||
if strings.Contains(c, "Claude Code") || strings.Contains(c, "Anthropic") {
|
|
||||||
t.Errorf("found unsanitized content: %q", c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,159 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/handlers"
|
|
||||||
"cursor-api-proxy/internal/pool"
|
|
||||||
"cursor-api-proxy/internal/process"
|
|
||||||
"cursor-api-proxy/internal/logger"
|
|
||||||
"cursor-api-proxy/internal/router"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServerOptions struct {
|
|
||||||
Version string
|
|
||||||
Config config.BridgeConfig
|
|
||||||
Pool pool.PoolHandle
|
|
||||||
}
|
|
||||||
|
|
||||||
func StartBridgeServer(opts ServerOptions) []*http.Server {
|
|
||||||
cfg := opts.Config
|
|
||||||
var servers []*http.Server
|
|
||||||
|
|
||||||
if len(cfg.ConfigDirs) > 0 {
|
|
||||||
if cfg.MultiPort {
|
|
||||||
for i, dir := range cfg.ConfigDirs {
|
|
||||||
port := cfg.Port + i
|
|
||||||
subCfg := cfg
|
|
||||||
subCfg.Port = port
|
|
||||||
subCfg.ConfigDirs = []string{dir}
|
|
||||||
subCfg.MultiPort = false
|
|
||||||
subPool := pool.NewAccountPool([]string{dir})
|
|
||||||
srv := startSingleServer(ServerOptions{Version: opts.Version, Config: subCfg, Pool: subPool})
|
|
||||||
servers = append(servers, srv)
|
|
||||||
}
|
|
||||||
return servers
|
|
||||||
}
|
|
||||||
pool.InitAccountPool(cfg.ConfigDirs)
|
|
||||||
}
|
|
||||||
|
|
||||||
servers = append(servers, startSingleServer(opts))
|
|
||||||
return servers
|
|
||||||
}
|
|
||||||
|
|
||||||
func startSingleServer(opts ServerOptions) *http.Server {
|
|
||||||
cfg := opts.Config
|
|
||||||
|
|
||||||
modelCache := &handlers.ModelCacheRef{}
|
|
||||||
lastModel := cfg.DefaultModel
|
|
||||||
|
|
||||||
ph := opts.Pool
|
|
||||||
if ph == nil {
|
|
||||||
ph = pool.GlobalPoolHandle{}
|
|
||||||
}
|
|
||||||
handler := router.NewRouter(router.RouterOptions{
|
|
||||||
Version: opts.Version,
|
|
||||||
Config: cfg,
|
|
||||||
ModelCache: modelCache,
|
|
||||||
LastModel: &lastModel,
|
|
||||||
Pool: ph,
|
|
||||||
})
|
|
||||||
handler = router.WrapWithRecovery(cfg.SessionsLogPath, handler)
|
|
||||||
|
|
||||||
useTLS := cfg.TLSCertPath != "" && cfg.TLSKeyPath != ""
|
|
||||||
|
|
||||||
srv := &http.Server{
|
|
||||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
|
||||||
Handler: handler,
|
|
||||||
}
|
|
||||||
|
|
||||||
if useTLS {
|
|
||||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "TLS error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
|
|
||||||
}
|
|
||||||
|
|
||||||
scheme := "http"
|
|
||||||
if useTLS {
|
|
||||||
scheme = "https"
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
var err error
|
|
||||||
if useTLS {
|
|
||||||
err = srv.ListenAndServeTLS("", "")
|
|
||||||
} else {
|
|
||||||
err = srv.ListenAndServe()
|
|
||||||
}
|
|
||||||
if err != nil && err != http.ErrServerClosed {
|
|
||||||
if isAddrInUse(err) {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Port %d is already in use. Set CURSOR_BRIDGE_PORT to use a different port.\n", cfg.Port)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Server error: %v\n", err)
|
|
||||||
}
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.LogServerStart(opts.Version, scheme, cfg.Host, cfg.Port, cfg)
|
|
||||||
|
|
||||||
return srv
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetupGracefulShutdown(servers []*http.Server, timeoutMs int) {
|
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
sig := <-sigCh
|
|
||||||
logger.LogShutdown(sig.String())
|
|
||||||
|
|
||||||
process.KillAllChildProcesses()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
for _, srv := range servers {
|
|
||||||
_ = srv.Shutdown(ctx)
|
|
||||||
}
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
os.Exit(0)
|
|
||||||
case <-ctx.Done():
|
|
||||||
fmt.Fprintln(os.Stderr, "[shutdown] Timed out waiting for connections to drain — forcing exit.")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func isAddrInUse(err error) bool {
|
|
||||||
return err != nil && (contains(err.Error(), "address already in use") || contains(err.Error(), "bind: address already in use"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func contains(s, sub string) bool {
|
|
||||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsHelper(s, sub))
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsHelper(s, sub string) bool {
|
|
||||||
for i := 0; i <= len(s)-len(sub); i++ {
|
|
||||||
if s[i:i+len(sub)] == sub {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
@ -1,331 +0,0 @@
|
||||||
package server_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/server"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// freePort 取得一個暫時可用的隨機 port
|
|
||||||
func freePort(t *testing.T) int {
|
|
||||||
t.Helper()
|
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
port := l.Addr().(*net.TCPAddr).Port
|
|
||||||
l.Close()
|
|
||||||
return port
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeFakeAgentBin 建立一個 shell script,模擬 agent 固定輸出
|
|
||||||
// sync 模式:直接輸出一行文字
|
|
||||||
// stream 模式:輸出 JSON stream 行
|
|
||||||
func makeFakeAgentBin(t *testing.T, syncOutput string) string {
|
|
||||||
t.Helper()
|
|
||||||
dir := t.TempDir()
|
|
||||||
script := dir + "/agent"
|
|
||||||
content := fmt.Sprintf(`#!/bin/sh
|
|
||||||
# 若有 --stream-json 則輸出 stream 格式
|
|
||||||
for arg; do
|
|
||||||
if [ "$arg" = "--stream-json" ]; then
|
|
||||||
printf '%%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"%s"}]}}'
|
|
||||||
printf '%%s\n' '{"type":"result","subtype":"success"}'
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
# 否則輸出 sync 格式
|
|
||||||
printf '%%s' '%s'
|
|
||||||
`, syncOutput, syncOutput)
|
|
||||||
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return script
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeFakeAgentBinWithModels 額外支援 --list-models 輸出
|
|
||||||
func makeFakeAgentBinWithModels(t *testing.T) string {
|
|
||||||
t.Helper()
|
|
||||||
dir := t.TempDir()
|
|
||||||
script := dir + "/agent"
|
|
||||||
content := `#!/bin/sh
|
|
||||||
for arg; do
|
|
||||||
if [ "$arg" = "--list-models" ]; then
|
|
||||||
printf 'claude-3-opus - Claude 3 Opus\n'
|
|
||||||
printf 'claude-3-sonnet - Claude 3 Sonnet\n'
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
if [ "$arg" = "--stream-json" ]; then
|
|
||||||
printf '%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}'
|
|
||||||
printf '%s\n' '{"type":"result","subtype":"success"}'
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
printf 'Hello from agent'
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return script
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestConfig(agentBin string, port int, overrides ...func(*config.BridgeConfig)) config.BridgeConfig {
|
|
||||||
cfg := config.BridgeConfig{
|
|
||||||
AgentBin: agentBin,
|
|
||||||
Host: "127.0.0.1",
|
|
||||||
Port: port,
|
|
||||||
DefaultModel: "auto",
|
|
||||||
Mode: "ask",
|
|
||||||
Force: false,
|
|
||||||
ApproveMcps: false,
|
|
||||||
StrictModel: true,
|
|
||||||
Workspace: os.TempDir(),
|
|
||||||
TimeoutMs: 30000,
|
|
||||||
SessionsLogPath: os.TempDir() + "/test-sessions.log",
|
|
||||||
ChatOnlyWorkspace: true,
|
|
||||||
Verbose: false,
|
|
||||||
MaxMode: false,
|
|
||||||
ConfigDirs: []string{},
|
|
||||||
MultiPort: false,
|
|
||||||
WinCmdlineMax: 30000,
|
|
||||||
}
|
|
||||||
for _, fn := range overrides {
|
|
||||||
fn(&cfg)
|
|
||||||
}
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitListening(t *testing.T, host string, port int, timeout time.Duration) {
|
|
||||||
t.Helper()
|
|
||||||
deadline := time.Now().Add(timeout)
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 50*time.Millisecond)
|
|
||||||
if err == nil {
|
|
||||||
conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
}
|
|
||||||
t.Fatalf("server on port %d did not start within %v", port, timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func doRequest(t *testing.T, method, url, body string, headers map[string]string) (int, string) {
|
|
||||||
t.Helper()
|
|
||||||
var reqBody io.Reader
|
|
||||||
if body != "" {
|
|
||||||
reqBody = strings.NewReader(body)
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest(method, url, reqBody)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if body != "" {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
}
|
|
||||||
for k, v := range headers {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
data, _ := io.ReadAll(resp.Body)
|
|
||||||
return resp.StatusCode, string(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeServer_Health(t *testing.T) {
|
|
||||||
port := freePort(t)
|
|
||||||
agentBin := makeFakeAgentBinWithModels(t)
|
|
||||||
cfg := makeTestConfig(agentBin, port)
|
|
||||||
|
|
||||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|
||||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|
||||||
defer func() {
|
|
||||||
for _, s := range srvs {
|
|
||||||
s.Shutdown(context.Background())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
|
||||||
if status != 200 {
|
|
||||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
|
||||||
}
|
|
||||||
var result map[string]interface{}
|
|
||||||
json.Unmarshal([]byte(body), &result)
|
|
||||||
if result["ok"] != true {
|
|
||||||
t.Errorf("ok = %v, want true", result["ok"])
|
|
||||||
}
|
|
||||||
if result["version"] != "1.0.0" {
|
|
||||||
t.Errorf("version = %v, want 1.0.0", result["version"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeServer_Models(t *testing.T) {
|
|
||||||
port := freePort(t)
|
|
||||||
agentBin := makeFakeAgentBinWithModels(t)
|
|
||||||
cfg := makeTestConfig(agentBin, port)
|
|
||||||
|
|
||||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|
||||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|
||||||
defer func() {
|
|
||||||
for _, s := range srvs {
|
|
||||||
s.Shutdown(context.Background())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/v1/models", port), "", nil)
|
|
||||||
if status != 200 {
|
|
||||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
|
||||||
}
|
|
||||||
var result map[string]interface{}
|
|
||||||
json.Unmarshal([]byte(body), &result)
|
|
||||||
if result["object"] != "list" {
|
|
||||||
t.Errorf("object = %v, want list", result["object"])
|
|
||||||
}
|
|
||||||
data := result["data"].([]interface{})
|
|
||||||
if len(data) < 2 {
|
|
||||||
t.Errorf("data len = %d, want >= 2", len(data))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeServer_Unauthorized(t *testing.T) {
|
|
||||||
port := freePort(t)
|
|
||||||
agentBin := makeFakeAgentBinWithModels(t)
|
|
||||||
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
|
||||||
c.RequiredKey = "secret123"
|
|
||||||
})
|
|
||||||
|
|
||||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|
||||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|
||||||
defer func() {
|
|
||||||
for _, s := range srvs {
|
|
||||||
s.Shutdown(context.Background())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
|
||||||
if status != 401 {
|
|
||||||
t.Fatalf("status = %d, want 401; body: %s", status, body)
|
|
||||||
}
|
|
||||||
var result map[string]interface{}
|
|
||||||
json.Unmarshal([]byte(body), &result)
|
|
||||||
errObj := result["error"].(map[string]interface{})
|
|
||||||
if errObj["message"] != "Invalid API key" {
|
|
||||||
t.Errorf("message = %v, want 'Invalid API key'", errObj["message"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeServer_AuthorizedKey(t *testing.T) {
|
|
||||||
port := freePort(t)
|
|
||||||
agentBin := makeFakeAgentBinWithModels(t)
|
|
||||||
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
|
||||||
c.RequiredKey = "secret123"
|
|
||||||
})
|
|
||||||
|
|
||||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|
||||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|
||||||
defer func() {
|
|
||||||
for _, s := range srvs {
|
|
||||||
s.Shutdown(context.Background())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
status, _ := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", map[string]string{
|
|
||||||
"Authorization": "Bearer secret123",
|
|
||||||
})
|
|
||||||
if status != 200 {
|
|
||||||
t.Errorf("status = %d, want 200", status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeServer_NotFound(t *testing.T) {
|
|
||||||
port := freePort(t)
|
|
||||||
agentBin := makeFakeAgentBinWithModels(t)
|
|
||||||
cfg := makeTestConfig(agentBin, port)
|
|
||||||
|
|
||||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|
||||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|
||||||
defer func() {
|
|
||||||
for _, s := range srvs {
|
|
||||||
s.Shutdown(context.Background())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/unknown", port), "", nil)
|
|
||||||
if status != 404 {
|
|
||||||
t.Fatalf("status = %d, want 404; body: %s", status, body)
|
|
||||||
}
|
|
||||||
var result map[string]interface{}
|
|
||||||
json.Unmarshal([]byte(body), &result)
|
|
||||||
errObj := result["error"].(map[string]interface{})
|
|
||||||
if errObj["code"] != "not_found" {
|
|
||||||
t.Errorf("code = %v, want not_found", errObj["code"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeServer_ChatCompletions_Sync(t *testing.T) {
|
|
||||||
port := freePort(t)
|
|
||||||
agentBin := makeFakeAgentBin(t, "Hello from agent")
|
|
||||||
cfg := makeTestConfig(agentBin, port)
|
|
||||||
|
|
||||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|
||||||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|
||||||
defer func() {
|
|
||||||
for _, s := range srvs {
|
|
||||||
s.Shutdown(context.Background())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hi"}]}`
|
|
||||||
status, body := doRequest(t, "POST", fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", port), reqBody, nil)
|
|
||||||
if status != 200 {
|
|
||||||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
|
||||||
}
|
|
||||||
var result map[string]interface{}
|
|
||||||
json.Unmarshal([]byte(body), &result)
|
|
||||||
if result["object"] != "chat.completion" {
|
|
||||||
t.Errorf("object = %v, want chat.completion", result["object"])
|
|
||||||
}
|
|
||||||
choices := result["choices"].([]interface{})
|
|
||||||
msg := choices[0].(map[string]interface{})["message"].(map[string]interface{})
|
|
||||||
if msg["content"] != "Hello from agent" {
|
|
||||||
t.Errorf("content = %v, want 'Hello from agent'", msg["content"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeServer_MultiPort(t *testing.T) {
|
|
||||||
basePort := freePort(t)
|
|
||||||
agentBin := makeFakeAgentBinWithModels(t)
|
|
||||||
|
|
||||||
dir1 := t.TempDir()
|
|
||||||
dir2 := t.TempDir()
|
|
||||||
|
|
||||||
cfg := makeTestConfig(agentBin, basePort, func(c *config.BridgeConfig) {
|
|
||||||
c.ConfigDirs = []string{dir1, dir2}
|
|
||||||
c.MultiPort = true
|
|
||||||
})
|
|
||||||
|
|
||||||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|
||||||
if len(srvs) != 2 {
|
|
||||||
t.Fatalf("got %d servers, want 2", len(srvs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 等待兩個 server 啟動(port 可能會衝突,這裡不嚴格測試 port 分配)
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
for _, s := range srvs {
|
|
||||||
s.Shutdown(context.Background())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
@ -1,154 +0,0 @@
|
||||||
package toolcall
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
Name string
|
|
||||||
Arguments string // JSON string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ParsedResponse struct {
|
|
||||||
TextContent string
|
|
||||||
ToolCalls []ToolCall
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParsedResponse) HasToolCalls() bool {
|
|
||||||
return len(p.ToolCalls) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modified regex to handle nested JSON
|
|
||||||
var toolCallTagRe = regexp.MustCompile(`(?s)行政法规\s*(\{(?:[^{}]|\{[^{}]*\})*\})\s*ugalakh`)
|
|
||||||
var antmlFunctionCallsRe = regexp.MustCompile(`(?s)<function_calls>\s*(.*?)\s*</function_calls>`)
|
|
||||||
var antmlInvokeRe = regexp.MustCompile(`(?s)<invoke\s+name="([^"]+)">\s*(.*?)\s*</invoke>`)
|
|
||||||
var antmlParamRe = regexp.MustCompile(`(?s)<parameter\s+name="([^"]+)">(.*?)</parameter>`)
|
|
||||||
|
|
||||||
func ExtractToolCalls(text string, toolNames map[string]bool) *ParsedResponse {
|
|
||||||
result := &ParsedResponse{}
|
|
||||||
|
|
||||||
if locs := toolCallTagRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
|
|
||||||
var calls []ToolCall
|
|
||||||
var textParts []string
|
|
||||||
last := 0
|
|
||||||
for _, loc := range locs {
|
|
||||||
if loc[0] > last {
|
|
||||||
textParts = append(textParts, text[last:loc[0]])
|
|
||||||
}
|
|
||||||
jsonStr := text[loc[2]:loc[3]]
|
|
||||||
if tc := parseToolCallJSON(jsonStr, toolNames); tc != nil {
|
|
||||||
calls = append(calls, *tc)
|
|
||||||
} else {
|
|
||||||
textParts = append(textParts, text[loc[0]:loc[1]])
|
|
||||||
}
|
|
||||||
last = loc[1]
|
|
||||||
}
|
|
||||||
if last < len(text) {
|
|
||||||
textParts = append(textParts, text[last:])
|
|
||||||
}
|
|
||||||
if len(calls) > 0 {
|
|
||||||
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
|
|
||||||
result.ToolCalls = calls
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if locs := antmlFunctionCallsRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
|
|
||||||
var calls []ToolCall
|
|
||||||
var textParts []string
|
|
||||||
last := 0
|
|
||||||
for _, loc := range locs {
|
|
||||||
if loc[0] > last {
|
|
||||||
textParts = append(textParts, text[last:loc[0]])
|
|
||||||
}
|
|
||||||
block := text[loc[2]:loc[3]]
|
|
||||||
invokes := antmlInvokeRe.FindAllStringSubmatch(block, -1)
|
|
||||||
for _, inv := range invokes {
|
|
||||||
name := inv[1]
|
|
||||||
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
body := inv[2]
|
|
||||||
args := map[string]interface{}{}
|
|
||||||
params := antmlParamRe.FindAllStringSubmatch(body, -1)
|
|
||||||
for _, p := range params {
|
|
||||||
paramName := p[1]
|
|
||||||
paramValue := strings.TrimSpace(p[2])
|
|
||||||
var jsonVal interface{}
|
|
||||||
if err := json.Unmarshal([]byte(paramValue), &jsonVal); err == nil {
|
|
||||||
args[paramName] = jsonVal
|
|
||||||
} else {
|
|
||||||
args[paramName] = paramValue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
argsJSON, _ := json.Marshal(args)
|
|
||||||
calls = append(calls, ToolCall{Name: name, Arguments: string(argsJSON)})
|
|
||||||
}
|
|
||||||
last = loc[1]
|
|
||||||
}
|
|
||||||
if last < len(text) {
|
|
||||||
textParts = append(textParts, text[last:])
|
|
||||||
}
|
|
||||||
if len(calls) > 0 {
|
|
||||||
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
|
|
||||||
result.ToolCalls = calls
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result.TextContent = text
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseToolCallJSON(jsonStr string, toolNames map[string]bool) *ToolCall {
|
|
||||||
var raw map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
name, _ := raw["name"].(string)
|
|
||||||
if name == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var argsStr string
|
|
||||||
switch a := raw["arguments"].(type) {
|
|
||||||
case string:
|
|
||||||
argsStr = a
|
|
||||||
case map[string]interface{}, []interface{}:
|
|
||||||
b, _ := json.Marshal(a)
|
|
||||||
argsStr = string(b)
|
|
||||||
default:
|
|
||||||
if p, ok := raw["parameters"]; ok {
|
|
||||||
b, _ := json.Marshal(p)
|
|
||||||
argsStr = string(b)
|
|
||||||
} else {
|
|
||||||
argsStr = "{}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &ToolCall{Name: name, Arguments: argsStr}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CollectToolNames(tools []interface{}) map[string]bool {
|
|
||||||
names := map[string]bool{}
|
|
||||||
for _, t := range tools {
|
|
||||||
m, ok := t.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if m["type"] == "function" {
|
|
||||||
if fn, ok := m["function"].(map[string]interface{}); ok {
|
|
||||||
if name, ok := fn["name"].(string); ok {
|
|
||||||
names[name] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if name, ok := m["name"].(string); ok {
|
|
||||||
names[name] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
@ -1,181 +0,0 @@
|
||||||
package winlimit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/env"
|
|
||||||
"runtime"
|
|
||||||
)
|
|
||||||
|
|
||||||
const WinPromptOmissionPrefix = "[Earlier messages omitted: Windows command-line length limit.]\n\n"
|
|
||||||
const LinuxPromptOmissionPrefix = "[Earlier messages omitted: Linux ARG_MAX command-line length limit.]\n\n"
|
|
||||||
|
|
||||||
// safeLinuxArgMax returns a conservative estimate of ARG_MAX on Linux.
|
|
||||||
// The actual limit is typically 2MB; we use 1.5MB to leave room for env vars.
|
|
||||||
func safeLinuxArgMax() int {
|
|
||||||
return 1536 * 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
type FitPromptResult struct {
|
|
||||||
OK bool
|
|
||||||
Args []string
|
|
||||||
Truncated bool
|
|
||||||
OriginalLength int
|
|
||||||
FinalPromptLength int
|
|
||||||
Error string
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateCmdlineLength(resolved env.AgentCommand) int {
|
|
||||||
argv := append([]string{resolved.Command}, resolved.Args...)
|
|
||||||
if resolved.WindowsVerbatimArguments {
|
|
||||||
n := 0
|
|
||||||
for _, a := range argv {
|
|
||||||
n += len(a)
|
|
||||||
}
|
|
||||||
if len(argv) > 1 {
|
|
||||||
n += len(argv) - 1
|
|
||||||
}
|
|
||||||
return n + 512
|
|
||||||
}
|
|
||||||
dstLen := 0
|
|
||||||
for _, a := range argv {
|
|
||||||
dstLen += len(a)
|
|
||||||
}
|
|
||||||
dstLen = dstLen*2 + len(argv)*2
|
|
||||||
if len(argv) > 1 {
|
|
||||||
dstLen += len(argv) - 1
|
|
||||||
}
|
|
||||||
return dstLen + 512
|
|
||||||
}
|
|
||||||
|
|
||||||
func FitPromptToWinCmdline(agentBin string, fixedArgs []string, prompt string, maxCmdline int, cwd string) FitPromptResult {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
return fitPromptLinux(fixedArgs, prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
e := env.OsEnvToMap()
|
|
||||||
measured := func(p string) int {
|
|
||||||
args := make([]string, len(fixedArgs)+1)
|
|
||||||
copy(args, fixedArgs)
|
|
||||||
args[len(fixedArgs)] = p
|
|
||||||
resolved := env.ResolveAgentCommand(agentBin, args, e, cwd)
|
|
||||||
return estimateCmdlineLength(resolved)
|
|
||||||
}
|
|
||||||
|
|
||||||
if measured("") > maxCmdline {
|
|
||||||
return FitPromptResult{
|
|
||||||
OK: false,
|
|
||||||
Error: "Windows command line exceeds the configured limit even without a prompt; shorten workspace path, model id, or CURSOR_BRIDGE_WIN_CMDLINE_MAX.",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if measured(prompt) <= maxCmdline {
|
|
||||||
args := make([]string, len(fixedArgs)+1)
|
|
||||||
copy(args, fixedArgs)
|
|
||||||
args[len(fixedArgs)] = prompt
|
|
||||||
return FitPromptResult{
|
|
||||||
OK: true,
|
|
||||||
Args: args,
|
|
||||||
Truncated: false,
|
|
||||||
OriginalLength: len(prompt),
|
|
||||||
FinalPromptLength: len(prompt),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := WinPromptOmissionPrefix
|
|
||||||
if measured(prefix) > maxCmdline {
|
|
||||||
return FitPromptResult{
|
|
||||||
OK: false,
|
|
||||||
Error: "Windows command line too long to fit even the truncation notice; shorten workspace path or flags.",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lo, hi, best := 0, len(prompt), 0
|
|
||||||
for lo <= hi {
|
|
||||||
mid := (lo + hi) / 2
|
|
||||||
var tail string
|
|
||||||
if mid > 0 {
|
|
||||||
tail = prompt[len(prompt)-mid:]
|
|
||||||
}
|
|
||||||
candidate := prefix + tail
|
|
||||||
if measured(candidate) <= maxCmdline {
|
|
||||||
best = mid
|
|
||||||
lo = mid + 1
|
|
||||||
} else {
|
|
||||||
hi = mid - 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var finalPrompt string
|
|
||||||
if best == 0 {
|
|
||||||
finalPrompt = prefix
|
|
||||||
} else {
|
|
||||||
finalPrompt = prefix + prompt[len(prompt)-best:]
|
|
||||||
}
|
|
||||||
|
|
||||||
args := make([]string, len(fixedArgs)+1)
|
|
||||||
copy(args, fixedArgs)
|
|
||||||
args[len(fixedArgs)] = finalPrompt
|
|
||||||
return FitPromptResult{
|
|
||||||
OK: true,
|
|
||||||
Args: args,
|
|
||||||
Truncated: true,
|
|
||||||
OriginalLength: len(prompt),
|
|
||||||
FinalPromptLength: len(finalPrompt),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fitPromptLinux handles Linux ARG_MAX truncation.
|
|
||||||
func fitPromptLinux(fixedArgs []string, prompt string) FitPromptResult {
|
|
||||||
argMax := safeLinuxArgMax()
|
|
||||||
|
|
||||||
// Estimate total cmdline size: sum of all fixed args + prompt + null terminators
|
|
||||||
fixedLen := 0
|
|
||||||
for _, a := range fixedArgs {
|
|
||||||
fixedLen += len(a) + 1
|
|
||||||
}
|
|
||||||
totalLen := fixedLen + len(prompt) + 1
|
|
||||||
|
|
||||||
if totalLen <= argMax {
|
|
||||||
args := make([]string, len(fixedArgs)+1)
|
|
||||||
copy(args, fixedArgs)
|
|
||||||
args[len(fixedArgs)] = prompt
|
|
||||||
return FitPromptResult{
|
|
||||||
OK: true,
|
|
||||||
Args: args,
|
|
||||||
Truncated: false,
|
|
||||||
OriginalLength: len(prompt),
|
|
||||||
FinalPromptLength: len(prompt),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Need to truncate: keep the tail of the prompt (most recent messages)
|
|
||||||
prefix := LinuxPromptOmissionPrefix
|
|
||||||
available := argMax - fixedLen - len(prefix) - 1
|
|
||||||
if available < 0 {
|
|
||||||
available = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
var finalPrompt string
|
|
||||||
if available <= 0 {
|
|
||||||
finalPrompt = prefix
|
|
||||||
} else if available >= len(prompt) {
|
|
||||||
finalPrompt = prefix + prompt
|
|
||||||
} else {
|
|
||||||
finalPrompt = prefix + prompt[len(prompt)-available:]
|
|
||||||
}
|
|
||||||
|
|
||||||
args := make([]string, len(fixedArgs)+1)
|
|
||||||
copy(args, fixedArgs)
|
|
||||||
args[len(fixedArgs)] = finalPrompt
|
|
||||||
return FitPromptResult{
|
|
||||||
OK: true,
|
|
||||||
Args: args,
|
|
||||||
Truncated: true,
|
|
||||||
OriginalLength: len(prompt),
|
|
||||||
FinalPromptLength: len(finalPrompt),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WarnPromptTruncated(originalLength, finalLength int) {
|
|
||||||
_ = originalLength
|
|
||||||
_ = finalLength
|
|
||||||
}
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
||||||
package winlimit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNonWindowsPassThrough(t *testing.T) {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
t.Skip("Skipping non-Windows test on Windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedArgs := []string{"--print", "--model", "gpt-4"}
|
|
||||||
prompt := "Hello world"
|
|
||||||
result := FitPromptToWinCmdline("agent", fixedArgs, prompt, 30000, "/tmp")
|
|
||||||
|
|
||||||
if !result.OK {
|
|
||||||
t.Fatalf("expected OK=true on non-Windows, got error: %s", result.Error)
|
|
||||||
}
|
|
||||||
if result.Truncated {
|
|
||||||
t.Error("expected no truncation on non-Windows")
|
|
||||||
}
|
|
||||||
if result.OriginalLength != len(prompt) {
|
|
||||||
t.Errorf("expected original length %d, got %d", len(prompt), result.OriginalLength)
|
|
||||||
}
|
|
||||||
// Last arg should be the prompt
|
|
||||||
if len(result.Args) == 0 || result.Args[len(result.Args)-1] != prompt {
|
|
||||||
t.Errorf("expected last arg to be prompt, got %v", result.Args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOmissionPrefix(t *testing.T) {
|
|
||||||
if !strings.Contains(WinPromptOmissionPrefix, "Earlier messages omitted") {
|
|
||||||
t.Errorf("omission prefix should mention earlier messages, got: %q", WinPromptOmissionPrefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,30 +0,0 @@
|
||||||
package workspace
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type WorkspaceResult struct {
|
|
||||||
WorkspaceDir string
|
|
||||||
TempDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResolveWorkspace(cfg config.BridgeConfig, workspaceHeader string) WorkspaceResult {
|
|
||||||
if cfg.ChatOnlyWorkspace {
|
|
||||||
tempDir, err := os.MkdirTemp("", "cursor-proxy-")
|
|
||||||
if err != nil {
|
|
||||||
tempDir = filepath.Join(os.TempDir(), "cursor-proxy-fallback")
|
|
||||||
_ = os.MkdirAll(tempDir, 0700)
|
|
||||||
}
|
|
||||||
return WorkspaceResult{WorkspaceDir: tempDir, TempDir: tempDir}
|
|
||||||
}
|
|
||||||
|
|
||||||
headerWs := strings.TrimSpace(workspaceHeader)
|
|
||||||
if headerWs != "" {
|
|
||||||
return WorkspaceResult{WorkspaceDir: headerWs}
|
|
||||||
}
|
|
||||||
return WorkspaceResult{WorkspaceDir: cfg.Workspace}
|
|
||||||
}
|
|
||||||
73
main.go
73
main.go
|
|
@ -1,73 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/cmd"
|
|
||||||
"cursor-api-proxy/internal/config"
|
|
||||||
"cursor-api-proxy/internal/env"
|
|
||||||
"cursor-api-proxy/internal/server"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
const version = "1.0.0"
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
args, err := cmd.ParseArgs(os.Args[1:])
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.Help {
|
|
||||||
cmd.PrintHelp(version)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.Login {
|
|
||||||
if err := cmd.HandleLogin(args.AccountName, args.Proxies); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.Logout {
|
|
||||||
if err := cmd.HandleLogout(args.AccountName); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.AccountsList {
|
|
||||||
if err := cmd.HandleAccountsList(); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.ResetHwid {
|
|
||||||
if err := cmd.HandleResetHwid(args.DeepClean, args.DryRun); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
e := env.OsEnvToMap()
|
|
||||||
if args.Tailscale {
|
|
||||||
e["CURSOR_BRIDGE_HOST"] = "0.0.0.0"
|
|
||||||
}
|
|
||||||
|
|
||||||
cwd, _ := os.Getwd()
|
|
||||||
cfg := config.LoadBridgeConfig(e, cwd)
|
|
||||||
|
|
||||||
servers := server.StartBridgeServer(server.ServerOptions{
|
|
||||||
Version: version,
|
|
||||||
Config: cfg,
|
|
||||||
})
|
|
||||||
server.SetupGracefulShutdown(servers, 10000)
|
|
||||||
|
|
||||||
select {}
|
|
||||||
}
|
|
||||||
|
|
@ -1,159 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cursor-api-proxy/internal/providers/geminiweb"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/go-rod/rod"
|
|
||||||
"github.com/go-rod/rod/lib/launcher"
|
|
||||||
"github.com/go-rod/rod/lib/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
fmt.Println("Starting Gemini DOM detection...")
|
|
||||||
fmt.Println("This will open a browser and analyze the Gemini web interface.")
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
// 啟動可見瀏覽器
|
|
||||||
l := launcher.New().Headless(false)
|
|
||||||
url, err := l.Launch()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Failed to launch browser: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
browser := rod.New().ControlURL(url)
|
|
||||||
if err := browser.Connect(); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Failed to connect browser: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
defer browser.Close()
|
|
||||||
|
|
||||||
page, err := browser.Page(proto.TargetCreateTarget{URL: "about:blank"})
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Failed to create page: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 載入 cookies(如果有)
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
cookieFile := home + "/.cursor-api-proxy/gemini-accounts/session-1/cookies.json"
|
|
||||||
if _, err := os.Stat(cookieFile); err == nil {
|
|
||||||
cookies, err := geminiweb.LoadCookiesFromFile(cookieFile)
|
|
||||||
if err == nil {
|
|
||||||
geminiweb.SetCookiesOnPage(page, cookies)
|
|
||||||
fmt.Println("Loaded existing cookies")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 導航到 Gemini
|
|
||||||
fmt.Println("Navigating to gemini.google.com...")
|
|
||||||
if err := geminiweb.NavigateToGemini(page); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Failed to navigate: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Println("Browser is now open. Please:")
|
|
||||||
fmt.Println("1. Log in if needed")
|
|
||||||
fmt.Println("2. Wait for the chat interface to fully load")
|
|
||||||
fmt.Println("3. Look for the model selector dropdown")
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Println("Press Enter to analyze the DOM...")
|
|
||||||
fmt.Scanln()
|
|
||||||
|
|
||||||
// 分析 DOM
|
|
||||||
analyzeDOM(page)
|
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Println("Press Enter to close...")
|
|
||||||
fmt.Scanln()
|
|
||||||
}
|
|
||||||
|
|
||||||
func analyzeDOM(page *rod.Page) {
|
|
||||||
fmt.Println("=== DOM Analysis ===")
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
// 尋找可能的輸入框
|
|
||||||
fmt.Println("Looking for input fields...")
|
|
||||||
selectors := []string{
|
|
||||||
`textarea`,
|
|
||||||
`[contenteditable="true"]`,
|
|
||||||
`[role="textbox"]`,
|
|
||||||
`input[type="text"]`,
|
|
||||||
}
|
|
||||||
for _, sel := range selectors {
|
|
||||||
elements, err := page.Elements(sel)
|
|
||||||
if err == nil && len(elements) > 0 {
|
|
||||||
fmt.Printf(" Found %d elements with: %s\n", len(elements), sel)
|
|
||||||
for i, el := range elements {
|
|
||||||
tag, _ := el.Property("tagName")
|
|
||||||
class, _ := el.Attribute("class")
|
|
||||||
ariaLabel, _ := el.Attribute("aria-label")
|
|
||||||
placeholder, _ := el.Attribute("placeholder")
|
|
||||||
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s placeholder=%s\n",
|
|
||||||
i, tag, class, ariaLabel, placeholder)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尋找可能的發送按鈕
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Println("Looking for send buttons...")
|
|
||||||
buttonSelectors := []string{
|
|
||||||
`button`,
|
|
||||||
`[role="button"]`,
|
|
||||||
`[type="submit"]`,
|
|
||||||
}
|
|
||||||
for _, sel := range buttonSelectors {
|
|
||||||
elements, err := page.Elements(sel)
|
|
||||||
if err == nil && len(elements) > 0 {
|
|
||||||
fmt.Printf(" Found %d elements with: %s\n", len(elements), sel)
|
|
||||||
for i, el := range elements {
|
|
||||||
if i >= 5 {
|
|
||||||
fmt.Printf(" ... and %d more\n", len(elements)-5)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tag, _ := el.Property("tagName")
|
|
||||||
class, _ := el.Attribute("class")
|
|
||||||
ariaLabel, _ := el.Attribute("aria-label")
|
|
||||||
text, _ := el.Text()
|
|
||||||
text = truncate(text, 30)
|
|
||||||
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n",
|
|
||||||
i, tag, class, ariaLabel, text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尋找模型選擇器
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Println("Looking for model selector...")
|
|
||||||
modelSelectors := []string{
|
|
||||||
`[aria-label*="model"]`,
|
|
||||||
`[aria-label*="Model"]`,
|
|
||||||
`button[aria-haspopup]`,
|
|
||||||
`[data-test-id*="model"]`,
|
|
||||||
`[class*="model"]`,
|
|
||||||
}
|
|
||||||
for _, sel := range modelSelectors {
|
|
||||||
elements, err := page.Elements(sel)
|
|
||||||
if err == nil && len(elements) > 0 {
|
|
||||||
fmt.Printf(" Found with: %s\n", sel)
|
|
||||||
for i, el := range elements {
|
|
||||||
tag, _ := el.Property("tagName")
|
|
||||||
class, _ := el.Attribute("class")
|
|
||||||
ariaLabel, _ := el.Attribute("aria-label")
|
|
||||||
text, _ := el.Text()
|
|
||||||
fmt.Printf(" [%d] tag=%s class=%s aria-label=%s text=%s\n",
|
|
||||||
i, tag, class, ariaLabel, truncate(text, 30))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func truncate(s string, max int) string {
|
|
||||||
if len(s) <= max {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return s[:max] + "..."
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue