opencode-cursor-agent/internal/logic/chat/anthropic_messages_logic.go

460 lines
13 KiB
Go

// Code scaffolded by goctl. Safe to edit.
// goctl 1.10.1
package chat
import (
"context"
"encoding/json"
"fmt"
"net/http"
"regexp"
"time"
"cursor-api-proxy/internal/svc"
apitypes "cursor-api-proxy/internal/types"
"cursor-api-proxy/pkg/adapter/anthropic"
"cursor-api-proxy/pkg/adapter/openai"
"cursor-api-proxy/pkg/domain/types"
"cursor-api-proxy/pkg/infrastructure/httputil"
"cursor-api-proxy/pkg/infrastructure/logger"
"cursor-api-proxy/pkg/infrastructure/parser"
"cursor-api-proxy/pkg/infrastructure/winlimit"
"cursor-api-proxy/pkg/infrastructure/workspace"
"cursor-api-proxy/pkg/usecase"
"github.com/google/uuid"
"github.com/zeromicro/go-zero/core/logx"
)
type AnthropicMessagesLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewAnthropicMessagesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AnthropicMessagesLogic {
return &AnthropicMessagesLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *AnthropicMessagesLogic) resolveModel(requested string, lastModelRef *string) string {
cfg := l.svcCtx.Config
isAuto := requested == "auto"
var explicitModel string
if requested != "" && !isAuto {
explicitModel = requested
}
if explicitModel != "" {
*lastModelRef = explicitModel
}
if isAuto {
return "auto"
}
if explicitModel != "" {
return explicitModel
}
if cfg.StrictModel && *lastModelRef != "" {
return *lastModelRef
}
if *lastModelRef != "" {
return *lastModelRef
}
return cfg.DefaultModel
}
func (l *AnthropicMessagesLogic) AnthropicMessages(req *apitypes.AnthropicRequest, w http.ResponseWriter, method, pathname string) error {
return fmt.Errorf("non-streaming not implemented for Anthropic Messages API, use stream=true")
}
func (l *AnthropicMessagesLogic) AnthropicMessagesStream(req *apitypes.AnthropicRequest, w http.ResponseWriter, method, pathname string) error {
cfg := l.svcCtx.Config.ToBridgeConfig()
requested := openai.NormalizeModelID(req.Model)
model := l.resolveModel(requested, l.svcCtx.LastModel)
cursorModel := types.ResolveToCursorModel(model)
if cursorModel == "" {
cursorModel = model
}
// Convert messages
cleanMessages := convertAnthropicMessagesToInterface(req.Messages)
cleanMessages = usecase.SanitizeMessages(cleanMessages)
// Build prompt
systemText := req.System
var systemWithTools interface{} = systemText
if len(req.Tools) > 0 {
toolsText := openai.ToolsToSystemText(convertToolsToInterface(req.Tools), nil)
if systemText != "" {
systemWithTools = systemText + "\n\n" + toolsText
} else {
systemWithTools = toolsText
}
}
prompt := anthropic.BuildPromptFromAnthropicMessages(convertToAnthropicParams(cleanMessages), systemWithTools)
// Validate max_tokens
if req.MaxTokens == 0 {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"},
}, nil)
return nil
}
// Log traffic
var trafficMsgs []logger.TrafficMessage
if systemText != "" {
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: systemText})
}
for _, m := range cleanMessages {
if mm, ok := m.(map[string]interface{}); ok {
role, _ := mm["role"].(string)
content := openai.MessageContentToText(mm["content"])
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
}
}
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, true)
// Resolve workspace
ws := workspace.ResolveWorkspace(cfg, "")
// Build command args
if cfg.Verbose {
logger.LogDebug("model=%s prompt_len=%d", cursorModel, len(prompt))
}
maxCmdline := cfg.WinCmdlineMax
if maxCmdline == 0 {
maxCmdline = 32768
}
fixedArgs := usecase.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, true)
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, maxCmdline, ws.WorkspaceDir)
if cfg.Verbose {
logger.LogDebug("cmd_args=%v", fit.Args)
}
if !fit.OK {
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"type": "api_error", "message": fit.Error},
}, nil)
return nil
}
if fit.Truncated {
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
}
cmdArgs := fit.Args
msgID := "msg_" + uuid.New().String()
var truncatedHeaders map[string]string
if fit.Truncated {
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
}
hasTools := len(req.Tools) > 0
var toolNames map[string]bool
if hasTools {
toolNames = usecase.CollectToolNames(convertToolsToInterface(req.Tools))
}
// Write SSE headers
httputil.WriteSSEHeaders(w, truncatedHeaders)
flusher, _ := w.(http.Flusher)
var p parser.Parser
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []interface{}{},
},
})
if hasTools {
p = createAnthropicToolParser(w, flusher, model, toolNames, cfg.Verbose)
} else {
p = createAnthropicStreamParser(w, flusher, model, cfg.Verbose)
}
configDir := l.svcCtx.AccountPool.GetNextConfigDir()
logger.LogAccountAssigned(configDir)
l.svcCtx.AccountPool.ReportRequestStart(configDir)
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
streamStart := time.Now().UnixMilli()
wrappedParser := func(line string) {
logger.LogRawLine(line)
p.Parse(line)
}
result, err := usecase.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, l.ctx)
if l.ctx.Err() == nil {
p.Flush()
}
latencyMs := time.Now().UnixMilli() - streamStart
l.svcCtx.AccountPool.ReportRequestEnd(configDir)
if l.ctx.Err() == context.DeadlineExceeded {
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
} else if l.ctx.Err() == context.Canceled {
logger.LogClientDisconnect(method, pathname, model, latencyMs)
} else if err == nil && isRateLimited(result.Stderr) {
l.svcCtx.AccountPool.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
}
if err != nil || (result.Code != 0 && l.ctx.Err() == nil) {
l.svcCtx.AccountPool.ReportRequestError(configDir, latencyMs)
errMsg := "unknown error"
if err != nil {
errMsg = err.Error()
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", -1, errMsg)
} else {
errMsg = result.Stderr
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", result.Code, result.Stderr)
}
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "error",
"error": map[string]interface{}{"type": "api_error", "message": errMsg},
})
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
} else if l.ctx.Err() == nil {
l.svcCtx.AccountPool.ReportRequestSuccess(configDir, latencyMs)
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
}
logger.LogAccountStats(cfg.Verbose, l.svcCtx.AccountPool.GetStats())
return nil
}
func createAnthropicStreamParser(w http.ResponseWriter, flusher http.Flusher, model string, verbose bool) parser.Parser {
var textBlockOpen bool
var textBlockIndex int
var thinkingOpen bool
var thinkingBlockIndex int
var blockCount int
return parser.CreateStreamParserWithThinking(
func(text string) {
if verbose {
logger.LogStreamChunk(model, text, 0)
}
if !textBlockOpen && !thinkingOpen {
textBlockIndex = blockCount
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_start",
"index": textBlockIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
textBlockOpen = true
blockCount++
}
if thinkingOpen {
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_stop", "index": thinkingBlockIndex,
})
thinkingOpen = false
}
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_delta",
"index": textBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": text},
})
},
func(thinking string) {
if verbose {
logger.LogStreamChunk(model, thinking, 0)
}
if !thinkingOpen {
thinkingBlockIndex = blockCount
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_start",
"index": thinkingBlockIndex,
"content_block": map[string]string{"type": "thinking", "thinking": ""},
})
thinkingOpen = true
blockCount++
}
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_delta",
"index": thinkingBlockIndex,
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
})
},
func() {
if textBlockOpen {
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_stop", "index": textBlockIndex,
})
}
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
writeAnthropicEvent(w, flusher, map[string]interface{}{"type": "message_stop"})
if flusher != nil {
flusher.Flush()
}
},
)
}
func createAnthropicToolParser(w http.ResponseWriter, flusher http.Flusher, model string, toolNames map[string]bool, verbose bool) parser.Parser {
var accumulated string
toolCallMarkerRe := regexp.MustCompile(`行政法规|<function_calls>`)
var toolCallMode bool
var textBlockOpen bool
var textBlockIndex int
var blockCount int
return parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
if verbose {
logger.LogStreamChunk(model, text, 0)
}
if toolCallMode {
return
}
if toolCallMarkerRe.MatchString(text) {
if textBlockOpen {
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_stop", "index": textBlockIndex,
})
textBlockOpen = false
}
toolCallMode = true
return
}
if !textBlockOpen {
textBlockIndex = blockCount
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_start",
"index": textBlockIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
textBlockOpen = true
blockCount++
}
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_delta",
"index": textBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": text},
})
},
func(thinking string) {},
func() {
if verbose {
logger.LogTrafficResponse(verbose, model, accumulated, true)
}
parsed := usecase.ExtractToolCalls(accumulated, toolNames)
blockIndex := 0
if textBlockOpen {
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_stop", "index": textBlockIndex,
})
blockIndex = textBlockIndex + 1
}
if parsed.HasToolCalls() {
for _, tc := range parsed.ToolCalls {
toolID := "toolu_" + uuid.New().String()[:12]
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_start", "index": blockIndex,
"content_block": map[string]interface{}{
"type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{},
},
})
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_delta", "index": blockIndex,
"delta": map[string]interface{}{
"type": "input_json_delta", "partial_json": tc.Arguments,
},
})
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "content_block_stop", "index": blockIndex,
})
blockIndex++
}
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
} else {
writeAnthropicEvent(w, flusher, map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": 0},
})
}
writeAnthropicEvent(w, flusher, map[string]interface{}{"type": "message_stop"})
if flusher != nil {
flusher.Flush()
}
},
)
}
func writeAnthropicEvent(w http.ResponseWriter, flusher http.Flusher, evt interface{}) {
data, _ := json.Marshal(evt)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
}
func convertAnthropicMessagesToInterface(msgs []apitypes.Message) []interface{} {
result := make([]interface{}, len(msgs))
for i, m := range msgs {
result[i] = map[string]interface{}{
"role": m.Role,
"content": m.Content,
}
}
return result
}
func convertToAnthropicParams(msgs []interface{}) []anthropic.MessageParam {
result := make([]anthropic.MessageParam, len(msgs))
for i, m := range msgs {
if mm, ok := m.(map[string]interface{}); ok {
result[i] = anthropic.MessageParam{
Role: mm["role"].(string),
Content: mm["content"],
}
}
}
return result
}
func convertToolsToInterface(tools []apitypes.Tool) []interface{} {
if tools == nil {
return nil
}
result := make([]interface{}, len(tools))
for i, t := range tools {
result[i] = map[string]interface{}{
"type": t.Type,
"function": map[string]interface{}{
"name": t.Function.Name,
"description": t.Function.Description,
"parameters": t.Function.Parameters,
},
}
}
return result
}