package handlers import ( "context" "cursor-api-proxy/internal/agent" "cursor-api-proxy/internal/config" "cursor-api-proxy/internal/httputil" "cursor-api-proxy/internal/logger" "cursor-api-proxy/internal/models" "cursor-api-proxy/internal/openai" "cursor-api-proxy/internal/parser" "cursor-api-proxy/internal/pool" "cursor-api-proxy/internal/sanitize" "cursor-api-proxy/internal/toolcall" "cursor-api-proxy/internal/winlimit" "cursor-api-proxy/internal/workspace" "encoding/json" "fmt" "net/http" "regexp" "strconv" "strings" "time" "github.com/google/uuid" ) var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`) var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`) func isRateLimited(stderr string) bool { return rateLimitRe.MatchString(stderr) } func extractRetryAfterMs(stderr string) int64 { if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 { if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 { return secs * 1000 } } return 60000 } func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) { var bodyMap map[string]interface{} if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil { httputil.WriteJSON(w, 400, map[string]interface{}{ "error": map[string]string{"message": "invalid JSON body", "code": "bad_request"}, }, nil) return } rawModel, _ := bodyMap["model"].(string) requested := openai.NormalizeModelID(rawModel) model := ResolveModel(requested, lastModelRef, cfg) cursorModel := models.ResolveToCursorModel(model) if cursorModel == "" { cursorModel = model } var messages []interface{} if m, ok := bodyMap["messages"].([]interface{}); ok { messages = m } var tools []interface{} if t, ok := bodyMap["tools"].([]interface{}); ok { tools = t } var funcs []interface{} if f, ok := bodyMap["functions"].([]interface{}); ok { funcs = f } cleanMessages := sanitize.SanitizeMessages(messages) toolsText := openai.ToolsToSystemText(tools, funcs) messagesWithTools := cleanMessages if toolsText != "" { messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...) } prompt := openai.BuildPromptFromMessages(messagesWithTools) var trafficMsgs []logger.TrafficMessage for _, raw := range cleanMessages { if m, ok := raw.(map[string]interface{}); ok { role, _ := m["role"].(string) content := openai.MessageContentToText(m["content"]) trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content}) } } isStream := false if s, ok := bodyMap["stream"].(bool); ok { isStream = s } logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream) headerWs := r.Header.Get("x-cursor-workspace") ws := workspace.ResolveWorkspace(cfg, headerWs) promptLen := len(prompt) if cfg.Verbose { if promptLen > 200 { logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200]) } else { logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt) } } fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream) fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir) if cfg.Verbose { logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args) } if !fit.OK { httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"}, }, nil) return } if fit.Truncated { logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength) } cmdArgs := fit.Args id := "chatcmpl_" + uuid.New().String() created := time.Now().Unix() var truncatedHeaders map[string]string if fit.Truncated { truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} } hasTools := len(tools) > 0 || len(funcs) > 0 var toolNames map[string]bool if hasTools { toolNames = toolcall.CollectToolNames(tools) for _, f := range funcs { if fm, ok := f.(map[string]interface{}); ok { if name, ok := fm["name"].(string); ok { toolNames[name] = true } } } } if isStream { httputil.WriteSSEHeaders(w, truncatedHeaders) flusher, _ := w.(http.Flusher) var accumulated string var chunkNum int var p parser.Parser // toolCallMarkerRe 偵測 tool call 開頭標記,一旦出現就停止即時輸出並進入累積模式 toolCallMarkerRe := regexp.MustCompile(`|`) if hasTools { var toolCallMode bool // 是否已進入 tool call 累積模式 p = parser.CreateStreamParserWithThinking( func(text string) { accumulated += text chunkNum++ logger.LogStreamChunk(model, text, chunkNum) if toolCallMode { // 已進入累積模式,不即時輸出 return } if toolCallMarkerRe.MatchString(text) { // 偵測到 tool call 標記,切換為累積模式 toolCallMode = true return } chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } }, func(_ string) {}, // thinking ignored in tools mode func() { logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) parsed := toolcall.ExtractToolCalls(accumulated, toolNames) if parsed.HasToolCalls() { if parsed.TextContent != "" && toolCallMode { // 已有部分 text 被即時輸出,只補發剩餘的 chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } } for i, tc := range parsed.ToolCalls { callID := "call_" + uuid.New().String()[:8] chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{ "tool_calls": []map[string]interface{}{ { "index": i, "id": callID, "type": "function", "function": map[string]interface{}{ "name": tc.Name, "arguments": tc.Arguments, }, }, }, }, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } } stopChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"}, }, } data, _ := json.Marshal(stopChunk) fmt.Fprintf(w, "data: %s\n\n", data) fmt.Fprintf(w, "data: [DONE]\n\n") if flusher != nil { flusher.Flush() } } else { stopChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, }, } data, _ := json.Marshal(stopChunk) fmt.Fprintf(w, "data: %s\n\n", data) fmt.Fprintf(w, "data: [DONE]\n\n") if flusher != nil { flusher.Flush() } } }, ) } else { p = parser.CreateStreamParserWithThinking( func(text string) { accumulated += text chunkNum++ logger.LogStreamChunk(model, text, chunkNum) chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } }, func(thinking string) { chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } }, func() { logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) stopChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, }, } data, _ := json.Marshal(stopChunk) fmt.Fprintf(w, "data: %s\n\n", data) fmt.Fprintf(w, "data: [DONE]\n\n") if flusher != nil { flusher.Flush() } }, ) } configDir := ph.GetNextConfigDir() logger.LogAccountAssigned(configDir) ph.ReportRequestStart(configDir) logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) streamStart := time.Now().UnixMilli() ctx := r.Context() wrappedParser := func(line string) { logger.LogRawLine(line) p.Parse(line) } result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx) // agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾 if ctx.Err() == nil { p.Flush() } latencyMs := time.Now().UnixMilli() - streamStart ph.ReportRequestEnd(configDir) if ctx.Err() == context.DeadlineExceeded { logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) } else if ctx.Err() == context.Canceled { logger.LogClientDisconnect(method, pathname, model, latencyMs) } else if err == nil && isRateLimited(result.Stderr) { ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) } if err != nil || (result.Code != 0 && ctx.Err() == nil) { ph.ReportRequestError(configDir, latencyMs) if err != nil { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) } else { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr) } logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) } else if ctx.Err() == nil { ph.ReportRequestSuccess(configDir, latencyMs) logger.LogRequestDone(method, pathname, model, latencyMs, 0) } logger.LogAccountStats(cfg.Verbose, ph.GetStats()) return } configDir := ph.GetNextConfigDir() logger.LogAccountAssigned(configDir) ph.ReportRequestStart(configDir) logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false) syncStart := time.Now().UnixMilli() out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context()) syncLatency := time.Now().UnixMilli() - syncStart ph.ReportRequestEnd(configDir) ctx := r.Context() if ctx.Err() == context.DeadlineExceeded { logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) httputil.WriteJSON(w, 504, map[string]interface{}{ "error": map[string]string{"message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs), "code": "timeout"}, }, nil) return } if ctx.Err() == context.Canceled { logger.LogClientDisconnect(method, pathname, model, syncLatency) return } if err != nil { ph.ReportRequestError(configDir, syncLatency) logger.LogAccountStats(cfg.Verbose, ph.GetStats()) logger.LogRequestDone(method, pathname, model, syncLatency, -1) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"}, }, nil) return } if isRateLimited(out.Stderr) { ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr)) } if out.Code != 0 { ph.ReportRequestError(configDir, syncLatency) logger.LogAccountStats(cfg.Verbose, ph.GetStats()) errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr) logger.LogRequestDone(method, pathname, model, syncLatency, out.Code) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": errMsg, "code": "cursor_cli_error"}, }, nil) return } ph.ReportRequestSuccess(configDir, syncLatency) content := strings.TrimSpace(out.Stdout) logger.LogTrafficResponse(cfg.Verbose, model, content, false) logger.LogAccountStats(cfg.Verbose, ph.GetStats()) logger.LogRequestDone(method, pathname, model, syncLatency, 0) if hasTools { parsed := toolcall.ExtractToolCalls(content, toolNames) if parsed.HasToolCalls() { msg := map[string]interface{}{"role": "assistant"} if parsed.TextContent != "" { msg["content"] = parsed.TextContent } else { msg["content"] = nil } var tcArr []map[string]interface{} for _, tc := range parsed.ToolCalls { callID := "call_" + uuid.New().String()[:8] tcArr = append(tcArr, map[string]interface{}{ "id": callID, "type": "function", "function": map[string]interface{}{ "name": tc.Name, "arguments": tc.Arguments, }, }) } msg["tool_calls"] = tcArr httputil.WriteJSON(w, 200, map[string]interface{}{ "id": id, "object": "chat.completion", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "message": msg, "finish_reason": "tool_calls"}, }, "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, }, truncatedHeaders) return } } httputil.WriteJSON(w, 200, map[string]interface{}{ "id": id, "object": "chat.completion", "created": created, "model": model, "choices": []map[string]interface{}{ { "index": 0, "message": map[string]string{"role": "assistant", "content": content}, "finish_reason": "stop", }, }, "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, }, truncatedHeaders) }