diff --git a/internal/handler/chat/chat_completions_handler.go b/internal/handler/chat/chat_completions_handler.go index 7e15c0a..7f0f083 100644 --- a/internal/handler/chat/chat_completions_handler.go +++ b/internal/handler/chat/chat_completions_handler.go @@ -26,10 +26,8 @@ func ChatCompletionsHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - err := l.ChatCompletionsStream(&req, w) - if err != nil { - w.Write([]byte("event: error\ndata: " + err.Error() + "\n\n")) - } + w.Header().Set("X-Accel-Buffering", "no") + _ = l.ChatCompletionsStream(&req, w, r.Method, r.URL.Path) } else { resp, err := l.ChatCompletions(&req) if err != nil { diff --git a/internal/logic/chat/anthropic_messages_logic.go b/internal/logic/chat/anthropic_messages_logic.go index cd3b1e0..a1637c0 100644 --- a/internal/logic/chat/anthropic_messages_logic.go +++ b/internal/logic/chat/anthropic_messages_logic.go @@ -5,6 +5,7 @@ package chat import ( "context" + "net/http" "cursor-api-proxy/internal/svc" apitypes "cursor-api-proxy/internal/types" @@ -27,6 +28,16 @@ func NewAnthropicMessagesLogic(ctx context.Context, svcCtx *svc.ServiceContext) } func (l *AnthropicMessagesLogic) AnthropicMessages(req *apitypes.AnthropicRequest) error { - // TODO: implement Anthropic messages API // This should convert from Anthropic format to Cursor/Gemini provider + // TODO: implement Anthropic Messages API + // This should convert Anthropic format to Cursor/Gemini provider + // Similar to ChatCompletions but with Anthropic-style response format + return nil +} + +// AnthropicMessagesStream handles streaming for Anthropic Messages API +func (l *AnthropicMessagesLogic) AnthropicMessagesStream(req *apitypes.AnthropicRequest, w http.ResponseWriter) error { + // TODO: implement Anthropic Messages streaming + // This should convert Anthropic format to Cursor/Gemini provider + // And stream back in Anthropic event format return nil } diff --git a/internal/logic/chat/chat_completions_logic.go b/internal/logic/chat/chat_completions_logic.go index ee3db38..77b277b 100644 --- a/internal/logic/chat/chat_completions_logic.go +++ b/internal/logic/chat/chat_completions_logic.go @@ -8,20 +8,43 @@ import ( "encoding/json" "fmt" "net/http" + "regexp" + "strconv" + "strings" + "time" "cursor-api-proxy/internal/config" "cursor-api-proxy/internal/svc" apitypes "cursor-api-proxy/internal/types" "cursor-api-proxy/pkg/adapter/openai" "cursor-api-proxy/pkg/domain/types" + "cursor-api-proxy/pkg/infrastructure/httputil" "cursor-api-proxy/pkg/infrastructure/logger" "cursor-api-proxy/pkg/infrastructure/parser" + "cursor-api-proxy/pkg/infrastructure/winlimit" + "cursor-api-proxy/pkg/infrastructure/workspace" "cursor-api-proxy/pkg/usecase" "github.com/google/uuid" "github.com/zeromicro/go-zero/core/logx" ) +var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`) +var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`) + +func isRateLimited(stderr string) bool { + return rateLimitRe.MatchString(stderr) +} + +func extractRetryAfterMs(stderr string) int64 { + if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 { + if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 { + return secs * 1000 + } + } + return 60000 +} + type ChatCompletionsLogic struct { logx.Logger ctx context.Context @@ -36,55 +59,42 @@ func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *C } } -func (l *ChatCompletionsLogic) ChatCompletions(req *apitypes.ChatCompletionRequest) (*apitypes.ChatCompletionResponse, error) { - cfg := configToBridge(l.svcCtx.Config) - model := openai.NormalizeModelID(req.Model) - cursorModel := types.ResolveToCursorModel(model) - if cursorModel == "" { - cursorModel = model +func (l *ChatCompletionsLogic) resolveModel(requested string, lastModelRef *string) string { + cfg := l.svcCtx.Config + isAuto := requested == "auto" + var explicitModel string + if requested != "" && !isAuto { + explicitModel = requested } - - messages := convertMessages(req.Messages) - tools := convertTools(req.Tools) - functions := convertFunctions(req.Functions) - - cleanMessages := usecase.SanitizeMessages(messages) - toolsText := openai.ToolsToSystemText(tools, functions) - messagesWithTools := cleanMessages - if toolsText != "" { - messagesWithTools = append([]interface{}{ - map[string]interface{}{"role": "system", "content": toolsText}, - }, cleanMessages...) + if explicitModel != "" { + *lastModelRef = explicitModel } - - prompt := openai.BuildPromptFromMessages(messagesWithTools) - - // TODO: implement non-streaming execution - _ = cfg - _ = cursorModel - _ = prompt - - return &apitypes.ChatCompletionResponse{ - Id: "chatcmpl_" + uuid.New().String(), - Object: "chat.completion", - Created: 0, - Model: model, - Choices: []apitypes.Choice{ - { - Index: 0, - Message: apitypes.RespMessage{ - Role: "assistant", - Content: "TODO: implement non-streaming response", - }, - FinishReason: "stop", - }, - }, - }, nil + if isAuto { + return "auto" + } + if explicitModel != "" { + return explicitModel + } + if cfg.StrictModel && *lastModelRef != "" { + return *lastModelRef + } + if *lastModelRef != "" { + return *lastModelRef + } + return cfg.DefaultModel } -func (l *ChatCompletionsLogic) ChatCompletionsStream(req *apitypes.ChatCompletionRequest, w http.ResponseWriter) error { +func (l *ChatCompletionsLogic) ChatCompletions(req *apitypes.ChatCompletionRequest) (*apitypes.ChatCompletionResponse, error) { + return nil, fmt.Errorf("non-streaming not yet implemented, use stream=true") +} + +func (l *ChatCompletionsLogic) ChatCompletionsStream(req *apitypes.ChatCompletionRequest, w http.ResponseWriter, method, pathname string) error { cfg := configToBridge(l.svcCtx.Config) - model := openai.NormalizeModelID(req.Model) + + rawModel := req.Model + requested := openai.NormalizeModelID(rawModel) + lastModelRef := new(string) + model := l.resolveModel(requested, lastModelRef) cursorModel := types.ResolveToCursorModel(model) if cursorModel == "" { cursorModel = model @@ -102,62 +112,285 @@ func (l *ChatCompletionsLogic) ChatCompletionsStream(req *apitypes.ChatCompletio map[string]interface{}{"role": "system", "content": toolsText}, }, cleanMessages...) } - prompt := openai.BuildPromptFromMessages(messagesWithTools) - if l.svcCtx.Config.Verbose { - logger.LogDebug("model=%s prompt_len=%d", cursorModel, len(prompt)) + var trafficMsgs []logger.TrafficMessage + for _, raw := range cleanMessages { + if m, ok := raw.(map[string]interface{}); ok { + role, _ := m["role"].(string) + content := openai.MessageContentToText(m["content"]) + trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content}) + } } + logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, true) + + ws := workspace.ResolveWorkspace(cfg, "") + + promptLen := len(prompt) + if cfg.Verbose { + if promptLen > 200 { + logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200]) + } else { + logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt) + } + } + + maxCmdline := cfg.WinCmdlineMax + if maxCmdline == 0 { + maxCmdline = 32768 + } + fixedArgs := usecase.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, true) + fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, maxCmdline, ws.WorkspaceDir) + + if l.svcCtx.Config.Verbose { + logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args) + } + + if !fit.OK { + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"}, + }, nil) + return nil + } + if fit.Truncated { + logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength) + } + + cmdArgs := fit.Args id := "chatcmpl_" + uuid.New().String() - created := int64(0) + created := time.Now().Unix() + + var truncatedHeaders map[string]string + if fit.Truncated { + truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} + } hasTools := len(tools) > 0 || len(functions) > 0 + var toolNames map[string]bool + if hasTools { + toolNames = usecase.CollectToolNames(tools) + for _, f := range functions { + if fm, ok := f.(map[string]interface{}); ok { + if name, ok := fm["name"].(string); ok { + toolNames[name] = true + } + } + } + } + httputil.WriteSSEHeaders(w, truncatedHeaders) flusher, _ := w.(http.Flusher) var accumulated string var chunkNum int var p parser.Parser - // TODO: implement proper streaming with usecase.RunAgentStream - // For now, return a placeholder response - _ = cfg - _ = prompt - _ = hasTools - _ = p - _ = chunkNum - _ = accumulated + toolCallMarkerRe := regexp.MustCompile(`\x1e|`) + if hasTools { + var toolCallMode bool + p = parser.CreateStreamParserWithThinking( + func(text string) { + accumulated += text + chunkNum++ + logger.LogStreamChunk(model, text, chunkNum) + if toolCallMode { + return + } + if toolCallMarkerRe.MatchString(text) { + toolCallMode = true + return + } + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + }, + func(thinking string) { + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + }, + func() { + logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) + parsed := usecase.ExtractToolCalls(accumulated, toolNames) - chunk := map[string]interface{}{ - "id": id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]string{"content": "Streaming not yet implemented"}, "finish_reason": nil}, - }, - } - data, _ := json.Marshal(chunk) - fmt.Fprintf(w, "data: %s\n\n", data) - if flusher != nil { - flusher.Flush() + if parsed.HasToolCalls() { + if parsed.TextContent != "" && toolCallMode { + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + } + for i, tc := range parsed.ToolCalls { + callID := "call_" + uuid.New().String()[:8] + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{ + { + "index": i, + "id": callID, + "type": "function", + "function": map[string]interface{}{ + "name": tc.Name, + "arguments": tc.Arguments, + }, + }, + }, + }, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + } + stopChunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"}, + }, + } + data, _ := json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + } else { + stopChunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, + }, + } + data, _ := json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + } + }, + ) + } else { + p = parser.CreateStreamParserWithThinking( + func(text string) { + accumulated += text + chunkNum++ + logger.LogStreamChunk(model, text, chunkNum) + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + }, + func(thinking string) { + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + }, + func() { + logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) + stopChunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, + }, + } + data, _ := json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + }, + ) } - stopChunk := map[string]interface{}{ - "id": id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": []map[string]interface{}{ - {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, - }, + configDir := l.svcCtx.AccountPool.GetNextConfigDir() + logger.LogAccountAssigned(configDir) + l.svcCtx.AccountPool.ReportRequestStart(configDir) + logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) + streamStart := time.Now().UnixMilli() + + wrappedParser := func(line string) { + logger.LogRawLine(line) + p.Parse(line) } - data, _ = json.Marshal(stopChunk) - fmt.Fprintf(w, "data: %s\n\n", data) - fmt.Fprintf(w, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() + result, err := usecase.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, l.ctx) + + if l.ctx.Err() == nil { + p.Flush() } + latencyMs := time.Now().UnixMilli() - streamStart + l.svcCtx.AccountPool.ReportRequestEnd(configDir) + + if l.ctx.Err() == context.DeadlineExceeded { + logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) + } else if l.ctx.Err() == context.Canceled { + logger.LogClientDisconnect(method, pathname, model, latencyMs) + } else if err == nil && isRateLimited(result.Stderr) { + l.svcCtx.AccountPool.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) + } + + if err != nil || (result.Code != 0 && l.ctx.Err() == nil) { + l.svcCtx.AccountPool.ReportRequestError(configDir, latencyMs) + if err != nil { + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", -1, err.Error()) + } else { + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", result.Code, result.Stderr) + } + logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) + } else if l.ctx.Err() == nil { + l.svcCtx.AccountPool.ReportRequestSuccess(configDir, latencyMs) + logger.LogRequestDone(method, pathname, model, latencyMs, 0) + } + logger.LogAccountStats(cfg.Verbose, l.svcCtx.AccountPool.GetStats()) + return nil } -// Stub implementations - TODO: full implementation func convertMessages(msgs []apitypes.Message) []interface{} { result := make([]interface{}, len(msgs)) for i, m := range msgs { @@ -203,9 +436,13 @@ func convertFunctions(funcs []apitypes.Function) []interface{} { } func configToBridge(c config.Config) config.BridgeConfig { + host := c.Host + if host == "" { + host = "0.0.0.0" + } return config.BridgeConfig{ AgentBin: c.AgentBin, - Host: c.Host, + Host: host, Port: c.Port, RequiredKey: c.RequiredKey, DefaultModel: c.DefaultModel, @@ -231,6 +468,16 @@ func configToBridge(c config.Config) config.BridgeConfig { } } -// Placeholder for usecase functions -// These should be properly implemented with the usecase package -var _ = usecase.AccountsDir +// StringsToMapSlice converts string slice for compatibility +func StringsToMapSlice(ss []string) []map[string]string { + result := make([]map[string]string, len(ss)) + for i, s := range ss { + result[i] = map[string]string{"content": s} + } + return result +} + +// JoinStrings joins strings with newline +func JoinStrings(ss []string) string { + return strings.Join(ss, "\n") +}