460 lines
13 KiB
Go
460 lines
13 KiB
Go
// Code scaffolded by goctl. Safe to edit.
|
|
// goctl 1.10.1
|
|
|
|
package chat
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"regexp"
|
|
"time"
|
|
|
|
"cursor-api-proxy/internal/svc"
|
|
apitypes "cursor-api-proxy/internal/types"
|
|
"cursor-api-proxy/pkg/adapter/anthropic"
|
|
"cursor-api-proxy/pkg/adapter/openai"
|
|
"cursor-api-proxy/pkg/domain/types"
|
|
"cursor-api-proxy/pkg/infrastructure/httputil"
|
|
"cursor-api-proxy/pkg/infrastructure/logger"
|
|
"cursor-api-proxy/pkg/infrastructure/parser"
|
|
"cursor-api-proxy/pkg/infrastructure/winlimit"
|
|
"cursor-api-proxy/pkg/infrastructure/workspace"
|
|
"cursor-api-proxy/pkg/usecase"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
)
|
|
|
|
type AnthropicMessagesLogic struct {
|
|
logx.Logger
|
|
ctx context.Context
|
|
svcCtx *svc.ServiceContext
|
|
}
|
|
|
|
func NewAnthropicMessagesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AnthropicMessagesLogic {
|
|
return &AnthropicMessagesLogic{
|
|
Logger: logx.WithContext(ctx),
|
|
ctx: ctx,
|
|
svcCtx: svcCtx,
|
|
}
|
|
}
|
|
|
|
func (l *AnthropicMessagesLogic) resolveModel(requested string, lastModelRef *string) string {
|
|
cfg := l.svcCtx.Config
|
|
isAuto := requested == "auto"
|
|
var explicitModel string
|
|
if requested != "" && !isAuto {
|
|
explicitModel = requested
|
|
}
|
|
if explicitModel != "" {
|
|
*lastModelRef = explicitModel
|
|
}
|
|
if isAuto {
|
|
return "auto"
|
|
}
|
|
if explicitModel != "" {
|
|
return explicitModel
|
|
}
|
|
if cfg.StrictModel && *lastModelRef != "" {
|
|
return *lastModelRef
|
|
}
|
|
if *lastModelRef != "" {
|
|
return *lastModelRef
|
|
}
|
|
return cfg.DefaultModel
|
|
}
|
|
|
|
func (l *AnthropicMessagesLogic) AnthropicMessages(req *apitypes.AnthropicRequest, w http.ResponseWriter, method, pathname string) error {
|
|
return fmt.Errorf("non-streaming not implemented for Anthropic Messages API, use stream=true")
|
|
}
|
|
|
|
func (l *AnthropicMessagesLogic) AnthropicMessagesStream(req *apitypes.AnthropicRequest, w http.ResponseWriter, method, pathname string) error {
|
|
cfg := l.svcCtx.Config.ToBridgeConfig()
|
|
|
|
requested := openai.NormalizeModelID(req.Model)
|
|
model := l.resolveModel(requested, l.svcCtx.LastModel)
|
|
cursorModel := types.ResolveToCursorModel(model)
|
|
if cursorModel == "" {
|
|
cursorModel = model
|
|
}
|
|
|
|
// Convert messages
|
|
cleanMessages := convertAnthropicMessagesToInterface(req.Messages)
|
|
cleanMessages = usecase.SanitizeMessages(cleanMessages)
|
|
|
|
// Build prompt
|
|
systemText := req.System
|
|
var systemWithTools interface{} = systemText
|
|
if len(req.Tools) > 0 {
|
|
toolsText := openai.ToolsToSystemText(convertToolsToInterface(req.Tools), nil)
|
|
if systemText != "" {
|
|
systemWithTools = systemText + "\n\n" + toolsText
|
|
} else {
|
|
systemWithTools = toolsText
|
|
}
|
|
}
|
|
|
|
prompt := anthropic.BuildPromptFromAnthropicMessages(convertToAnthropicParams(cleanMessages), systemWithTools)
|
|
|
|
// Validate max_tokens
|
|
if req.MaxTokens == 0 {
|
|
httputil.WriteJSON(w, 400, map[string]interface{}{
|
|
"error": map[string]string{"type": "invalid_request_error", "message": "max_tokens is required"},
|
|
}, nil)
|
|
return nil
|
|
}
|
|
|
|
// Log traffic
|
|
var trafficMsgs []logger.TrafficMessage
|
|
if systemText != "" {
|
|
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: "system", Content: systemText})
|
|
}
|
|
for _, m := range cleanMessages {
|
|
if mm, ok := m.(map[string]interface{}); ok {
|
|
role, _ := mm["role"].(string)
|
|
content := openai.MessageContentToText(mm["content"])
|
|
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
|
|
}
|
|
}
|
|
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, true)
|
|
|
|
// Resolve workspace
|
|
ws := workspace.ResolveWorkspace(cfg, "")
|
|
|
|
// Build command args
|
|
if cfg.Verbose {
|
|
logger.LogDebug("model=%s prompt_len=%d", cursorModel, len(prompt))
|
|
}
|
|
|
|
maxCmdline := cfg.WinCmdlineMax
|
|
if maxCmdline == 0 {
|
|
maxCmdline = 32768
|
|
}
|
|
fixedArgs := usecase.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, true)
|
|
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, maxCmdline, ws.WorkspaceDir)
|
|
|
|
if cfg.Verbose {
|
|
logger.LogDebug("cmd_args=%v", fit.Args)
|
|
}
|
|
|
|
if !fit.OK {
|
|
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
"error": map[string]string{"type": "api_error", "message": fit.Error},
|
|
}, nil)
|
|
return nil
|
|
}
|
|
if fit.Truncated {
|
|
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
|
|
}
|
|
|
|
cmdArgs := fit.Args
|
|
msgID := "msg_" + uuid.New().String()
|
|
|
|
var truncatedHeaders map[string]string
|
|
if fit.Truncated {
|
|
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
|
}
|
|
|
|
hasTools := len(req.Tools) > 0
|
|
var toolNames map[string]bool
|
|
if hasTools {
|
|
toolNames = usecase.CollectToolNames(convertToolsToInterface(req.Tools))
|
|
}
|
|
|
|
// Write SSE headers
|
|
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
|
flusher, _ := w.(http.Flusher)
|
|
|
|
var p parser.Parser
|
|
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "message_start",
|
|
"message": map[string]interface{}{
|
|
"id": msgID,
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": model,
|
|
"content": []interface{}{},
|
|
},
|
|
})
|
|
|
|
if hasTools {
|
|
p = createAnthropicToolParser(w, flusher, model, toolNames, cfg.Verbose)
|
|
} else {
|
|
p = createAnthropicStreamParser(w, flusher, model, cfg.Verbose)
|
|
}
|
|
|
|
configDir := l.svcCtx.AccountPool.GetNextConfigDir()
|
|
logger.LogAccountAssigned(configDir)
|
|
l.svcCtx.AccountPool.ReportRequestStart(configDir)
|
|
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
|
streamStart := time.Now().UnixMilli()
|
|
|
|
wrappedParser := func(line string) {
|
|
logger.LogRawLine(line)
|
|
p.Parse(line)
|
|
}
|
|
result, err := usecase.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, l.ctx)
|
|
|
|
if l.ctx.Err() == nil {
|
|
p.Flush()
|
|
}
|
|
|
|
latencyMs := time.Now().UnixMilli() - streamStart
|
|
l.svcCtx.AccountPool.ReportRequestEnd(configDir)
|
|
|
|
if l.ctx.Err() == context.DeadlineExceeded {
|
|
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
|
} else if l.ctx.Err() == context.Canceled {
|
|
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
|
} else if err == nil && isRateLimited(result.Stderr) {
|
|
l.svcCtx.AccountPool.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
|
}
|
|
|
|
if err != nil || (result.Code != 0 && l.ctx.Err() == nil) {
|
|
l.svcCtx.AccountPool.ReportRequestError(configDir, latencyMs)
|
|
errMsg := "unknown error"
|
|
if err != nil {
|
|
errMsg = err.Error()
|
|
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", -1, errMsg)
|
|
} else {
|
|
errMsg = result.Stderr
|
|
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", result.Code, result.Stderr)
|
|
}
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "error",
|
|
"error": map[string]interface{}{"type": "api_error", "message": errMsg},
|
|
})
|
|
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
|
} else if l.ctx.Err() == nil {
|
|
l.svcCtx.AccountPool.ReportRequestSuccess(configDir, latencyMs)
|
|
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
|
}
|
|
logger.LogAccountStats(cfg.Verbose, l.svcCtx.AccountPool.GetStats())
|
|
|
|
return nil
|
|
}
|
|
|
|
func createAnthropicStreamParser(w http.ResponseWriter, flusher http.Flusher, model string, verbose bool) parser.Parser {
|
|
var textBlockOpen bool
|
|
var textBlockIndex int
|
|
var thinkingOpen bool
|
|
var thinkingBlockIndex int
|
|
var blockCount int
|
|
|
|
return parser.CreateStreamParserWithThinking(
|
|
func(text string) {
|
|
if verbose {
|
|
logger.LogStreamChunk(model, text, 0)
|
|
}
|
|
if !textBlockOpen && !thinkingOpen {
|
|
textBlockIndex = blockCount
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_start",
|
|
"index": textBlockIndex,
|
|
"content_block": map[string]string{"type": "text", "text": ""},
|
|
})
|
|
textBlockOpen = true
|
|
blockCount++
|
|
}
|
|
if thinkingOpen {
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_stop", "index": thinkingBlockIndex,
|
|
})
|
|
thinkingOpen = false
|
|
}
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_delta",
|
|
"index": textBlockIndex,
|
|
"delta": map[string]string{"type": "text_delta", "text": text},
|
|
})
|
|
},
|
|
func(thinking string) {
|
|
if verbose {
|
|
logger.LogStreamChunk(model, thinking, 0)
|
|
}
|
|
if !thinkingOpen {
|
|
thinkingBlockIndex = blockCount
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_start",
|
|
"index": thinkingBlockIndex,
|
|
"content_block": map[string]string{"type": "thinking", "thinking": ""},
|
|
})
|
|
thinkingOpen = true
|
|
blockCount++
|
|
}
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_delta",
|
|
"index": thinkingBlockIndex,
|
|
"delta": map[string]string{"type": "thinking_delta", "thinking": thinking},
|
|
})
|
|
},
|
|
func() {
|
|
if textBlockOpen {
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_stop", "index": textBlockIndex,
|
|
})
|
|
}
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "message_delta",
|
|
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
|
"usage": map[string]int{"output_tokens": 0},
|
|
})
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{"type": "message_stop"})
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
},
|
|
)
|
|
}
|
|
|
|
func createAnthropicToolParser(w http.ResponseWriter, flusher http.Flusher, model string, toolNames map[string]bool, verbose bool) parser.Parser {
|
|
var accumulated string
|
|
toolCallMarkerRe := regexp.MustCompile(`行政法规|<function_calls>`)
|
|
var toolCallMode bool
|
|
var textBlockOpen bool
|
|
var textBlockIndex int
|
|
var blockCount int
|
|
|
|
return parser.CreateStreamParserWithThinking(
|
|
func(text string) {
|
|
accumulated += text
|
|
if verbose {
|
|
logger.LogStreamChunk(model, text, 0)
|
|
}
|
|
if toolCallMode {
|
|
return
|
|
}
|
|
if toolCallMarkerRe.MatchString(text) {
|
|
if textBlockOpen {
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_stop", "index": textBlockIndex,
|
|
})
|
|
textBlockOpen = false
|
|
}
|
|
toolCallMode = true
|
|
return
|
|
}
|
|
if !textBlockOpen {
|
|
textBlockIndex = blockCount
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_start",
|
|
"index": textBlockIndex,
|
|
"content_block": map[string]string{"type": "text", "text": ""},
|
|
})
|
|
textBlockOpen = true
|
|
blockCount++
|
|
}
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_delta",
|
|
"index": textBlockIndex,
|
|
"delta": map[string]string{"type": "text_delta", "text": text},
|
|
})
|
|
},
|
|
func(thinking string) {},
|
|
func() {
|
|
if verbose {
|
|
logger.LogTrafficResponse(verbose, model, accumulated, true)
|
|
}
|
|
parsed := usecase.ExtractToolCalls(accumulated, toolNames)
|
|
blockIndex := 0
|
|
|
|
if textBlockOpen {
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_stop", "index": textBlockIndex,
|
|
})
|
|
blockIndex = textBlockIndex + 1
|
|
}
|
|
|
|
if parsed.HasToolCalls() {
|
|
for _, tc := range parsed.ToolCalls {
|
|
toolID := "toolu_" + uuid.New().String()[:12]
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_start", "index": blockIndex,
|
|
"content_block": map[string]interface{}{
|
|
"type": "tool_use", "id": toolID, "name": tc.Name, "input": map[string]interface{}{},
|
|
},
|
|
})
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_delta", "index": blockIndex,
|
|
"delta": map[string]interface{}{
|
|
"type": "input_json_delta", "partial_json": tc.Arguments,
|
|
},
|
|
})
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "content_block_stop", "index": blockIndex,
|
|
})
|
|
blockIndex++
|
|
}
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "message_delta",
|
|
"delta": map[string]interface{}{"stop_reason": "tool_use", "stop_sequence": nil},
|
|
"usage": map[string]int{"output_tokens": 0},
|
|
})
|
|
} else {
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{
|
|
"type": "message_delta",
|
|
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
|
"usage": map[string]int{"output_tokens": 0},
|
|
})
|
|
}
|
|
writeAnthropicEvent(w, flusher, map[string]interface{}{"type": "message_stop"})
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
},
|
|
)
|
|
}
|
|
|
|
func writeAnthropicEvent(w http.ResponseWriter, flusher http.Flusher, evt interface{}) {
|
|
data, _ := json.Marshal(evt)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
func convertAnthropicMessagesToInterface(msgs []apitypes.Message) []interface{} {
|
|
result := make([]interface{}, len(msgs))
|
|
for i, m := range msgs {
|
|
result[i] = map[string]interface{}{
|
|
"role": m.Role,
|
|
"content": m.Content,
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertToAnthropicParams(msgs []interface{}) []anthropic.MessageParam {
|
|
result := make([]anthropic.MessageParam, len(msgs))
|
|
for i, m := range msgs {
|
|
if mm, ok := m.(map[string]interface{}); ok {
|
|
result[i] = anthropic.MessageParam{
|
|
Role: mm["role"].(string),
|
|
Content: mm["content"],
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertToolsToInterface(tools []apitypes.Tool) []interface{} {
|
|
if tools == nil {
|
|
return nil
|
|
}
|
|
result := make([]interface{}, len(tools))
|
|
for i, t := range tools {
|
|
result[i] = map[string]interface{}{
|
|
"type": t.Type,
|
|
"function": map[string]interface{}{
|
|
"name": t.Function.Name,
|
|
"description": t.Function.Description,
|
|
"parameters": t.Function.Parameters,
|
|
},
|
|
}
|
|
}
|
|
return result
|
|
}
|