package handlers import ( "context" "cursor-api-proxy/internal/agent" "cursor-api-proxy/internal/anthropic" "cursor-api-proxy/internal/config" "cursor-api-proxy/internal/httputil" "cursor-api-proxy/internal/logger" "cursor-api-proxy/internal/models" "cursor-api-proxy/internal/openai" "cursor-api-proxy/internal/parser" "cursor-api-proxy/internal/pool" "cursor-api-proxy/internal/sanitize" "cursor-api-proxy/internal/toolcall" "cursor-api-proxy/internal/winlimit" "cursor-api-proxy/internal/workspace" "encoding/json" "fmt" "net/http" "regexp" "strings" "time" "github.com/google/uuid" ) func HandleAnthropicMessages(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { var req anthropic.MessagesRequest if err := json.Unmarshal([]byte(rawBody), &req); err != nil { httputil.WriteJSON(w, 400, map[string]interface{}{ "error": map[string]string{"type": "invalid_request_error", "message": "invalid JSON body"}, }, nil) return } requested := openai.NormalizeModelID(req.Model) model := ResolveModel(requested, lastModelRef, cfg) var rawMap map[string]interface{} _ = json.Unmarshal([]byte(rawBody), &rawMap) cleanSystem := sanitize.SanitizeSystem(req.System) rawMessages := make([]interface{}, len(req.Messages)) for i, m := range req.Messages { rawMessages[i] = map[string]interface{}{"role": m.Role, "content": m.Content} } cleanRawMessages := sanitize.SanitizeMessages(rawMessages) var cleanMessages []anthropic.MessageParam for _, raw := range cleanRawMessages { if m, ok := raw.(map[string]interface{}); ok { role, _ := m["role"].(string) cleanMessages = append(cleanMessages, anthropic.MessageParam{Role: role, Content: m["content"]}) } } toolsText := openai.ToolsToSystemText(req.Tools, nil) var systemWithTools interface{} if toolsText != "" { sysStr := "" switch v := cleanSystem.(type) { case string: sysStr = v } if sysStr != "" { systemWithTools = sysStr + "\n\n" + toolsText } else { systemWithTools = toolsText } } else { systemWithTools = cleanSystem } prompt := anthropic.BuildPromptFromAnthropicMessages(cleanMessages, systemWithTools) if req.MaxTokens == 0 { httputil.WriteJSON(w, 400, map[string]interface{}{ "error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"}, }, nil) return } cursorModel := models.ResolveToCursorModel(model) if cursorModel == "" { cursorModel = model } var trafficMsgs []logger.TrafficMessage if s := systemToString(cleanSystem); s != "" { trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: s}) } for _, m := range cleanMessages { text := contentToString(m.Content) if text != "" { trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: m.Role, Content: text}) } } logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, req.Stream) headerWs := r.Header.Get("x-cursor-workspace") ws := workspace.ResolveWorkspace(cfg, headerWs) fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, req.Stream) fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir) if cfg.Verbose { if len(prompt) > 200 { logger.LogDebug("model=%s prompt_len=%d prompt_preview=%q", cursorModel, len(prompt), prompt[:200]+"...") } else { logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, len(prompt), prompt) } logger.LogDebug("cmd_args=%v", fit.Args) } if !fit.OK { httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"type": "api_error", "message": fit.Error}, }, nil) return } if fit.Truncated { logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength) } cmdArgs := fit.Args msgID := "msg_" + uuid.New().String() var truncatedHeaders map[string]string if fit.Truncated { truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} } hasTools := len(req.Tools) > 0 var toolNames map[string]bool if hasTools { toolNames = toolcall.CollectToolNames(req.Tools) } if req.Stream { httputil.WriteSSEHeaders(w, truncatedHeaders) flusher, _ := w.(http.Flusher) writeEvent := func(evt interface{}) { data, _ := json.Marshal(evt) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } } var accumulated string var accumulatedThinking string var chunkNum int var p parser.Parser writeEvent(map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ "id": msgID, "type": "message", "role": "assistant", "model": model, "content": []interface{}{}, }, }) if hasTools { toolCallMarkerRe := regexp.MustCompile(`行政法规|`) var toolCallMode bool textBlockOpen := false textBlockIndex := 0 thinkingOpen := false thinkingBlockIndex := 0 blockCount := 0 p = parser.CreateStreamParserWithThinking( func(text string) { accumulated += text chunkNum++ logger.LogStreamChunk(model, text, chunkNum) if toolCallMode { return } if toolCallMarkerRe.MatchString(text) { if textBlockOpen { writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex}) textBlockOpen = false } if thinkingOpen { writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex}) thinkingOpen = false } toolCallMode = true return } if !textBlockOpen && !thinkingOpen { textBlockIndex = blockCount writeEvent(map[string]interface{}{ "type": "content_block_start", "index": textBlockIndex, "content_block": map[string]string{"type": "text", "text": ""}, }) textBlockOpen = true blockCount++ } if thinkingOpen { writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex}) thinkingOpen = false } writeEvent(map[string]interface{}{ "type": "content_block_delta", "index": textBlockIndex, "delta": map[string]string{"type": "text_delta", "text": text}, }) }, func(thinking string) { accumulatedThinking += thinking chunkNum++ if toolCallMode { return } if !thinkingOpen { thinkingBlockIndex = blockCount writeEvent(map[string]interface{}{ "type": "content_block_start", "index": thinkingBlockIndex, "content_block": map[string]string{"type": "thinking", "thinking": ""}, }) thinkingOpen = true blockCount++ } writeEvent(map[string]interface{}{ "type": "content_block_delta", "index": thinkingBlockIndex, "delta": map[string]string{"type": "thinking_delta", "thinking": thinking}, }) }, func() { logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) parsed := toolcall.ExtractToolCalls(accumulated, toolNames) blockIndex := 0 if thinkingOpen { writeEvent(map[string]interface{}{"type": "content_block_stop", "index": thinkingBlockIndex}) thinkingOpen = false } if parsed.HasToolCalls() { if textBlockOpen { writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex}) blockIndex = textBlockIndex + 1 } if parsed.TextContent != "" && !textBlockOpen && !toolCallMode { writeEvent(map[string]interface{}{ "type": "content_block_start", "index": blockIndex, "content_block": map[string]string{"type": "text", "text": ""}, }) writeEvent(map[string]interface{}{ "type": "content_block_delta", "index": blockIndex, "delta": map[string]string{"type": "text_delta", "text": parsed.TextContent}, }) writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) blockIndex++ } for _, tc := range parsed.ToolCalls { toolID := "toolu_" + uuid.New().String()[:12] var inputObj interface{} _ = json.Unmarshal([]byte(tc.Arguments), &inputObj) if inputObj == nil { inputObj = map[string]interface{}{} } writeEvent(map[string]interface{}{ "type": "content_block_start", "index": blockIndex, "content_block": map[string]interface{}{ "type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{}, }, }) writeEvent(map[string]interface{}{ "type": "content_block_delta", "index": blockIndex, "delta": map[string]interface{}{ "type": "input_json_delta", "partial_json": tc.Arguments, }, }) writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) blockIndex++ } writeEvent(map[string]interface{}{ "type": "message_delta", "delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil}, "usage": map[string]int{"output_tokens": 0}, }) writeEvent(map[string]interface{}{"type": "message_stop"}) } else { if textBlockOpen { writeEvent(map[string]interface{}{"type": "content_block_stop", "index": textBlockIndex}) } else if accumulated != "" { writeEvent(map[string]interface{}{ "type": "content_block_start", "index": blockIndex, "content_block": map[string]string{"type": "text", "text": ""}, }) writeEvent(map[string]interface{}{ "type": "content_block_delta", "index": blockIndex, "delta": map[string]string{"type": "text_delta", "text": accumulated}, }) writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) blockIndex++ } else { writeEvent(map[string]interface{}{ "type": "content_block_start", "index": blockIndex, "content_block": map[string]string{"type": "text", "text": ""}, }) writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockIndex}) blockIndex++ } writeEvent(map[string]interface{}{ "type": "message_delta", "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, "usage": map[string]int{"output_tokens": 0}, }) writeEvent(map[string]interface{}{"type": "message_stop"}) } }, ) } else { // 非 tools 模式:即時串流 thinking 和 text blockCount := 0 thinkingOpen := false textOpen := false p = parser.CreateStreamParserWithThinking( func(text string) { accumulated += text chunkNum++ logger.LogStreamChunk(model, text, chunkNum) if thinkingOpen { writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) thinkingOpen = false } if !textOpen { writeEvent(map[string]interface{}{ "type": "content_block_start", "index": blockCount, "content_block": map[string]string{"type": "text", "text": ""}, }) textOpen = true blockCount++ } writeEvent(map[string]interface{}{ "type": "content_block_delta", "index": blockCount - 1, "delta": map[string]string{"type": "text_delta", "text": text}, }) }, func(thinking string) { accumulatedThinking += thinking chunkNum++ if !thinkingOpen { writeEvent(map[string]interface{}{ "type": "content_block_start", "index": blockCount, "content_block": map[string]string{"type": "thinking", "thinking": ""}, }) thinkingOpen = true blockCount++ } writeEvent(map[string]interface{}{ "type": "content_block_delta", "index": blockCount - 1, "delta": map[string]string{"type": "thinking_delta", "thinking": thinking}, }) }, func() { logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) if thinkingOpen { writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) thinkingOpen = false } if !textOpen { writeEvent(map[string]interface{}{ "type": "content_block_start", "index": blockCount, "content_block": map[string]string{"type": "text", "text": ""}, }) blockCount++ } writeEvent(map[string]interface{}{"type": "content_block_stop", "index": blockCount - 1}) writeEvent(map[string]interface{}{ "type": "message_delta", "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, "usage": map[string]int{"output_tokens": 0}, }) writeEvent(map[string]interface{}{"type": "message_stop"}) }, ) } configDir := ph.GetNextConfigDir() logger.LogAccountAssigned(configDir) ph.ReportRequestStart(configDir) logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) streamStart := time.Now().UnixMilli() ctx := r.Context() wrappedParser := func(line string) { logger.LogRawLine(line) p.Parse(line) } result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx) if ctx.Err() == nil { p.Flush() } latencyMs := time.Now().UnixMilli() - streamStart ph.ReportRequestEnd(configDir) if ctx.Err() == context.DeadlineExceeded { logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) } else if ctx.Err() == context.Canceled { logger.LogClientDisconnect(method, pathname, model, latencyMs) } else if err == nil && isRateLimited(result.Stderr) { ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) } if err != nil || (result.Code != 0 && ctx.Err() == nil) { ph.ReportRequestError(configDir, latencyMs) if err != nil { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) } else { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr) } logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) } else if ctx.Err() == nil { ph.ReportRequestSuccess(configDir, latencyMs) logger.LogRequestDone(method, pathname, model, latencyMs, 0) } logger.LogAccountStats(cfg.Verbose, ph.GetStats()) return } configDir := ph.GetNextConfigDir() logger.LogAccountAssigned(configDir) ph.ReportRequestStart(configDir) logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false) syncStart := time.Now().UnixMilli() out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context()) syncLatency := time.Now().UnixMilli() - syncStart ph.ReportRequestEnd(configDir) ctx := r.Context() if ctx.Err() == context.DeadlineExceeded { logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) httputil.WriteJSON(w, 504, map[string]interface{}{ "error": map[string]string{"type": "api_error", "message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs)}, }, nil) return } if ctx.Err() == context.Canceled { logger.LogClientDisconnect(method, pathname, model, syncLatency) return } if err != nil { ph.ReportRequestError(configDir, syncLatency) logger.LogAccountStats(cfg.Verbose, ph.GetStats()) logger.LogRequestDone(method, pathname, model, syncLatency, -1) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"type": "api_error", "message": err.Error()}, }, nil) return } if isRateLimited(out.Stderr) { ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr)) } if out.Code != 0 { ph.ReportRequestError(configDir, syncLatency) logger.LogAccountStats(cfg.Verbose, ph.GetStats()) errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr) logger.LogRequestDone(method, pathname, model, syncLatency, out.Code) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"type": "api_error", "message": errMsg}, }, nil) return } ph.ReportRequestSuccess(configDir, syncLatency) content := strings.TrimSpace(out.Stdout) logger.LogTrafficResponse(cfg.Verbose, model, content, false) logger.LogAccountStats(cfg.Verbose, ph.GetStats()) logger.LogRequestDone(method, pathname, model, syncLatency, 0) if hasTools { parsed := toolcall.ExtractToolCalls(content, toolNames) if parsed.HasToolCalls() { var contentBlocks []map[string]interface{} if parsed.TextContent != "" { contentBlocks = append(contentBlocks, map[string]interface{}{ "type": "text", "text": parsed.TextContent, }) } for _, tc := range parsed.ToolCalls { toolID := "toolu_" + uuid.New().String()[:12] var inputObj interface{} _ = json.Unmarshal([]byte(tc.Arguments), &inputObj) if inputObj == nil { inputObj = map[string]interface{}{} } contentBlocks = append(contentBlocks, map[string]interface{}{ "type": "tool_use", "id": toolID, "name": tc.Name, "input": inputObj, }) } httputil.WriteJSON(w, 200, map[string]interface{}{ "id": msgID, "type": "message", "role": "assistant", "content": contentBlocks, "model": model, "stop_reason": "tool_use", "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, }, truncatedHeaders) return } } httputil.WriteJSON(w, 200, map[string]interface{}{ "id": msgID, "type": "message", "role": "assistant", "content": []map[string]string{{"type": "text", "text": content}}, "model": model, "stop_reason": "end_turn", "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, }, truncatedHeaders) } func systemToString(system interface{}) string { switch v := system.(type) { case string: return v case []interface{}: result := "" for _, p := range v { if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" { if t, ok := m["text"].(string); ok { result += t } } } return result } return "" } func contentToString(content interface{}) string { switch v := content.(type) { case string: return v case []interface{}: result := "" for _, p := range v { if m, ok := p.(map[string]interface{}); ok && m["type"] == "text" { if t, ok := m["text"].(string); ok { result += t } } } return result } return "" }