519 lines
17 KiB
Go
519 lines
17 KiB
Go
package handlers
|
||
|
||
import (
|
||
"context"
|
||
"cursor-api-proxy/internal/agent"
|
||
"cursor-api-proxy/internal/anthropic"
|
||
"cursor-api-proxy/internal/config"
|
||
"cursor-api-proxy/internal/httputil"
|
||
"cursor-api-proxy/internal/logger"
|
||
"cursor-api-proxy/internal/models"
|
||
"cursor-api-proxy/internal/openai"
|
||
"cursor-api-proxy/internal/parser"
|
||
"cursor-api-proxy/internal/pool"
|
||
"cursor-api-proxy/internal/sanitize"
|
||
"cursor-api-proxy/internal/toolcall"
|
||
"cursor-api-proxy/internal/winlimit"
|
||
"cursor-api-proxy/internal/workspace"
|
||
"encoding/json"
|
||
"strings"
|
||
"fmt"
|
||
"net/http"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
||
var req anthropic.MessagesRequest
|
||
if err := json.Unmarshal([]byte(rawBody), &req); err != nil {
|
||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||
"error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"},
|
||
}, nil)
|
||
return
|
||
}
|
||
|
||
requested := openai.NormalizeModelID(req.Model)
|
||
model := ResolveModel(requested, lastModelRef, cfg)
|
||
|
||
var rawMap map[string]interface{}
|
||
_ = json.Unmarshal([]byte(rawBody), &rawMap)
|
||
|
||
cleanSystem := sanitize.SanitizeSystem(req.System)
|
||
|
||
rawMessages := make([]interface{}, len(req.Messages))
|
||
for i, m := range req.Messages {
|
||
rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
||
}
|
||
cleanRawMessages := sanitize.SanitizeMessages(rawMessages)
|
||
|
||
var cleanMessages []anthropic.MessageParam
|
||
for _, raw := range cleanRawMessages {
|
||
if m, ok := raw.(map[string]interface{}); ok {
|
||
role, _ := m["role"].(string)
|
||
cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]})
|
||
}
|
||
}
|
||
|
||
toolsText := openai.ToolsToSystemText(req.Tools, nil)
|
||
var systemWithTools interface{}
|
||
if toolsText != "" {
|
||
sysStr := ""
|
||
switch v := cleanSystem.(type) {
|
||
case string:
|
||
sysStr = v
|
||
}
|
||
if sysStr != "" {
|
||
systemWithTools = sysStr + "\n\n" + toolsText
|
||
} else {
|
||
systemWithTools = toolsText
|
||
}
|
||
} else {
|
||
systemWithTools = cleanSystem
|
||
}
|
||
|
||
prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools)
|
||
|
||
if req.MaxTokens == 0 {
|
||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||
"error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"},
|
||
}, nil)
|
||
return
|
||
}
|
||
|
||
cursorModel := models.ResolveToCursorModel(model)
|
||
if cursorModel == "" {
|
||
cursorModel = model
|
||
}
|
||
|
||
var trafficMsgs []logger.TrafficMessage
|
||
if s := systemToString(cleanSystem); s != "" {
|
||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s})
|
||
}
|
||
for _, m := range cleanMessages {
|
||
text := contentToString(m.Content)
|
||
if text != "" {
|
||
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text})
|
||
}
|
||
}
|
||
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream)
|
||
|
||
headerWs := r.Header.Get("x-cursor-workspace")
|
||
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
||
|
||
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream)
|
||
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
||
|
||
if cfg.Verbose {
|
||
if len(prompt) > 200 {
|
||
logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...")
|
||
} else {
|
||
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt)
|
||
}
|
||
logger.LogDebug("cmd_args=%v", fit.Args)
|
||
}
|
||
|
||
if !fit.OK {
|
||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||
"error": map[string]string{"type": "api_error", "message": fit.Error},
|
||
}, nil)
|
||
return
|
||
}
|
||
if fit.Truncated {
|
||
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
|
||
}
|
||
|
||
cmdArgs := fit.Args
|
||
msgID := "msg_" + uuid.New().String()
|
||
|
||
var truncatedHeaders map[string]string
|
||
if fit.Truncated {
|
||
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
||
}
|
||
|
||
hasTools := len(req.Tools) > 0
|
||
var toolNames map[string]bool
|
||
if hasTools {
|
||
toolNames = toolcall.CollectToolNames(req.Tools)
|
||
}
|
||
|
||
if req.Stream {
|
||
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
||
flusher, _ := w.(http.Flusher)
|
||
|
||
writeEvent := func(evt interface{}) {
|
||
data, _ := json.Marshal(evt)
|
||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||
if flusher != nil {
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
|
||
var accumulated string
|
||
var accumulatedThinking string
|
||
var chunkNum int
|
||
var p parser.Parser
|
||
|
||
writeEvent(map[string]interface{}{
|
||
"type": "message_start",
|
||
"message": map[string]interface{}{
|
||
"id": msgID,
|
||
"type": "message",
|
||
"role": "assistant",
|
||
"model": model,
|
||
"content": []interface{}{},
|
||
},
|
||
})
|
||
|
||
if hasTools {
|
||
// tools 模式:先累積所有內容,完成後再一次性輸出(因為 tool_calls 需要完整解析)
|
||
p = parser.CreateStreamParserWithThinking(
|
||
func(text string) {
|
||
accumulated += text
|
||
chunkNum++
|
||
logger.LogStreamChunk(model, text, chunkNum)
|
||
},
|
||
func(thinking string) {
|
||
accumulatedThinking += thinking
|
||
},
|
||
func() {
|
||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
|
||
|
||
blockIndex := 0
|
||
if accumulatedThinking != "" {
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_start", "index": blockIndex,
|
||
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
||
})
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_delta", "index": blockIndex,
|
||
"delta": map[string]string{"type": "thinking_delta", "thinking": accumulatedThinking},
|
||
})
|
||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||
blockIndex++
|
||
}
|
||
|
||
if parsed.HasToolCalls() {
|
||
if parsed.TextContent != "" {
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_start", "index": blockIndex,
|
||
"content_block": map[string]string{"type": "text", "text": ""},
|
||
})
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_delta", "index": blockIndex,
|
||
"delta": map[string]string{"type": "text_delta", "text": parsed.TextContent},
|
||
})
|
||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||
blockIndex++
|
||
}
|
||
for _, tc := range parsed.ToolCalls {
|
||
toolID := "toolu_" + uuid.New().String()[:12]
|
||
var inputObj interface{}
|
||
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
|
||
if inputObj == nil {
|
||
inputObj = map[string]interface{}{}
|
||
}
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_start", "index": blockIndex,
|
||
"content_block": map[string]interface{}{
|
||
"type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{},
|
||
},
|
||
})
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_delta", "index": blockIndex,
|
||
"delta": map[string]interface{}{
|
||
"type": "input_json_delta", "partial_json": tc.Arguments,
|
||
},
|
||
})
|
||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||
blockIndex++
|
||
}
|
||
writeEvent(map[string]interface{}{
|
||
"type": "message_delta",
|
||
"delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil},
|
||
"usage": map[string]int{"output_tokens": 0},
|
||
})
|
||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||
} else {
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_start", "index": blockIndex,
|
||
"content_block": map[string]string{"type": "text", "text": ""},
|
||
})
|
||
if accumulated != "" {
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_delta", "index": blockIndex,
|
||
"delta": map[string]string{"type": "text_delta", "text": accumulated},
|
||
})
|
||
}
|
||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
|
||
writeEvent(map[string]interface{}{
|
||
"type": "message_delta",
|
||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
||
"usage": map[string]int{"output_tokens": 0},
|
||
})
|
||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||
}
|
||
},
|
||
)
|
||
} else {
|
||
// 非 tools 模式:即時串流 thinking 和 text
|
||
// blockCount 追蹤已開啟的 block 數量
|
||
// thinkingOpen 代表 thinking block 是否已開啟且尚未關閉
|
||
// textOpen 代表 text block 是否已開啟且尚未關閉
|
||
blockCount := 0
|
||
thinkingOpen := false
|
||
textOpen := false
|
||
|
||
p = parser.CreateStreamParserWithThinking(
|
||
func(text string) {
|
||
accumulated += text
|
||
chunkNum++
|
||
logger.LogStreamChunk(model, text, chunkNum)
|
||
// 若 thinking block 尚未關閉,先關閉它
|
||
if thinkingOpen {
|
||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||
thinkingOpen = false
|
||
}
|
||
// 若 text block 尚未開啟,先開啟它
|
||
if !textOpen {
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_start",
|
||
"index": blockCount,
|
||
"content_block": map[string]string{"type": "text", "text": ""},
|
||
})
|
||
textOpen = true
|
||
blockCount++
|
||
}
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_delta",
|
||
"index": blockCount - 1,
|
||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||
})
|
||
},
|
||
func(thinking string) {
|
||
accumulatedThinking += thinking
|
||
chunkNum++
|
||
// 若 thinking block 尚未開啟,先開啟它
|
||
if !thinkingOpen {
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_start",
|
||
"index": blockCount,
|
||
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
||
})
|
||
thinkingOpen = true
|
||
blockCount++
|
||
}
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_delta",
|
||
"index": blockCount - 1,
|
||
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
|
||
})
|
||
},
|
||
func() {
|
||
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
||
// 關閉尚未關閉的 thinking block
|
||
if thinkingOpen {
|
||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||
thinkingOpen = false
|
||
}
|
||
// 若 text block 尚未開啟(全部都是 thinking,沒有 text),開啟並立即關閉空的 text block
|
||
if !textOpen {
|
||
writeEvent(map[string]interface{}{
|
||
"type": "content_block_start",
|
||
"index": blockCount,
|
||
"content_block": map[string]string{"type": "text", "text": ""},
|
||
})
|
||
blockCount++
|
||
}
|
||
// 關閉 text block
|
||
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
|
||
writeEvent(map[string]interface{}{
|
||
"type": "message_delta",
|
||
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
||
"usage": map[string]int{"output_tokens": 0},
|
||
})
|
||
writeEvent(map[string]interface{}{"type": "message_stop"})
|
||
},
|
||
)
|
||
}
|
||
|
||
configDir := ph.GetNextConfigDir()
|
||
logger.LogAccountAssigned(configDir)
|
||
ph.ReportRequestStart(configDir)
|
||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
||
streamStart := time.Now().UnixMilli()
|
||
|
||
ctx := r.Context()
|
||
wrappedParser := func(line string) {
|
||
logger.LogRawLine(line)
|
||
p.Parse(line)
|
||
}
|
||
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
|
||
|
||
// agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾
|
||
if ctx.Err() == nil {
|
||
p.Flush()
|
||
}
|
||
|
||
latencyMs := time.Now().UnixMilli() - streamStart
|
||
ph.ReportRequestEnd(configDir)
|
||
|
||
if ctx.Err() == context.DeadlineExceeded {
|
||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||
} else if ctx.Err() == context.Canceled {
|
||
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
||
} else if err == nil && isRateLimited(result.Stderr) {
|
||
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
||
}
|
||
|
||
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
||
ph.ReportRequestError(configDir, latencyMs)
|
||
if err != nil {
|
||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||
} else {
|
||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
||
}
|
||
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
||
} else if ctx.Err() == nil {
|
||
ph.ReportRequestSuccess(configDir, latencyMs)
|
||
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
||
}
|
||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||
return
|
||
}
|
||
|
||
configDir := ph.GetNextConfigDir()
|
||
logger.LogAccountAssigned(configDir)
|
||
ph.ReportRequestStart(configDir)
|
||
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
|
||
syncStart := time.Now().UnixMilli()
|
||
|
||
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
||
syncLatency := time.Now().UnixMilli() - syncStart
|
||
ph.ReportRequestEnd(configDir)
|
||
|
||
ctx := r.Context()
|
||
if ctx.Err() == context.DeadlineExceeded {
|
||
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
||
httputil.WriteJSON(w, 504, map[string]interface{}{
|
||
"error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)},
|
||
}, nil)
|
||
return
|
||
}
|
||
if ctx.Err() == context.Canceled {
|
||
logger.LogClientDisconnect(method, pathname, model, syncLatency)
|
||
return
|
||
}
|
||
|
||
if err != nil {
|
||
ph.ReportRequestError(configDir, syncLatency)
|
||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
|
||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||
"error": map[string]string{"type": "api_error", "message": err.Error()},
|
||
}, nil)
|
||
return
|
||
}
|
||
|
||
if isRateLimited(out.Stderr) {
|
||
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
|
||
}
|
||
|
||
if out.Code != 0 {
|
||
ph.ReportRequestError(configDir, syncLatency)
|
||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
||
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
|
||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||
"error": map[string]string{"type": "api_error", "message": errMsg},
|
||
}, nil)
|
||
return
|
||
}
|
||
|
||
ph.ReportRequestSuccess(configDir, syncLatency)
|
||
content := strings.TrimSpace(out.Stdout)
|
||
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
||
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
||
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
|
||
|
||
if hasTools {
|
||
parsed := toolcall.ExtractToolCalls(content, toolNames)
|
||
if parsed.HasToolCalls() {
|
||
var contentBlocks []map[string]interface{}
|
||
if parsed.TextContent != "" {
|
||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||
"type": "text", "text": parsed.TextContent,
|
||
})
|
||
}
|
||
for _, tc := range parsed.ToolCalls {
|
||
toolID := "toolu_" + uuid.New().String()[:12]
|
||
var inputObj interface{}
|
||
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
|
||
if inputObj == nil {
|
||
inputObj = map[string]interface{}{}
|
||
}
|
||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||
"type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj,
|
||
})
|
||
}
|
||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||
"id": msgID,
|
||
"type": "message",
|
||
"role": "assistant",
|
||
"content": contentBlocks,
|
||
"model": model,
|
||
"stop_reason": "tool_use",
|
||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||
}, truncatedHeaders)
|
||
return
|
||
}
|
||
}
|
||
|
||
httputil.WriteJSON(w, 200, map[string]interface{}{
|
||
"id": msgID,
|
||
"type": "message",
|
||
"role": "assistant",
|
||
"content": []map[string]string{{"type": "text", "text": content}},
|
||
"model": model,
|
||
"stop_reason": "end_turn",
|
||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||
}, truncatedHeaders)
|
||
}
|
||
|
||
func systemToString(system interface{}) string {
|
||
switch v := system.(type) {
|
||
case string:
|
||
return v
|
||
case []interface{}:
|
||
result := ""
|
||
for _, p := range v {
|
||
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
|
||
if t, ok := m["text"].(string); ok {
|
||
result += t
|
||
}
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func contentToString(content interface{}) string {
|
||
switch v := content.(type) {
|
||
case string:
|
||
return v
|
||
case []interface{}:
|
||
result := ""
|
||
for _, p := range v {
|
||
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
|
||
if t, ok := m["text"].(string); ok {
|
||
result += t
|
||
}
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
return ""
|
||
}
|