opencode-cursor-agent/internal/logic/chat/chat_completions_logic.go

484 lines
14 KiB
Go

// Code scaffolded by goctl. Safe to edit.
// goctl 1.10.1
package chat
import (
"context"
"encoding/json"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/svc"
apitypes "cursor-api-proxy/internal/types"
"cursor-api-proxy/pkg/adapter/openai"
"cursor-api-proxy/pkg/domain/types"
"cursor-api-proxy/pkg/infrastructure/httputil"
"cursor-api-proxy/pkg/infrastructure/logger"
"cursor-api-proxy/pkg/infrastructure/parser"
"cursor-api-proxy/pkg/infrastructure/winlimit"
"cursor-api-proxy/pkg/infrastructure/workspace"
"cursor-api-proxy/pkg/usecase"
"github.com/google/uuid"
"github.com/zeromicro/go-zero/core/logx"
)
var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`)
var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`)
func isRateLimited(stderr string) bool {
return rateLimitRe.MatchString(stderr)
}
func extractRetryAfterMs(stderr string) int64 {
if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 {
if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 {
return secs * 1000
}
}
return 60000
}
type ChatCompletionsLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic {
return &ChatCompletionsLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *ChatCompletionsLogic) resolveModel(requested string, lastModelRef *string) string {
cfg := l.svcCtx.Config
isAuto := requested == "auto"
var explicitModel string
if requested != "" && !isAuto {
explicitModel = requested
}
if explicitModel != "" {
*lastModelRef = explicitModel
}
if isAuto {
return "auto"
}
if explicitModel != "" {
return explicitModel
}
if cfg.StrictModel && *lastModelRef != "" {
return *lastModelRef
}
if *lastModelRef != "" {
return *lastModelRef
}
return cfg.DefaultModel
}
func (l *ChatCompletionsLogic) ChatCompletions(req *apitypes.ChatCompletionRequest) (*apitypes.ChatCompletionResponse, error) {
return nil, fmt.Errorf("non-streaming not yet implemented, use stream=true")
}
func (l *ChatCompletionsLogic) ChatCompletionsStream(req *apitypes.ChatCompletionRequest, w http.ResponseWriter, method, pathname string) error {
cfg := configToBridge(l.svcCtx.Config)
rawModel := req.Model
requested := openai.NormalizeModelID(rawModel)
lastModelRef := new(string)
model := l.resolveModel(requested, lastModelRef)
cursorModel := types.ResolveToCursorModel(model)
if cursorModel == "" {
cursorModel = model
}
messages := convertMessages(req.Messages)
tools := convertTools(req.Tools)
functions := convertFunctions(req.Functions)
cleanMessages := usecase.SanitizeMessages(messages)
toolsText := openai.ToolsToSystemText(tools, functions)
messagesWithTools := cleanMessages
if toolsText != "" {
messagesWithTools = append([]interface{}{
map[string]interface{}{"role": "system", "content": toolsText},
}, cleanMessages...)
}
prompt := openai.BuildPromptFromMessages(messagesWithTools)
var trafficMsgs []logger.TrafficMessage
for _, raw := range cleanMessages {
if m, ok := raw.(map[string]interface{}); ok {
role, _ := m["role"].(string)
content := openai.MessageContentToText(m["content"])
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
}
}
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, true)
ws := workspace.ResolveWorkspace(cfg, "")
promptLen := len(prompt)
if cfg.Verbose {
if promptLen > 200 {
logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200])
} else {
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt)
}
}
maxCmdline := cfg.WinCmdlineMax
if maxCmdline == 0 {
maxCmdline = 32768
}
fixedArgs := usecase.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, true)
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, maxCmdline, ws.WorkspaceDir)
if l.svcCtx.Config.Verbose {
logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args)
}
if !fit.OK {
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"},
}, nil)
return nil
}
if fit.Truncated {
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
}
cmdArgs := fit.Args
id := "chatcmpl_" + uuid.New().String()
created := time.Now().Unix()
var truncatedHeaders map[string]string
if fit.Truncated {
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
}
hasTools := len(tools) > 0 || len(functions) > 0
var toolNames map[string]bool
if hasTools {
toolNames = usecase.CollectToolNames(tools)
for _, f := range functions {
if fm, ok := f.(map[string]interface{}); ok {
if name, ok := fm["name"].(string); ok {
toolNames[name] = true
}
}
}
}
httputil.WriteSSEHeaders(w, truncatedHeaders)
flusher, _ := w.(http.Flusher)
var accumulated string
var chunkNum int
var p parser.Parser
toolCallMarkerRe := regexp.MustCompile(`\x1e|<function_calls>`)
if hasTools {
var toolCallMode bool
p = parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
chunkNum++
logger.LogStreamChunk(model, text, chunkNum)
if toolCallMode {
return
}
if toolCallMarkerRe.MatchString(text) {
toolCallMode = true
return
}
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
},
func(thinking string) {
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
},
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
parsed := usecase.ExtractToolCalls(accumulated, toolNames)
if parsed.HasToolCalls() {
if parsed.TextContent != "" && toolCallMode {
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
}
for i, tc := range parsed.ToolCalls {
callID := "call_" + uuid.New().String()[:8]
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{
"tool_calls": []map[string]interface{}{
{
"index": i,
"id": callID,
"type": "function",
"function": map[string]interface{}{
"name": tc.Name,
"arguments": tc.Arguments,
},
},
},
}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
}
stopChunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
} else {
stopChunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
}
},
)
} else {
p = parser.CreateStreamParserWithThinking(
func(text string) {
accumulated += text
chunkNum++
logger.LogStreamChunk(model, text, chunkNum)
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
},
func(thinking string) {
chunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
},
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
stopChunk := map[string]interface{}{
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
},
)
}
configDir := l.svcCtx.AccountPool.GetNextConfigDir()
logger.LogAccountAssigned(configDir)
l.svcCtx.AccountPool.ReportRequestStart(configDir)
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
streamStart := time.Now().UnixMilli()
wrappedParser := func(line string) {
logger.LogRawLine(line)
p.Parse(line)
}
result, err := usecase.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, l.ctx)
if l.ctx.Err() == nil {
p.Flush()
}
latencyMs := time.Now().UnixMilli() - streamStart
l.svcCtx.AccountPool.ReportRequestEnd(configDir)
if l.ctx.Err() == context.DeadlineExceeded {
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
} else if l.ctx.Err() == context.Canceled {
logger.LogClientDisconnect(method, pathname, model, latencyMs)
} else if err == nil && isRateLimited(result.Stderr) {
l.svcCtx.AccountPool.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
}
if err != nil || (result.Code != 0 && l.ctx.Err() == nil) {
l.svcCtx.AccountPool.ReportRequestError(configDir, latencyMs)
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", -1, err.Error())
} else {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, "", result.Code, result.Stderr)
}
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
} else if l.ctx.Err() == nil {
l.svcCtx.AccountPool.ReportRequestSuccess(configDir, latencyMs)
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
}
logger.LogAccountStats(cfg.Verbose, l.svcCtx.AccountPool.GetStats())
return nil
}
func convertMessages(msgs []apitypes.Message) []interface{} {
result := make([]interface{}, len(msgs))
for i, m := range msgs {
result[i] = map[string]interface{}{
"role": m.Role,
"content": m.Content,
}
}
return result
}
func convertTools(tools []apitypes.Tool) []interface{} {
if tools == nil {
return nil
}
result := make([]interface{}, len(tools))
for i, t := range tools {
result[i] = map[string]interface{}{
"type": t.Type,
"function": map[string]interface{}{
"name": t.Function.Name,
"description": t.Function.Description,
"parameters": t.Function.Parameters,
},
}
}
return result
}
func convertFunctions(funcs []apitypes.Function) []interface{} {
if funcs == nil {
return nil
}
result := make([]interface{}, len(funcs))
for i, f := range funcs {
result[i] = map[string]interface{}{
"name": f.Name,
"description": f.Description,
"parameters": f.Parameters,
}
}
return result
}
func configToBridge(c config.Config) config.BridgeConfig {
host := c.Host
if host == "" {
host = "0.0.0.0"
}
return config.BridgeConfig{
AgentBin: c.AgentBin,
Host: host,
Port: c.Port,
RequiredKey: c.RequiredKey,
DefaultModel: c.DefaultModel,
Mode: "ask",
Provider: c.Provider,
Force: c.Force,
ApproveMcps: c.ApproveMcps,
StrictModel: c.StrictModel,
Workspace: c.Workspace,
TimeoutMs: c.TimeoutMs,
TLSCertPath: c.TLSCertPath,
TLSKeyPath: c.TLSKeyPath,
SessionsLogPath: c.SessionsLogPath,
ChatOnlyWorkspace: c.ChatOnlyWorkspace,
Verbose: c.Verbose,
MaxMode: c.MaxMode,
ConfigDirs: c.ConfigDirs,
MultiPort: c.MultiPort,
WinCmdlineMax: c.WinCmdlineMax,
GeminiAccountDir: c.GeminiAccountDir,
GeminiBrowserVisible: c.GeminiBrowserVisible,
GeminiMaxSessions: c.GeminiMaxSessions,
}
}
// StringsToMapSlice converts string slice for compatibility
func StringsToMapSlice(ss []string) []map[string]string {
result := make([]map[string]string, len(ss))
for i, s := range ss {
result[i] = map[string]string{"content": s}
}
return result
}
// JoinStrings joins strings with newline
func JoinStrings(ss []string) string {
return strings.Join(ss, "\n")
}