opencode-cursor-agent/internal/handlers/anthropic_handler.go

519 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handlers
import (
"context"
"cursor-api-proxy/internal/agent"
"cursor-api-proxy/internal/anthropic"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/httputil"
"cursor-api-proxy/internal/logger"
"cursor-api-proxy/internal/models"
"cursor-api-proxy/internal/openai"
"cursor-api-proxy/internal/parser"
"cursor-api-proxy/internal/pool"
"cursor-api-proxy/internal/sanitize"
"cursor-api-proxy/internal/toolcall"
"cursor-api-proxy/internal/winlimit"
"cursor-api-proxy/internal/workspace"
"encoding/json"
"strings"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
)
func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
var req anthropic.MessagesRequest
if err := json.Unmarshal([]byte(rawBody), &req); err != nil {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"},
}, nil)
return
}
requested := openai.NormalizeModelID(req.Model)
model := ResolveModel(requested, lastModelRef, cfg)
var rawMap map[string]interface{}
_ = json.Unmarshal([]byte(rawBody), &rawMap)
cleanSystem := sanitize.SanitizeSystem(req.System)
rawMessages := make([]interface{}, len(req.Messages))
for i, m := range req.Messages {
rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
}
cleanRawMessages := sanitize.SanitizeMessages(rawMessages)
var cleanMessages []anthropic.MessageParam
for _, raw := range cleanRawMessages {
if m, ok := raw.(map[string]interface{}); ok {
role, _ := m["role"].(string)
cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]})
}
}
toolsText := openai.ToolsToSystemText(req.Tools, nil)
var systemWithTools interface{}
if toolsText != "" {
sysStr := ""
switch v := cleanSystem.(type) {
case string:
sysStr = v
}
if sysStr != "" {
systemWithTools = sysStr + "\n\n" + toolsText
} else {
systemWithTools = toolsText
}
} else {
systemWithTools = cleanSystem
}
prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools)
if req.MaxTokens == 0 {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"},
}, nil)
return
}
cursorModel := models.ResolveToCursorModel(model)
if cursorModel == "" {
cursorModel = model
}
var trafficMsgs []logger.TrafficMessage
if s := systemToString(cleanSystem); s != "" {
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s})
}
for _, m := range cleanMessages {
text := contentToString(m.Content)
if text != "" {
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text})
}
}
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream)
headerWs := r.Header.Get("x-cursor-workspace")
ws := workspace.ResolveWorkspace(cfg, headerWs)
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream)
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
if cfg.Verbose {
if len(prompt) > 200 {
logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...")
} else {
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt)
}
logger.LogDebug("cmd_args=%v", fit.Args)
}
if !fit.OK {
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": fit.Error},
}, nil)
return
}
if fit.Truncated {
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
}
cmdArgs := fit.Args
msgID := "msg_" + uuid.New().String()
var truncatedHeaders map[string]string
if fit.Truncated {
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
}
hasTools := len(req.Tools) > 0
var toolNames map[string]bool
if hasTools {
toolNames = toolcall.CollectToolNames(req.Tools)
}
if req.Stream {
httputil.WriteSSEHeaders(w, truncatedHeaders)
flusher, _ := w.(http.Flusher)
writeEvent := func(evt interface{}) {
data, _ := json.Marshal(evt)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
}
var accumulated string
var accumulatedThinking string
var chunkNum int
var p parser.Parser
writeEvent(map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []interface{}{},
},
})
if hasTools {
// tools 模式:先累積所有內容,完成後再一次性輸出(因為 tool_calls 需要完整解析)
p = parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
chunkNum++
logger.LogStreamChunk(model, text, chunkNum)
},
func(thinking string) {
accumulatedThinking += thinking
},
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
blockIndex := 0
if accumulatedThinking != "" {
writeEvent(map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]string{"type": "thinking", "thinking": ""},
})
writeEvent(map[string]interface{}{
"type": "content_block_delta", "index": blockIndex,
"delta": map[string]string{"type": "thinking_delta", "thinking": accumulatedThinking},
})
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
blockIndex++
}
if parsed.HasToolCalls() {
if parsed.TextContent != "" {
writeEvent(map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
writeEvent(map[string]interface{}{
"type": "content_block_delta", "index": blockIndex,
"delta": map[string]string{"type": "text_delta", "text": parsed.TextContent},
})
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
blockIndex++
}
for _, tc := range parsed.ToolCalls {
toolID := "toolu_" + uuid.New().String()[:12]
var inputObj interface{}
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
if inputObj == nil {
inputObj = map[string]interface{}{}
}
writeEvent(map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]interface{}{
"type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{},
},
})
writeEvent(map[string]interface{}{
"type": "content_block_delta", "index": blockIndex,
"delta": map[string]interface{}{
"type": "input_json_delta", "partial_json": tc.Arguments,
},
})
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
blockIndex++
}
writeEvent(map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
writeEvent(map[string]interface{}{"type": "message_stop"})
} else {
writeEvent(map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
if accumulated != "" {
writeEvent(map[string]interface{}{
"type": "content_block_delta", "index": blockIndex,
"delta": map[string]string{"type": "text_delta", "text": accumulated},
})
}
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex})
writeEvent(map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
writeEvent(map[string]interface{}{"type": "message_stop"})
}
},
)
} else {
// 非 tools 模式:即時串流 thinking 和 text
// blockCount 追蹤已開啟的 block 數量
// thinkingOpen 代表 thinking block 是否已開啟且尚未關閉
// textOpen 代表 text block 是否已開啟且尚未關閉
blockCount := 0
thinkingOpen := false
textOpen := false
p = parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
chunkNum++
logger.LogStreamChunk(model, text, chunkNum)
// 若 thinking block 尚未關閉,先關閉它
if thinkingOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
thinkingOpen = false
}
// 若 text block 尚未開啟,先開啟它
if !textOpen {
writeEvent(map[string]interface{}{
"type": "content_block_start",
"index": blockCount,
"content_block": map[string]string{"type": "text", "text": ""},
})
textOpen = true
blockCount++
}
writeEvent(map[string]interface{}{
"type": "content_block_delta",
"index": blockCount - 1,
"delta": map[string]string{"type": "text_delta", "text": text},
})
},
func(thinking string) {
accumulatedThinking += thinking
chunkNum++
// 若 thinking block 尚未開啟,先開啟它
if !thinkingOpen {
writeEvent(map[string]interface{}{
"type": "content_block_start",
"index": blockCount,
"content_block": map[string]string{"type": "thinking", "thinking": ""},
})
thinkingOpen = true
blockCount++
}
writeEvent(map[string]interface{}{
"type": "content_block_delta",
"index": blockCount - 1,
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
})
},
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
// 關閉尚未關閉的 thinking block
if thinkingOpen {
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
thinkingOpen = false
}
// 若 text block 尚未開啟(全部都是 thinking沒有 text開啟並立即關閉空的 text block
if !textOpen {
writeEvent(map[string]interface{}{
"type": "content_block_start",
"index": blockCount,
"content_block": map[string]string{"type": "text", "text": ""},
})
blockCount++
}
// 關閉 text block
writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1})
writeEvent(map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
writeEvent(map[string]interface{}{"type": "message_stop"})
},
)
}
configDir := ph.GetNextConfigDir()
logger.LogAccountAssigned(configDir)
ph.ReportRequestStart(configDir)
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
streamStart := time.Now().UnixMilli()
ctx := r.Context()
wrappedParser := func(line string) {
logger.LogRawLine(line)
p.Parse(line)
}
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
// agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾
if ctx.Err() == nil {
p.Flush()
}
latencyMs := time.Now().UnixMilli() - streamStart
ph.ReportRequestEnd(configDir)
if ctx.Err() == context.DeadlineExceeded {
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
} else if ctx.Err() == context.Canceled {
logger.LogClientDisconnect(method, pathname, model, latencyMs)
} else if err == nil && isRateLimited(result.Stderr) {
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
}
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
ph.ReportRequestError(configDir, latencyMs)
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
} else {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
}
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
} else if ctx.Err() == nil {
ph.ReportRequestSuccess(configDir, latencyMs)
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
}
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
return
}
configDir := ph.GetNextConfigDir()
logger.LogAccountAssigned(configDir)
ph.ReportRequestStart(configDir)
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
syncStart := time.Now().UnixMilli()
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
syncLatency := time.Now().UnixMilli() - syncStart
ph.ReportRequestEnd(configDir)
ctx := r.Context()
if ctx.Err() == context.DeadlineExceeded {
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
httputil.WriteJSON(w, 504, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)},
}, nil)
return
}
if ctx.Err() == context.Canceled {
logger.LogClientDisconnect(method, pathname, model, syncLatency)
return
}
if err != nil {
ph.ReportRequestError(configDir, syncLatency)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": err.Error()},
}, nil)
return
}
if isRateLimited(out.Stderr) {
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
}
if out.Code != 0 {
ph.ReportRequestError(configDir, syncLatency)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": errMsg},
}, nil)
return
}
ph.ReportRequestSuccess(configDir, syncLatency)
content := strings.TrimSpace(out.Stdout)
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
if hasTools {
parsed := toolcall.ExtractToolCalls(content, toolNames)
if parsed.HasToolCalls() {
var contentBlocks []map[string]interface{}
if parsed.TextContent != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text", "text": parsed.TextContent,
})
}
for _, tc := range parsed.ToolCalls {
toolID := "toolu_" + uuid.New().String()[:12]
var inputObj interface{}
_ = json.Unmarshal([]byte(tc.Arguments), &inputObj)
if inputObj == nil {
inputObj = map[string]interface{}{}
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj,
})
}
httputil.WriteJSON(w, 200, map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": contentBlocks,
"model": model,
"stop_reason": "tool_use",
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
}, truncatedHeaders)
return
}
}
httputil.WriteJSON(w, 200, map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": []map[string]string{{"type": "text", "text": content}},
"model": model,
"stop_reason": "end_turn",
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
}, truncatedHeaders)
}
func systemToString(system interface{}) string {
switch v := system.(type) {
case string:
return v
case []interface{}:
result := ""
for _, p := range v {
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
if t, ok := m["text"].(string); ok {
result += t
}
}
}
return result
}
return ""
}
func contentToString(content interface{}) string {
switch v := content.(type) {
case string:
return v
case []interface{}:
result := ""
for _, p := range v {
if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" {
if t, ok := m["text"].(string); ok {
result += t
}
}
}
return result
}
return ""
}