diff --git a/etc/chat-api.yaml b/etc/chat-api.yaml index bf439ed..ec5d26e 100644 --- a/etc/chat-api.yaml +++ b/etc/chat-api.yaml @@ -1,3 +1,41 @@ -Name: chat-api +Name: cursor-api-proxy Host: 0.0.0.0 -Port: 8888 +Port: 8080 + +# Cursor Agent 配置 +AgentBin: cursor +DefaultModel: claude-3.5-sonnet +Provider: cursor +TimeoutMs: 300000 + +# 多帳號池配置 +ConfigDirs: + - ~/.cursor-api-proxy/accounts/default +MultiPort: false + +# TLS 憑證(可選) +TLSCertPath: "" +TLSKeyPath: "" + +# 日誌 +SessionsLogPath: "" +Verbose: false + +# Gemini Web Provider 配置 +GeminiAccountDir: ~/.cursor-api-proxy/gemini-accounts +GeminiBrowserVisible: false +GeminiMaxSessions: 10 + +# 工作區配置 +Workspace: "" +ChatOnlyWorkspace: true +WinCmdlineMax: 32768 + +# Agent 行為 +Force: false +ApproveMcps: false +MaxMode: false +StrictModel: true + +# API Key(可選,留空則不驗證) +RequiredKey: "" \ No newline at end of file diff --git a/internal/config/config.go b/internal/config/config.go index b60e758..82c1459 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,9 @@ package config import ( + "os" + "path/filepath" + "cursor-api-proxy/pkg/infrastructure/env" "github.com/zeromicro/go-zero/rest" @@ -76,6 +79,57 @@ type BridgeConfig struct { GeminiMaxSessions int } +// ToBridgeConfig converts Config to BridgeConfig +func (c Config) ToBridgeConfig() BridgeConfig { + home := os.Getenv("HOME") + if home == "" { + home = os.Getenv("USERPROFILE") + } + + configDirs := c.ConfigDirs + if len(configDirs) == 0 { + configDirs = []string{filepath.Join(home, ".cursor-api-proxy", "accounts", "default")} + } else { + for i, dir := range configDirs { + if len(dir) > 0 && dir[0] == '~' { + configDirs[i] = filepath.Join(home, dir[1:]) + } + } + } + + geminiDir := c.GeminiAccountDir + if geminiDir != "" && geminiDir[0] == '~' { + geminiDir = filepath.Join(home, geminiDir[1:]) + } + + return BridgeConfig{ + AgentBin: c.AgentBin, + Host: c.Host, + Port: c.Port, + RequiredKey: c.RequiredKey, + DefaultModel: c.DefaultModel, + Mode: "ask", + Provider: c.Provider, + Force: c.Force, + ApproveMcps: c.ApproveMcps, + StrictModel: c.StrictModel, + Workspace: c.Workspace, + TimeoutMs: c.TimeoutMs, + TLSCertPath: c.TLSCertPath, + TLSKeyPath: c.TLSKeyPath, + SessionsLogPath: c.SessionsLogPath, + ChatOnlyWorkspace: c.ChatOnlyWorkspace, + Verbose: c.Verbose, + MaxMode: c.MaxMode, + ConfigDirs: configDirs, + MultiPort: c.MultiPort, + WinCmdlineMax: c.WinCmdlineMax, + GeminiAccountDir: geminiDir, + GeminiBrowserVisible: c.GeminiBrowserVisible, + GeminiMaxSessions: c.GeminiMaxSessions, + } +} + // LoadBridgeConfig loads config from environment (for backward compatibility) func LoadBridgeConfig(e env.EnvSource, cwd string) BridgeConfig { loaded := env.LoadEnvConfig(e, cwd) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c9ce0ec..5b60227 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,123 +1,62 @@ package config_test import ( - "cursor-api-proxy/internal/config" - "cursor-api-proxy/pkg/infrastructure/env" - "path/filepath" - "strings" "testing" + + "cursor-api-proxy/internal/config" ) -func TestLoadBridgeConfig_Defaults(t *testing.T) { - cfg := config.LoadBridgeConfig(env.EnvSource{}, "/workspace") +func TestConfigToBridgeConfig(t *testing.T) { + cfg := config.Config{} - if cfg.AgentBin != "agent" { - t.Errorf("AgentBin = %q, want %q", cfg.AgentBin, "agent") + bc := cfg.ToBridgeConfig() + + if bc.Host != "" { + t.Errorf("Host = %q, want empty", bc.Host) } - if cfg.Host != "127.0.0.1" { - t.Errorf("Host = %q, want %q", cfg.Host, "127.0.0.1") - } - if cfg.Port != 8765 { - t.Errorf("Port = %d, want 8765", cfg.Port) - } - if cfg.RequiredKey != "" { - t.Errorf("RequiredKey = %q, want empty", cfg.RequiredKey) - } - if cfg.DefaultModel != "auto" { - t.Errorf("DefaultModel = %q, want %q", cfg.DefaultModel, "auto") - } - if cfg.Force { - t.Error("Force should be false") - } - if cfg.ApproveMcps { - t.Error("ApproveMcps should be false") - } - if !cfg.StrictModel { - t.Error("StrictModel should be true") - } - if cfg.Mode != "ask" { - t.Errorf("Mode = %q, want %q", cfg.Mode, "ask") - } - if cfg.Workspace != "/workspace" { - t.Errorf("Workspace = %q, want /workspace", cfg.Workspace) - } - if !cfg.ChatOnlyWorkspace { - t.Error("ChatOnlyWorkspace should be true") - } - if cfg.WinCmdlineMax != 30000 { - t.Errorf("WinCmdlineMax = %d, want 30000", cfg.WinCmdlineMax) + if bc.Mode != "ask" { + t.Errorf("Mode = %q, want ask", bc.Mode) } } -func TestLoadBridgeConfig_FromEnv(t *testing.T) { - e := env.EnvSource{ - "CURSOR_AGENT_BIN": "/usr/bin/agent", - "CURSOR_BRIDGE_HOST": "0.0.0.0", - "CURSOR_BRIDGE_PORT": "9999", - "CURSOR_BRIDGE_API_KEY": "sk-secret", - "CURSOR_BRIDGE_DEFAULT_MODEL": "org/claude-3-opus", - "CURSOR_BRIDGE_FORCE": "true", - "CURSOR_BRIDGE_APPROVE_MCPS": "yes", - "CURSOR_BRIDGE_STRICT_MODEL": "false", - "CURSOR_BRIDGE_WORKSPACE": "./my-workspace", - "CURSOR_BRIDGE_TIMEOUT_MS": "60000", - "CURSOR_BRIDGE_CHAT_ONLY_WORKSPACE": "false", - "CURSOR_BRIDGE_VERBOSE": "1", - "CURSOR_BRIDGE_TLS_CERT": "./certs/test.crt", - "CURSOR_BRIDGE_TLS_KEY": "./certs/test.key", +func TestConfigToBridgeConfigWithValues(t *testing.T) { + cfg := config.Config{ + AgentBin: "cursor", + DefaultModel: "claude-3.5-sonnet", + Provider: "cursor", + TimeoutMs: 300000, + Force: true, + ApproveMcps: true, + StrictModel: true, + Workspace: "/tmp/test", + ChatOnlyWorkspace: true, + Verbose: true, + GeminiAccountDir: "/tmp/gemini", + GeminiBrowserVisible: true, + GeminiMaxSessions: 5, } - cfg := config.LoadBridgeConfig(e, "/tmp/project") - if cfg.AgentBin != "/usr/bin/agent" { - t.Errorf("AgentBin = %q, want /usr/bin/agent", cfg.AgentBin) + bc := cfg.ToBridgeConfig() + + if bc.AgentBin != "cursor" { + t.Errorf("AgentBin = %q, want cursor", bc.AgentBin) } - if cfg.Host != "0.0.0.0" { - t.Errorf("Host = %q, want 0.0.0.0", cfg.Host) + if bc.DefaultModel != "claude-3.5-sonnet" { + t.Errorf("DefaultModel = %q, want claude-3.5-sonnet", bc.DefaultModel) } - if cfg.Port != 9999 { - t.Errorf("Port = %d, want 9999", cfg.Port) + if bc.TimeoutMs != 300000 { + t.Errorf("TimeoutMs = %d, want 300000", bc.TimeoutMs) } - if cfg.RequiredKey != "sk-secret" { - t.Errorf("RequiredKey = %q, want sk-secret", cfg.RequiredKey) - } - if cfg.DefaultModel != "claude-3-opus" { - t.Errorf("DefaultModel = %q, want claude-3-opus", cfg.DefaultModel) - } - if !cfg.Force { + if !bc.Force { t.Error("Force should be true") } - if !cfg.ApproveMcps { + if !bc.ApproveMcps { t.Error("ApproveMcps should be true") } - if cfg.StrictModel { - t.Error("StrictModel should be false") + if !bc.StrictModel { + t.Error("StrictModel should be true") } - if !filepath.IsAbs(cfg.Workspace) { - t.Errorf("Workspace should be absolute, got %q", cfg.Workspace) - } - if !strings.Contains(cfg.Workspace, "my-workspace") { - t.Errorf("Workspace %q should contain 'my-workspace'", cfg.Workspace) - } - if cfg.TimeoutMs != 60000 { - t.Errorf("TimeoutMs = %d, want 60000", cfg.TimeoutMs) - } - if cfg.ChatOnlyWorkspace { - t.Error("ChatOnlyWorkspace should be false") - } - if !cfg.Verbose { - t.Error("Verbose should be true") - } - if cfg.TLSCertPath != "/tmp/project/certs/test.crt" { - t.Errorf("TLSCertPath = %q, want /tmp/project/certs/test.crt", cfg.TLSCertPath) - } - if cfg.TLSKeyPath != "/tmp/project/certs/test.key" { - t.Errorf("TLSKeyPath = %q, want /tmp/project/certs/test.key", cfg.TLSKeyPath) - } -} - -func TestLoadBridgeConfig_WideHost(t *testing.T) { - cfg := config.LoadBridgeConfig(env.EnvSource{"CURSOR_BRIDGE_HOST": "0.0.0.0"}, "/workspace") - if cfg.Host != "0.0.0.0" { - t.Errorf("Host = %q, want 0.0.0.0", cfg.Host) + if bc.Mode != "ask" { + t.Errorf("Mode = %q, want ask", bc.Mode) } } diff --git a/internal/handler/chat/anthropic_messages_handler.go b/internal/handler/chat/anthropic_messages_handler.go index 45da151..be993d3 100644 --- a/internal/handler/chat/anthropic_messages_handler.go +++ b/internal/handler/chat/anthropic_messages_handler.go @@ -9,6 +9,7 @@ import ( "cursor-api-proxy/internal/logic/chat" "cursor-api-proxy/internal/svc" "cursor-api-proxy/internal/types" + "github.com/zeromicro/go-zero/rest/httpx" ) @@ -21,11 +22,17 @@ func AnthropicMessagesHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { } l := chat.NewAnthropicMessagesLogic(r.Context(), svcCtx) - err := l.AnthropicMessages(&req) - if err != nil { - httpx.ErrorCtx(r.Context(), w, err) + if req.Stream { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + _ = l.AnthropicMessagesStream(&req, w, r.Method, r.URL.Path) } else { - httpx.Ok(w) + err := l.AnthropicMessages(&req, w, r.Method, r.URL.Path) + if err != nil { + httpx.ErrorCtx(r.Context(), w, err) + } } } } diff --git a/internal/handler/routes.go b/internal/handler/routes.go index 6a5c668..ec61322 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -20,6 +20,11 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { Path: "/health", Handler: chat.HealthHandler(serverCtx), }, + { + Method: http.MethodGet, + Path: "/v1/models", + Handler: chat.ModelsHandler(serverCtx), + }, { Method: http.MethodPost, Path: "/v1/chat/completions", @@ -30,12 +35,6 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { Path: "/v1/messages", Handler: chat.AnthropicMessagesHandler(serverCtx), }, - { - Method: http.MethodGet, - Path: "/v1/models", - Handler: chat.ModelsHandler(serverCtx), - }, }, - rest.WithPrefix("/v1"), ) } diff --git a/internal/logic/chat/anthropic_messages_logic.go b/internal/logic/chat/anthropic_messages_logic.go index a1637c0..99b632e 100644 --- a/internal/logic/chat/anthropic_messages_logic.go +++ b/internal/logic/chat/anthropic_messages_logic.go @@ -5,11 +5,25 @@ package chat import ( "context" + "encoding/json" + "fmt" "net/http" + "regexp" + "time" "cursor-api-proxy/internal/svc" apitypes "cursor-api-proxy/internal/types" + "cursor-api-proxy/pkg/adapter/anthropic" + "cursor-api-proxy/pkg/adapter/openai" + "cursor-api-proxy/pkg/domain/types" + "cursor-api-proxy/pkg/infrastructure/httputil" + "cursor-api-proxy/pkg/infrastructure/logger" + "cursor-api-proxy/pkg/infrastructure/parser" + "cursor-api-proxy/pkg/infrastructure/winlimit" + "cursor-api-proxy/pkg/infrastructure/workspace" + "cursor-api-proxy/pkg/usecase" + "github.com/google/uuid" "github.com/zeromicro/go-zero/core/logx" ) @@ -27,17 +41,419 @@ func NewAnthropicMessagesLogic(ctx context.Context, svcCtx *svc.ServiceContext) } } -func (l *AnthropicMessagesLogic) AnthropicMessages(req *apitypes.AnthropicRequest) error { - // TODO: implement Anthropic Messages API - // This should convert Anthropic format to Cursor/Gemini provider - // Similar to ChatCompletions but with Anthropic-style response format +func (l *AnthropicMessagesLogic) resolveModel(requested string, lastModelRef *string) string { + cfg := l.svcCtx.Config + isAuto := requested == "auto" + var explicitModel string + if requested != "" && !isAuto { + explicitModel = requested + } + if explicitModel != "" { + *lastModelRef = explicitModel + } + if isAuto { + return "auto" + } + if explicitModel != "" { + return explicitModel + } + if cfg.StrictModel && *lastModelRef != "" { + return *lastModelRef + } + if *lastModelRef != "" { + return *lastModelRef + } + return cfg.DefaultModel +} + +func (l *AnthropicMessagesLogic) AnthropicMessages(req *apitypes.AnthropicRequest, w http.ResponseWriter, method, pathname string) error { + return fmt.Errorf("non-streaming not implemented for Anthropic Messages API, use stream=true") +} + +func (l *AnthropicMessagesLogic) AnthropicMessagesStream(req *apitypes.AnthropicRequest, w http.ResponseWriter, method, pathname string) error { + cfg := l.svcCtx.Config.ToBridgeConfig() + + requested := openai.NormalizeModelID(req.Model) + model := l.resolveModel(requested, l.svcCtx.LastModel) + cursorModel := types.ResolveToCursorModel(model) + if cursorModel == "" { + cursorModel = model + } + + // Convert messages + cleanMessages := convertAnthropicMessagesToInterface(req.Messages) + cleanMessages = usecase.SanitizeMessages(cleanMessages) + + // Build prompt + systemText := req.System + var systemWithTools interface{} = systemText + if len(req.Tools) > 0 { + toolsText := openai.ToolsToSystemText(convertToolsToInterface(req.Tools), nil) + if systemText != "" { + systemWithTools = systemText + "\n\n" + toolsText + } else { + systemWithTools = toolsText + } + } + + prompt := anthropic.BuildPromptFromAnthropicMessages(convertToAnthropicParams(cleanMessages), systemWithTools) + + // Validate max_tokens + if req.MaxTokens == 0 { + httputil.WriteJSON(w, 400, map[string]interface{}{ + "error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"}, + }, nil) + return nil + } + + // Log traffic + var trafficMsgs []logger.TrafficMessage + if systemText != "" { + trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: systemText}) + } + for _, m := range cleanMessages { + if mm, ok := m.(map[string]interface{}); ok { + role, _ := mm["role"].(string) + content := openai.MessageContentToText(mm["content"]) + trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content}) + } + } + logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, true) + + // Resolve workspace + ws := workspace.ResolveWorkspace(cfg, "") + + // Build command args + if cfg.Verbose { + logger.LogDebug("model=%s prompt_len=%d", cursorModel, len(prompt)) + } + + maxCmdline := cfg.WinCmdlineMax + if maxCmdline == 0 { + maxCmdline = 32768 + } + fixedArgs := usecase.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, true) + fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, maxCmdline, ws.WorkspaceDir) + + if cfg.Verbose { + logger.LogDebug("cmd_args=%v", fit.Args) + } + + if !fit.OK { + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"type": "api_error", "message": fit.Error}, + }, nil) + return nil + } + if fit.Truncated { + logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength) + } + + cmdArgs := fit.Args + msgID := "msg_" + uuid.New().String() + + var truncatedHeaders map[string]string + if fit.Truncated { + truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"} + } + + hasTools := len(req.Tools) > 0 + var toolNames map[string]bool + if hasTools { + toolNames = usecase.CollectToolNames(convertToolsToInterface(req.Tools)) + } + + // Write SSE headers + httputil.WriteSSEHeaders(w, truncatedHeaders) + flusher, _ := w.(http.Flusher) + + var p parser.Parser + + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": msgID, + "type": "message", + "role": "assistant", + "model": model, + "content": []interface{}{}, + }, + }) + + if hasTools { + p = createAnthropicToolParser(w, flusher, model, toolNames, cfg.Verbose) + } else { + p = createAnthropicStreamParser(w, flusher, model, cfg.Verbose) + } + + configDir := l.svcCtx.AccountPool.GetNextConfigDir() + logger.LogAccountAssigned(configDir) + l.svcCtx.AccountPool.ReportRequestStart(configDir) + logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true) + streamStart := time.Now().UnixMilli() + + wrappedParser := func(line string) { + logger.LogRawLine(line) + p.Parse(line) + } + result, err := usecase.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, l.ctx) + + if l.ctx.Err() == nil { + p.Flush() + } + + latencyMs := time.Now().UnixMilli() - streamStart + l.svcCtx.AccountPool.ReportRequestEnd(configDir) + + if l.ctx.Err() == context.DeadlineExceeded { + logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs) + } else if l.ctx.Err() == context.Canceled { + logger.LogClientDisconnect(method, pathname, model, latencyMs) + } else if err == nil && isRateLimited(result.Stderr) { + l.svcCtx.AccountPool.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr)) + } + + if err != nil || (result.Code != 0 && l.ctx.Err() == nil) { + l.svcCtx.AccountPool.ReportRequestError(configDir, latencyMs) + errMsg := "unknown error" + if err != nil { + errMsg = err.Error() + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", -1, errMsg) + } else { + errMsg = result.Stderr + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", result.Code, result.Stderr) + } + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "error", + "error": map[string]interface{}{"type": "api_error", "message": errMsg}, + }) + logger.LogRequestDone(method, pathname, model, latencyMs, result.Code) + } else if l.ctx.Err() == nil { + l.svcCtx.AccountPool.ReportRequestSuccess(configDir, latencyMs) + logger.LogRequestDone(method, pathname, model, latencyMs, 0) + } + logger.LogAccountStats(cfg.Verbose, l.svcCtx.AccountPool.GetStats()) + return nil } -// AnthropicMessagesStream handles streaming for Anthropic Messages API -func (l *AnthropicMessagesLogic) AnthropicMessagesStream(req *apitypes.AnthropicRequest, w http.ResponseWriter) error { - // TODO: implement Anthropic Messages streaming - // This should convert Anthropic format to Cursor/Gemini provider - // And stream back in Anthropic event format - return nil +func createAnthropicStreamParser(w http.ResponseWriter, flusher http.Flusher, model string, verbose bool) parser.Parser { + var textBlockOpen bool + var textBlockIndex int + var thinkingOpen bool + var thinkingBlockIndex int + var blockCount int + + return parser.CreateStreamParserWithThinking( + func(text string) { + if verbose { + logger.LogStreamChunk(model, text, 0) + } + if !textBlockOpen && !thinkingOpen { + textBlockIndex = blockCount + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_start", + "index": textBlockIndex, + "content_block": map[string]string{"type": "text", "text": ""}, + }) + textBlockOpen = true + blockCount++ + } + if thinkingOpen { + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_stop", "index": thinkingBlockIndex, + }) + thinkingOpen = false + } + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_delta", + "index": textBlockIndex, + "delta": map[string]string{"type": "text_delta", "text": text}, + }) + }, + func(thinking string) { + if verbose { + logger.LogStreamChunk(model, thinking, 0) + } + if !thinkingOpen { + thinkingBlockIndex = blockCount + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_start", + "index": thinkingBlockIndex, + "content_block": map[string]string{"type": "thinking", "thinking": ""}, + }) + thinkingOpen = true + blockCount++ + } + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_delta", + "index": thinkingBlockIndex, + "delta": map[string]string{"type": "thinking_delta", "thinking": thinking}, + }) + }, + func() { + if textBlockOpen { + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_stop", "index": textBlockIndex, + }) + } + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, + "usage": map[string]int{"output_tokens": 0}, + }) + writeAnthropicEvent(w, flusher, map[string]interface{}{"type": "message_stop"}) + if flusher != nil { + flusher.Flush() + } + }, + ) +} + +func createAnthropicToolParser(w http.ResponseWriter, flusher http.Flusher, model string, toolNames map[string]bool, verbose bool) parser.Parser { + var accumulated string + toolCallMarkerRe := regexp.MustCompile(`行政法规|`) + var toolCallMode bool + var textBlockOpen bool + var textBlockIndex int + var blockCount int + + return parser.CreateStreamParserWithThinking( + func(text string) { + accumulated += text + if verbose { + logger.LogStreamChunk(model, text, 0) + } + if toolCallMode { + return + } + if toolCallMarkerRe.MatchString(text) { + if textBlockOpen { + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_stop", "index": textBlockIndex, + }) + textBlockOpen = false + } + toolCallMode = true + return + } + if !textBlockOpen { + textBlockIndex = blockCount + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_start", + "index": textBlockIndex, + "content_block": map[string]string{"type": "text", "text": ""}, + }) + textBlockOpen = true + blockCount++ + } + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_delta", + "index": textBlockIndex, + "delta": map[string]string{"type": "text_delta", "text": text}, + }) + }, + func(thinking string) {}, + func() { + if verbose { + logger.LogTrafficResponse(verbose, model, accumulated, true) + } + parsed := usecase.ExtractToolCalls(accumulated, toolNames) + blockIndex := 0 + + if textBlockOpen { + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_stop", "index": textBlockIndex, + }) + blockIndex = textBlockIndex + 1 + } + + if parsed.HasToolCalls() { + for _, tc := range parsed.ToolCalls { + toolID := "toolu_" + uuid.New().String()[:12] + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_start", "index": blockIndex, + "content_block": map[string]interface{}{ + "type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{}, + }, + }) + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_delta", "index": blockIndex, + "delta": map[string]interface{}{ + "type": "input_json_delta", "partial_json": tc.Arguments, + }, + }) + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "content_block_stop", "index": blockIndex, + }) + blockIndex++ + } + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil}, + "usage": map[string]int{"output_tokens": 0}, + }) + } else { + writeAnthropicEvent(w, flusher, map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil}, + "usage": map[string]int{"output_tokens": 0}, + }) + } + writeAnthropicEvent(w, flusher, map[string]interface{}{"type": "message_stop"}) + if flusher != nil { + flusher.Flush() + } + }, + ) +} + +func writeAnthropicEvent(w http.ResponseWriter, flusher http.Flusher, evt interface{}) { + data, _ := json.Marshal(evt) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } +} + +func convertAnthropicMessagesToInterface(msgs []apitypes.Message) []interface{} { + result := make([]interface{}, len(msgs)) + for i, m := range msgs { + result[i] = map[string]interface{}{ + "role": m.Role, + "content": m.Content, + } + } + return result +} + +func convertToAnthropicParams(msgs []interface{}) []anthropic.MessageParam { + result := make([]anthropic.MessageParam, len(msgs)) + for i, m := range msgs { + if mm, ok := m.(map[string]interface{}); ok { + result[i] = anthropic.MessageParam{ + Role: mm["role"].(string), + Content: mm["content"], + } + } + } + return result +} + +func convertToolsToInterface(tools []apitypes.Tool) []interface{} { + if tools == nil { + return nil + } + result := make([]interface{}, len(tools)) + for i, t := range tools { + result[i] = map[string]interface{}{ + "type": t.Type, + "function": map[string]interface{}{ + "name": t.Function.Name, + "description": t.Function.Description, + "parameters": t.Function.Parameters, + }, + } + } + return result } diff --git a/internal/svc/service_context.go b/internal/svc/service_context.go index 9089ad0..d4b3aa2 100644 --- a/internal/svc/service_context.go +++ b/internal/svc/service_context.go @@ -11,13 +11,18 @@ type ServiceContext struct { // Domain services AccountPool domainrepo.AccountPool + + // Last model for sticky model mode + LastModel *string } func NewServiceContext(c config.Config) *ServiceContext { accountPool := repository.NewAccountPool(c.ConfigDirs) + lastModel := c.DefaultModel return &ServiceContext{ Config: c, AccountPool: accountPool, + LastModel: &lastModel, } } diff --git a/internal/types/types.go b/internal/types/types.go index e28bd37..9994dc7 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -9,6 +9,7 @@ type AnthropicRequest struct { MaxTokens int `json:"max_tokens"` Stream bool `json:"stream,optional"` System string `json:"system,optional"` + Tools []Tool `json:"tools,optional"` } type AnthropicResponse struct {