// Code scaffolded by goctl. Safe to edit. // goctl 1.10.1 package chat import ( "context" "encoding/json" "fmt" "net/http" "regexp" "strconv" "strings" "time" "cursor-api-proxy/internal/config" "cursor-api-proxy/internal/svc" apitypes "cursor-api-proxy/internal/types" "cursor-api-proxy/pkg/adapter/openai" "cursor-api-proxy/pkg/domain/types" "cursor-api-proxy/pkg/infrastructure/httputil" "cursor-api-proxy/pkg/infrastructure/logger" "cursor-api-proxy/pkg/infrastructure/parser" "cursor-api-proxy/pkg/infrastructure/winlimit" "cursor-api-proxy/pkg/infrastructure/workspace" "cursor-api-proxy/pkg/usecase" "github.com/google/uuid" "github.com/zeromicro/go-zero/core/logx" ) var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`) var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`) func isRateLimited(stderr string) bool { return rateLimitRe.MatchString(stderr) } func extractRetryAfterMs(stderr string) int64 { if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 { if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 { return secs * 1000 } } return 60000 } type ChatCompletionsLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic { return &ChatCompletionsLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, } } func (l *ChatCompletionsLogic) resolveModel(requested string, lastModelRef *string) string { cfg := l.svcCtx.Config isAuto := requested == "auto" var explicitModel string if requested != "" && !isAuto { explicitModel = requested } if explicitModel != "" { *lastModelRef = explicitModel } if isAuto { return "auto" } if explicitModel != "" { return explicitModel } if cfg.StrictModel && *lastModelRef != "" { return *lastModelRef } if *lastModelRef != "" { return *lastModelRef } return cfg.DefaultModel } func (l *ChatCompletionsLogic) ChatCompletions(req *apitypes.ChatCompletionRequest) (*apitypes.ChatCompletionResponse, error) { return nil, fmt.Errorf("non-streaming not yet implemented, use stream=true") } func (l *ChatCompletionsLogic) ChatCompletionsStream(req *apitypes.ChatCompletionRequest, w http.ResponseWriter, method, pathname string) error { cfg := configToBridge(l.svcCtx.Config) rawModel := req.Model requested := openai.NormalizeModelID(rawModel) lastModelRef := new(string) model := l.resolveModel(requested, lastModelRef) cursorModel := types.ResolveToCursorModel(model) if cursorModel == "" { cursorModel = model } messages := convertMessages(req.Messages) tools := convertTools(req.Tools) functions := convertFunctions(req.Functions) cleanMessages := usecase.SanitizeMessages(messages) toolsText := openai.ToolsToSystemText(tools, functions) messagesWithTools := cleanMessages if toolsText != "" { messagesWithTools = append([]interface{}{ map[string]interface{}{"role": "system", "content": toolsText}, }, cleanMessages...) } prompt := openai.BuildPromptFromMessages(messagesWithTools) var trafficMsgs []logger.TrafficMessage for _, raw := range cleanMessages { if m, ok := raw.(map[string]interface{}); ok { role, _ := m["role"].(string) content := openai.MessageContentToText(m["content"]) trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content}) } } logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, true) ws := workspace.ResolveWorkspace(cfg, "") promptLen := len(prompt) if cfg.Verbose { if promptLen > 200 { logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200]) } else { logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt) } } maxCmdline := cfg.WinCmdlineMax if maxCmdline == 0 { maxCmdline = 32768 } fixedArgs := usecase.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, true) fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, maxCmdline, ws.WorkspaceDir) if l.svcCtx.Config.Verbose { logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args) } if !fit.OK { httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"}, }, nil) return nil } if fit.Truncated { logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength) } cmdArgs := fit.Args id := "chatcmpl_" + uuid.New().String() created := time.Now().Unix() var truncatedHeaders map[string]string if fit.Truncated { truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} } hasTools := len(tools) > 0 || len(functions) > 0 var toolNames map[string]bool if hasTools { toolNames = usecase.CollectToolNames(tools) for _, f := range functions { if fm, ok := f.(map[string]interface{}); ok { if name, ok := fm["name"].(string); ok { toolNames[name] = true } } } } httputil.WriteSSEHeaders(w, truncatedHeaders) flusher, _ := w.(http.Flusher) var accumulated string var chunkNum int var p parser.Parser toolCallMarkerRe := regexp.MustCompile(`\x1e|`) if hasTools { var toolCallMode bool p = parser.CreateStreamParserWithThinking( func(text string) { accumulated += text chunkNum++ logger.LogStreamChunk(model, text, chunkNum) if toolCallMode { return } if toolCallMarkerRe.MatchString(text) { toolCallMode = true return } chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } }, func(thinking string) { chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } }, func() { logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) parsed := usecase.ExtractToolCalls(accumulated, toolNames) if parsed.HasToolCalls() { if parsed.TextContent != "" && toolCallMode { chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } } for i, tc := range parsed.ToolCalls { callID := "call_" + uuid.New().String()[:8] chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{ "tool_calls": []map[string]interface{}{ { "index": i, "id": callID, "type": "function", "function": map[string]interface{}{ "name": tc.Name, "arguments": tc.Arguments, }, }, }, }, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } } stopChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"}, }, } data, _ := json.Marshal(stopChunk) fmt.Fprintf(w, "data: %s\n\n", data) fmt.Fprintf(w, "data: [DONE]\n\n") if flusher != nil { flusher.Flush() } } else { stopChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, }, } data, _ := json.Marshal(stopChunk) fmt.Fprintf(w, "data: %s\n\n", data) fmt.Fprintf(w, "data: [DONE]\n\n") if flusher != nil { flusher.Flush() } } }, ) } else { p = parser.CreateStreamParserWithThinking( func(text string) { accumulated += text chunkNum++ logger.LogStreamChunk(model, text, chunkNum) chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } }, func(thinking string) { chunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil}, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } }, func() { logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true) stopChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": []map[string]interface{}{ {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, }, } data, _ := json.Marshal(stopChunk) fmt.Fprintf(w, "data: %s\n\n", data) fmt.Fprintf(w, "data: [DONE]\n\n") if flusher != nil { flusher.Flush() } }, ) } configDir := l.svcCtx.AccountPool.GetNextConfigDir() logger.LogAccountAssigned(configDir) l.svcCtx.AccountPool.ReportRequestStart(configDir) logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) streamStart := time.Now().UnixMilli() wrappedParser := func(line string) { logger.LogRawLine(line) p.Parse(line) } result, err := usecase.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, l.ctx) if l.ctx.Err() == nil { p.Flush() } latencyMs := time.Now().UnixMilli() - streamStart l.svcCtx.AccountPool.ReportRequestEnd(configDir) if l.ctx.Err() == context.DeadlineExceeded { logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) } else if l.ctx.Err() == context.Canceled { logger.LogClientDisconnect(method, pathname, model, latencyMs) } else if err == nil && isRateLimited(result.Stderr) { l.svcCtx.AccountPool.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) } if err != nil || (result.Code != 0 && l.ctx.Err() == nil) { l.svcCtx.AccountPool.ReportRequestError(configDir, latencyMs) if err != nil { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", -1, err.Error()) } else { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", result.Code, result.Stderr) } logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) } else if l.ctx.Err() == nil { l.svcCtx.AccountPool.ReportRequestSuccess(configDir, latencyMs) logger.LogRequestDone(method, pathname, model, latencyMs, 0) } logger.LogAccountStats(cfg.Verbose, l.svcCtx.AccountPool.GetStats()) return nil } func convertMessages(msgs []apitypes.Message) []interface{} { result := make([]interface{}, len(msgs)) for i, m := range msgs { result[i] = map[string]interface{}{ "role": m.Role, "content": m.Content, } } return result } func convertTools(tools []apitypes.Tool) []interface{} { if tools == nil { return nil } result := make([]interface{}, len(tools)) for i, t := range tools { result[i] = map[string]interface{}{ "type": t.Type, "function": map[string]interface{}{ "name": t.Function.Name, "description": t.Function.Description, "parameters": t.Function.Parameters, }, } } return result } func convertFunctions(funcs []apitypes.Function) []interface{} { if funcs == nil { return nil } result := make([]interface{}, len(funcs)) for i, f := range funcs { result[i] = map[string]interface{}{ "name": f.Name, "description": f.Description, "parameters": f.Parameters, } } return result } func configToBridge(c config.Config) config.BridgeConfig { host := c.Host if host == "" { host = "0.0.0.0" } return config.BridgeConfig{ AgentBin: c.AgentBin, Host: host, Port: c.Port, RequiredKey: c.RequiredKey, DefaultModel: c.DefaultModel, Mode: "ask", Provider: c.Provider, Force: c.Force, ApproveMcps: c.ApproveMcps, StrictModel: c.StrictModel, Workspace: c.Workspace, TimeoutMs: c.TimeoutMs, TLSCertPath: c.TLSCertPath, TLSKeyPath: c.TLSKeyPath, SessionsLogPath: c.SessionsLogPath, ChatOnlyWorkspace: c.ChatOnlyWorkspace, Verbose: c.Verbose, MaxMode: c.MaxMode, ConfigDirs: c.ConfigDirs, MultiPort: c.MultiPort, WinCmdlineMax: c.WinCmdlineMax, GeminiAccountDir: c.GeminiAccountDir, GeminiBrowserVisible: c.GeminiBrowserVisible, GeminiMaxSessions: c.GeminiMaxSessions, } } // StringsToMapSlice converts string slice for compatibility func StringsToMapSlice(ss []string) []map[string]string { result := make([]map[string]string, len(ss)) for i, s := range ss { result[i] = map[string]string{"content": s} } return result } // JoinStrings joins strings with newline func JoinStrings(ss []string) string { return strings.Join(ss, "\n") }