opencode-cursor-agent/internal/handlers/chat.go

250 lines
7.5 KiB
Go

package handlers
import (
"cursor-api-proxy/internal/agent"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/httputil"
"cursor-api-proxy/internal/logger"
"cursor-api-proxy/internal/models"
"cursor-api-proxy/internal/openai"
"cursor-api-proxy/internal/parser"
"cursor-api-proxy/internal/pool"
"cursor-api-proxy/internal/sanitize"
"cursor-api-proxy/internal/winlimit"
"cursor-api-proxy/internal/workspace"
"encoding/json"
"fmt"
"net/http"
"regexp"
"time"
"github.com/google/uuid"
)
var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`)
func isRateLimited(stderr string) bool {
return rateLimitRe.MatchString(stderr)
}
func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
var bodyMap map[string]interface{}
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
}, nil)
return
}
rawModel, _ := bodyMap["model"].(string)
requested := openai.NormalizeModelID(rawModel)
model := ResolveModel(requested, lastModelRef, cfg)
cursorModel := models.ResolveToCursorModel(model)
if cursorModel == "" {
cursorModel = model
}
var messages []interface{}
if m, ok := bodyMap["messages"].([]interface{}); ok {
messages = m
}
var tools []interface{}
if t, ok := bodyMap["tools"].([]interface{}); ok {
tools = t
}
var funcs []interface{}
if f, ok := bodyMap["functions"].([]interface{}); ok {
funcs = f
}
cleanMessages := sanitize.SanitizeMessages(messages)
toolsText := openai.ToolsToSystemText(tools, funcs)
messagesWithTools := cleanMessages
if toolsText != "" {
messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...)
}
prompt := openai.BuildPromptFromMessages(messagesWithTools)
var trafficMsgs []logger.TrafficMessage
for _, raw := range cleanMessages {
if m, ok := raw.(map[string]interface{}); ok {
role, _ := m["role"].(string)
content := openai.MessageContentToText(m["content"])
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
}
}
isStream := false
if s, ok := bodyMap["stream"].(bool); ok {
isStream = s
}
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream)
headerWs := r.Header.Get("x-cursor-workspace")
ws := workspace.ResolveWorkspace(cfg, headerWs)
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream)
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
if !fit.OK {
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"},
}, nil)
return
}
if fit.Truncated {
fmt.Printf("[%s] Windows: prompt truncated for CreateProcess limit (%d -> %d chars, tail preserved).\n",
time.Now().UTC().Format(time.RFC3339), fit.OriginalLength, fit.FinalPromptLength)
}
cmdArgs := fit.Args
id := "chatcmpl_" + uuid.New().String()
created := time.Now().Unix()
var truncatedHeaders map[string]string
if fit.Truncated {
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
}
if isStream {
httputil.WriteSSEHeaders(w, truncatedHeaders)
flusher, _ := w.(http.Flusher)
var accumulated string
parseLine := parser.CreateStreamParser(
func(text string) {
accumulated += text
chunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
},
func() {
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
stopChunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]interface{}{
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
},
)
configDir := pool.GetNextAccountConfigDir()
logger.LogAccountAssigned(configDir)
pool.ReportRequestStart(configDir)
streamStart := time.Now().UnixMilli()
ctx := r.Context()
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, parseLine, ws.TempDir, configDir, ctx)
latencyMs := time.Now().UnixMilli() - streamStart
pool.ReportRequestEnd(configDir)
if err == nil && isRateLimited(result.Stderr) {
pool.ReportRateLimit(configDir, 60000)
}
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
pool.ReportRequestError(configDir, latencyMs)
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
} else {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
}
} else {
pool.ReportRequestSuccess(configDir, latencyMs)
}
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
return
}
configDir := pool.GetNextAccountConfigDir()
logger.LogAccountAssigned(configDir)
pool.ReportRequestStart(configDir)
syncStart := time.Now().UnixMilli()
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
syncLatency := time.Now().UnixMilli() - syncStart
pool.ReportRequestEnd(configDir)
if err != nil {
pool.ReportRequestError(configDir, syncLatency)
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"},
}, nil)
return
}
if isRateLimited(out.Stderr) {
pool.ReportRateLimit(configDir, 60000)
}
if out.Code != 0 {
pool.ReportRequestError(configDir, syncLatency)
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": errMsg, "code": "cursor_cli_error"},
}, nil)
return
}
pool.ReportRequestSuccess(configDir, syncLatency)
content := trimSpace(out.Stdout)
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
logger.LogAccountStats(cfg.Verbose, pool.GetAccountStats())
httputil.WriteJSON(w, 200, map[string]interface{}{
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"message": map[string]string{"role": "assistant", "content": content},
"finish_reason": "stop",
},
},
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}, truncatedHeaders)
}
func trimSpace(s string) string {
result := ""
start := 0
end := len(s)
for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') {
start++
}
for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') {
end--
}
result = s[start:end]
return result
}