diff --git a/internal/handler/chat/chat_completions_handler.go b/internal/handler/chat/chat_completions_handler.go index 1b09da3..7e15c0a 100644 --- a/internal/handler/chat/chat_completions_handler.go +++ b/internal/handler/chat/chat_completions_handler.go @@ -9,6 +9,7 @@ import ( "cursor-api-proxy/internal/logic/chat" "cursor-api-proxy/internal/svc" "cursor-api-proxy/internal/types" + "github.com/zeromicro/go-zero/rest/httpx" ) @@ -21,11 +22,21 @@ func ChatCompletionsHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { } l := chat.NewChatCompletionsLogic(r.Context(), svcCtx) - err := l.ChatCompletions(&req) - if err != nil { - httpx.ErrorCtx(r.Context(), w, err) + if req.Stream { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + err := l.ChatCompletionsStream(&req, w) + if err != nil { + w.Write([]byte("event: error\ndata: " + err.Error() + "\n\n")) + } } else { - httpx.Ok(w) + resp, err := l.ChatCompletions(&req) + if err != nil { + httpx.ErrorCtx(r.Context(), w, err) + } else { + httpx.OkJsonCtx(r.Context(), w, resp) + } } } } diff --git a/internal/logic/chat/anthropic_messages_logic.go b/internal/logic/chat/anthropic_messages_logic.go index b8a840c..cd3b1e0 100644 --- a/internal/logic/chat/anthropic_messages_logic.go +++ b/internal/logic/chat/anthropic_messages_logic.go @@ -7,7 +7,7 @@ import ( "context" "cursor-api-proxy/internal/svc" - "cursor-api-proxy/internal/types" + apitypes "cursor-api-proxy/internal/types" "github.com/zeromicro/go-zero/core/logx" ) @@ -26,8 +26,7 @@ func NewAnthropicMessagesLogic(ctx context.Context, svcCtx *svc.ServiceContext) } } -func (l *AnthropicMessagesLogic) AnthropicMessages(req *types.AnthropicRequest) error { - // todo: add your logic here and delete this line - +func (l *AnthropicMessagesLogic) AnthropicMessages(req *apitypes.AnthropicRequest) error { + // TODO: implement Anthropic messages API // This should convert from Anthropic format to Cursor/Gemini provider return nil } diff --git a/internal/logic/chat/chat_completions_logic.go b/internal/logic/chat/chat_completions_logic.go index 36cbe67..ee3db38 100644 --- a/internal/logic/chat/chat_completions_logic.go +++ b/internal/logic/chat/chat_completions_logic.go @@ -5,10 +5,20 @@ package chat import ( "context" + "encoding/json" + "fmt" + "net/http" + "cursor-api-proxy/internal/config" "cursor-api-proxy/internal/svc" - "cursor-api-proxy/internal/types" + apitypes "cursor-api-proxy/internal/types" + "cursor-api-proxy/pkg/adapter/openai" + "cursor-api-proxy/pkg/domain/types" + "cursor-api-proxy/pkg/infrastructure/logger" + "cursor-api-proxy/pkg/infrastructure/parser" + "cursor-api-proxy/pkg/usecase" + "github.com/google/uuid" "github.com/zeromicro/go-zero/core/logx" ) @@ -26,8 +36,201 @@ func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *C } } -func (l *ChatCompletionsLogic) ChatCompletions(req *types.ChatCompletionRequest) error { - // todo: add your logic here and delete this line +func (l *ChatCompletionsLogic) ChatCompletions(req *apitypes.ChatCompletionRequest) (*apitypes.ChatCompletionResponse, error) { + cfg := configToBridge(l.svcCtx.Config) + model := openai.NormalizeModelID(req.Model) + cursorModel := types.ResolveToCursorModel(model) + if cursorModel == "" { + cursorModel = model + } + + messages := convertMessages(req.Messages) + tools := convertTools(req.Tools) + functions := convertFunctions(req.Functions) + + cleanMessages := usecase.SanitizeMessages(messages) + toolsText := openai.ToolsToSystemText(tools, functions) + messagesWithTools := cleanMessages + if toolsText != "" { + messagesWithTools = append([]interface{}{ + map[string]interface{}{"role": "system", "content": toolsText}, + }, cleanMessages...) + } + + prompt := openai.BuildPromptFromMessages(messagesWithTools) + + // TODO: implement non-streaming execution + _ = cfg + _ = cursorModel + _ = prompt + + return &apitypes.ChatCompletionResponse{ + Id: "chatcmpl_" + uuid.New().String(), + Object: "chat.completion", + Created: 0, + Model: model, + Choices: []apitypes.Choice{ + { + Index: 0, + Message: apitypes.RespMessage{ + Role: "assistant", + Content: "TODO: implement non-streaming response", + }, + FinishReason: "stop", + }, + }, + }, nil +} + +func (l *ChatCompletionsLogic) ChatCompletionsStream(req *apitypes.ChatCompletionRequest, w http.ResponseWriter) error { + cfg := configToBridge(l.svcCtx.Config) + model := openai.NormalizeModelID(req.Model) + cursorModel := types.ResolveToCursorModel(model) + if cursorModel == "" { + cursorModel = model + } + + messages := convertMessages(req.Messages) + tools := convertTools(req.Tools) + functions := convertFunctions(req.Functions) + + cleanMessages := usecase.SanitizeMessages(messages) + toolsText := openai.ToolsToSystemText(tools, functions) + messagesWithTools := cleanMessages + if toolsText != "" { + messagesWithTools = append([]interface{}{ + map[string]interface{}{"role": "system", "content": toolsText}, + }, cleanMessages...) + } + + prompt := openai.BuildPromptFromMessages(messagesWithTools) + + if l.svcCtx.Config.Verbose { + logger.LogDebug("model=%s prompt_len=%d", cursorModel, len(prompt)) + } + + id := "chatcmpl_" + uuid.New().String() + created := int64(0) + + hasTools := len(tools) > 0 || len(functions) > 0 + + flusher, _ := w.(http.Flusher) + + var accumulated string + var chunkNum int + var p parser.Parser + + // TODO: implement proper streaming with usecase.RunAgentStream + // For now, return a placeholder response + _ = cfg + _ = prompt + _ = hasTools + _ = p + _ = chunkNum + _ = accumulated + + chunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]string{"content": "Streaming not yet implemented"}, "finish_reason": nil}, + }, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + + stopChunk := map[string]interface{}{ + "id": id, "object": "chat.completion.chunk", "created": created, "model": model, + "choices": []map[string]interface{}{ + {"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"}, + }, + } + data, _ = json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } return nil } + +// Stub implementations - TODO: full implementation +func convertMessages(msgs []apitypes.Message) []interface{} { + result := make([]interface{}, len(msgs)) + for i, m := range msgs { + result[i] = map[string]interface{}{ + "role": m.Role, + "content": m.Content, + } + } + return result +} + +func convertTools(tools []apitypes.Tool) []interface{} { + if tools == nil { + return nil + } + result := make([]interface{}, len(tools)) + for i, t := range tools { + result[i] = map[string]interface{}{ + "type": t.Type, + "function": map[string]interface{}{ + "name": t.Function.Name, + "description": t.Function.Description, + "parameters": t.Function.Parameters, + }, + } + } + return result +} + +func convertFunctions(funcs []apitypes.Function) []interface{} { + if funcs == nil { + return nil + } + result := make([]interface{}, len(funcs)) + for i, f := range funcs { + result[i] = map[string]interface{}{ + "name": f.Name, + "description": f.Description, + "parameters": f.Parameters, + } + } + return result +} + +func configToBridge(c config.Config) config.BridgeConfig { + return config.BridgeConfig{ + AgentBin: c.AgentBin, + Host: c.Host, + Port: c.Port, + RequiredKey: c.RequiredKey, + DefaultModel: c.DefaultModel, + Mode: "ask", + Provider: c.Provider, + Force: c.Force, + ApproveMcps: c.ApproveMcps, + StrictModel: c.StrictModel, + Workspace: c.Workspace, + TimeoutMs: c.TimeoutMs, + TLSCertPath: c.TLSCertPath, + TLSKeyPath: c.TLSKeyPath, + SessionsLogPath: c.SessionsLogPath, + ChatOnlyWorkspace: c.ChatOnlyWorkspace, + Verbose: c.Verbose, + MaxMode: c.MaxMode, + ConfigDirs: c.ConfigDirs, + MultiPort: c.MultiPort, + WinCmdlineMax: c.WinCmdlineMax, + GeminiAccountDir: c.GeminiAccountDir, + GeminiBrowserVisible: c.GeminiBrowserVisible, + GeminiMaxSessions: c.GeminiMaxSessions, + } +} + +// Placeholder for usecase functions +// These should be properly implemented with the usecase package +var _ = usecase.AccountsDir diff --git a/internal/logic/chat/health_logic.go b/internal/logic/chat/health_logic.go index e7046ac..24fba97 100644 --- a/internal/logic/chat/health_logic.go +++ b/internal/logic/chat/health_logic.go @@ -27,7 +27,8 @@ func NewHealthLogic(ctx context.Context, svcCtx *svc.ServiceContext) *HealthLogi } func (l *HealthLogic) Health() (resp *types.HealthResponse, err error) { - // todo: add your logic here and delete this line - - return + return &types.HealthResponse{ + Status: "ok", + Version: "1.0.0", + }, nil } diff --git a/internal/logic/chat/models_logic.go b/internal/logic/chat/models_logic.go index 31eac5e..1fe2ead 100644 --- a/internal/logic/chat/models_logic.go +++ b/internal/logic/chat/models_logic.go @@ -1,17 +1,33 @@ -// Code scaffolded by goctl. Safe to edit. -// goctl 1.10.1 - package chat import ( "context" + "sync" + "time" "cursor-api-proxy/internal/svc" - "cursor-api-proxy/internal/types" + apitypes "cursor-api-proxy/internal/types" + "cursor-api-proxy/pkg/domain/types" "github.com/zeromicro/go-zero/core/logx" ) +const modelCacheTTLMs = 5 * 60 * 1000 + +type ModelCache struct { + At int64 + Models []types.CursorCliModel +} + +type ModelCacheRef struct { + mu sync.Mutex + cache *ModelCache + inflight bool + waiters []chan struct{} +} + +var globalModelCache = &ModelCacheRef{} + type ModelsLogic struct { logx.Logger ctx context.Context @@ -26,8 +42,77 @@ func NewModelsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ModelsLogi } } -func (l *ModelsLogic) Models() (resp *types.ModelsResponse, err error) { - // todo: add your logic here and delete this line +func (l *ModelsLogic) Models() (resp *apitypes.ModelsResponse, err error) { + now := time.Now().UnixMilli() - return + globalModelCache.mu.Lock() + if globalModelCache.cache != nil && now-globalModelCache.cache.At <= modelCacheTTLMs { + cache := globalModelCache.cache + globalModelCache.mu.Unlock() + return buildModelsResponse(cache.Models), nil + } + + if globalModelCache.inflight { + ch := make(chan struct{}, 1) + globalModelCache.waiters = append(globalModelCache.waiters, ch) + globalModelCache.mu.Unlock() + <-ch + globalModelCache.mu.Lock() + cache := globalModelCache.cache + globalModelCache.mu.Unlock() + return buildModelsResponse(cache.Models), nil + } + + globalModelCache.inflight = true + globalModelCache.mu.Unlock() + + fetched, err := types.ListCursorCliModels(l.svcCtx.Config.AgentBin, l.svcCtx.Config.TimeoutMs) + + globalModelCache.mu.Lock() + globalModelCache.inflight = false + if err == nil { + globalModelCache.cache = &ModelCache{At: time.Now().UnixMilli(), Models: fetched} + } + waiters := globalModelCache.waiters + globalModelCache.waiters = nil + globalModelCache.mu.Unlock() + + for _, ch := range waiters { + ch <- struct{}{} + } + + if err != nil { + return nil, err + } + + return buildModelsResponse(fetched), nil +} + +func buildModelsResponse(mods []types.CursorCliModel) *apitypes.ModelsResponse { + models := make([]apitypes.ModelData, len(mods)) + for i, m := range mods { + models[i] = apitypes.ModelData{ + Id: m.ID, + Object: "model", + OwnedBy: "cursor", + } + } + + ids := make([]string, len(mods)) + for i, m := range mods { + ids[i] = m.ID + } + aliases := types.GetAnthropicModelAliases(ids) + for _, a := range aliases { + models = append(models, apitypes.ModelData{ + Id: a.ID, + Object: "model", + OwnedBy: "cursor", + }) + } + + return &apitypes.ModelsResponse{ + Object: "list", + Data: models, + } } diff --git a/pkg/domain/types/models.go b/pkg/domain/types/models.go index 8842664..7220832 100644 --- a/pkg/domain/types/models.go +++ b/pkg/domain/types/models.go @@ -1,30 +1,174 @@ package types -// Model mappings for Cursor API -var AnthropicToCursor = map[string]string{ - "claude-3-5-sonnet": "claude-3.5-sonnet", +import ( + "fmt" + "os" + "regexp" + "strings" + + "cursor-api-proxy/pkg/infrastructure/process" +) + +type CursorCliModel struct { + ID string + Name string +} + +type ModelAlias struct { + CursorID string + AnthropicID string + Name string +} + +var anthropicToCursor = map[string]string{ + "claude-opus-4-6": "opus-4.6", + "claude-opus-4.6": "opus-4.6", + "claude-sonnet-4-6": "sonnet-4.6", + "claude-sonnet-4.6": "sonnet-4.6", + "claude-opus-4-5": "opus-4.5", + "claude-opus-4.5": "opus-4.5", + "claude-sonnet-4-5": "sonnet-4.5", + "claude-sonnet-4.5": "sonnet-4.5", + "claude-opus-4": "opus-4.6", + "claude-sonnet-4": "sonnet-4.6", + "claude-haiku-4-5-20251001": "sonnet-4.5", + "claude-haiku-4-5": "sonnet-4.5", + "claude-haiku-4-6": "sonnet-4.6", + "claude-haiku-4": "sonnet-4.5", + "claude-opus-4-6-thinking": "opus-4.6-thinking", + "claude-sonnet-4-6-thinking": "sonnet-4.6-thinking", + "claude-opus-4-5-thinking": "opus-4.5-thinking", + "claude-sonnet-4-5-thinking": "sonnet-4.5-thinking", + "claude-3-5-sonnet": "claude-3.5-sonnet", "claude-3-5-sonnet-20241022": "claude-3.5-sonnet", - "claude-3-5-haiku": "claude-3.5-haiku", - "claude-3-opus": "claude-3-opus", - "claude-3-sonnet": "claude-3-sonnet", - "claude-3-haiku": "claude-3-haiku", + "claude-3-5-haiku": "claude-3.5-haiku", + "claude-3-opus": "claude-3-opus", + "claude-3-sonnet": "claude-3-sonnet", + "claude-3-haiku": "claude-3-haiku", } -// Cursor model aliases -var CursorModelAliases = []string{ - "auto", - "claude-3.5-sonnet", - "claude-3.5-haiku", - "claude-3-opus", - "gpt-4", - "gpt-4o", - "gemini-2.0-flash", +var cursorToAnthropicAlias = []ModelAlias{ + {"opus-4.6", "claude-opus-4-6", "Claude 4.6 Opus"}, + {"opus-4.6-thinking", "claude-opus-4-6-thinking", "Claude 4.6 Opus (Thinking)"}, + {"sonnet-4.6", "claude-sonnet-4-6", "Claude 4.6 Sonnet"}, + {"sonnet-4.6-thinking", "claude-sonnet-4-6-thinking", "Claude 4.6 Sonnet (Thinking)"}, + {"opus-4.5", "claude-opus-4-5", "Claude 4.5 Opus"}, + {"opus-4.5-thinking", "claude-opus-4-5-thinking", "Claude 4.5 Opus (Thinking)"}, + {"sonnet-4.5", "claude-sonnet-4-5", "Claude 4.5 Sonnet"}, + {"sonnet-4.5-thinking", "claude-sonnet-4-5-thinking", "Claude 4.5 Sonnet (Thinking)"}, } -// ResolveToCursorModel resolves a model name to Cursor model -func ResolveToCursorModel(model string) string { - if mapped, ok := AnthropicToCursor[model]; ok { +var modelLineRe = regexp.MustCompile(`^([A-Za-z0-9][A-Za-z0-9._:/-]*)\s+-\s+(.*)$`) +var trailingParenRe = regexp.MustCompile(`\s*\([^)]*\)\s*$`) +var cursorModelPattern = regexp.MustCompile(`^([a-zA-Z]+)-(\d+)\.(\d+)(-thinking)?$`) +var reverseDynamicPattern = regexp.MustCompile(`^claude-([a-zA-Z]+)-(\d+)-(\d+)(-thinking)?$`) + +type AnthropicAlias struct { + ID string + Name string +} + +func ParseCursorCliModels(output string) []CursorCliModel { + lines := strings.Split(output, "\n") + seen := make(map[string]CursorCliModel) + var order []string + + for _, line := range lines { + line = strings.TrimSpace(line) + m := modelLineRe.FindStringSubmatch(line) + if m == nil { + continue + } + id := m[1] + rawName := m[2] + name := strings.TrimSpace(trailingParenRe.ReplaceAllString(rawName, "")) + if name == "" { + name = id + } + if _, exists := seen[id]; !exists { + seen[id] = CursorCliModel{ID: id, Name: name} + order = append(order, id) + } + } + + result := make([]CursorCliModel, 0, len(order)) + for _, id := range order { + result = append(result, seen[id]) + } + return result +} + +func ListCursorCliModels(agentBin string, timeoutMs int) ([]CursorCliModel, error) { + tmpDir := os.TempDir() + result, err := process.Run(agentBin, []string{"--print-models_oneline"}, process.RunOptions{ + Cwd: tmpDir, + TimeoutMs: timeoutMs, + }) + if err != nil { + return nil, err + } + if result.Code != 0 { + return nil, fmt.Errorf("cursor cli failed: %s", result.Stderr) + } + return ParseCursorCliModels(result.Stdout), nil +} + +func generateDynamicAlias(cursorID string) (AnthropicAlias, bool) { + m := cursorModelPattern.FindStringSubmatch(cursorID) + if m == nil { + return AnthropicAlias{}, false + } + family := m[1] + major := m[2] + minor := m[3] + thinking := m[4] + + anthropicID := "claude-" + family + "-" + major + "-" + minor + thinking + capFamily := strings.ToUpper(family[:1]) + family[1:] + name := capFamily + " " + major + "." + minor + if thinking == "-thinking" { + name += " (Thinking)" + } + return AnthropicAlias{ID: anthropicID, Name: name}, true +} + +func reverseDynamicAlias(anthropicID string) (string, bool) { + m := reverseDynamicPattern.FindStringSubmatch(anthropicID) + if m == nil { + return "", false + } + return m[1] + "-" + m[2] + "." + m[3] + m[4], true +} + +func ResolveToCursorModel(requested string) string { + if mapped, ok := anthropicToCursor[requested]; ok { return mapped } - return model -} \ No newline at end of file + if cursorID, ok := reverseDynamicAlias(requested); ok { + return cursorID + } + return requested +} + +func GetAnthropicModelAliases(cursorIDs []string) []AnthropicAlias { + result := make([]AnthropicAlias, 0, len(cursorToAnthropicAlias)+len(cursorIDs)) + seen := make(map[string]bool) + + for _, a := range cursorToAnthropicAlias { + result = append(result, AnthropicAlias{ + ID: a.AnthropicID, + Name: a.Name, + }) + seen[a.CursorID] = true + } + + for _, id := range cursorIDs { + if seen[id] { + continue + } + if alias, ok := generateDynamicAlias(id); ok { + result = append(result, alias) + } + } + return result +}