472 lines
15 KiB
Go
472 lines
15 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"cursor-api-proxy/internal/agent"
|
|
"cursor-api-proxy/internal/config"
|
|
"cursor-api-proxy/internal/httputil"
|
|
"cursor-api-proxy/internal/logger"
|
|
"cursor-api-proxy/internal/models"
|
|
"cursor-api-proxy/internal/openai"
|
|
"cursor-api-proxy/internal/parser"
|
|
"cursor-api-proxy/internal/pool"
|
|
"cursor-api-proxy/internal/sanitize"
|
|
"cursor-api-proxy/internal/toolcall"
|
|
"cursor-api-proxy/internal/winlimit"
|
|
"cursor-api-proxy/internal/workspace"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var rateLimitRe = regexp.MustCompile(`(?i)\b429\b|rate.?limit|too many requests`)
|
|
var retryAfterRe = regexp.MustCompile(`(?i)retry-after:\s*(\d+)`)
|
|
|
|
func isRateLimited(stderr string) bool {
|
|
return rateLimitRe.MatchString(stderr)
|
|
}
|
|
|
|
func extractRetryAfterMs(stderr string) int64 {
|
|
if m := retryAfterRe.FindStringSubmatch(stderr); len(m) > 1 {
|
|
if secs, err := strconv.ParseInt(m[1], 10, 64); err == nil && secs > 0 {
|
|
return secs * 1000
|
|
}
|
|
}
|
|
return 60000
|
|
}
|
|
|
|
func HandleChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, ph pool.PoolHandle, lastModelRef *string, rawBody, method, pathname, remoteAddress string) {
|
|
var bodyMap map[string]interface{}
|
|
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
|
httputil.WriteJSON(w, 400, map[string]interface{}{
|
|
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
|
|
}, nil)
|
|
return
|
|
}
|
|
|
|
rawModel, _ := bodyMap["model"].(string)
|
|
requested := openai.NormalizeModelID(rawModel)
|
|
model := ResolveModel(requested, lastModelRef, cfg)
|
|
cursorModel := models.ResolveToCursorModel(model)
|
|
if cursorModel == "" {
|
|
cursorModel = model
|
|
}
|
|
|
|
var messages []interface{}
|
|
if m, ok := bodyMap["messages"].([]interface{}); ok {
|
|
messages = m
|
|
}
|
|
|
|
var tools []interface{}
|
|
if t, ok := bodyMap["tools"].([]interface{}); ok {
|
|
tools = t
|
|
}
|
|
var funcs []interface{}
|
|
if f, ok := bodyMap["functions"].([]interface{}); ok {
|
|
funcs = f
|
|
}
|
|
|
|
cleanMessages := sanitize.SanitizeMessages(messages)
|
|
|
|
toolsText := openai.ToolsToSystemText(tools, funcs)
|
|
messagesWithTools := cleanMessages
|
|
if toolsText != "" {
|
|
messagesWithTools = append([]interface{}{map[string]interface{}{"role": "system", "content": toolsText}}, cleanMessages...)
|
|
}
|
|
prompt := openai.BuildPromptFromMessages(messagesWithTools)
|
|
|
|
var trafficMsgs []logger.TrafficMessage
|
|
for _, raw := range cleanMessages {
|
|
if m, ok := raw.(map[string]interface{}); ok {
|
|
role, _ := m["role"].(string)
|
|
content := openai.MessageContentToText(m["content"])
|
|
trafficMsgs = append(trafficMsgs, logger.TrafficMessage{Role: role, Content: content})
|
|
}
|
|
}
|
|
|
|
isStream := false
|
|
if s, ok := bodyMap["stream"].(bool); ok {
|
|
isStream = s
|
|
}
|
|
|
|
logger.LogTrafficRequest(cfg.Verbose, model, trafficMsgs, isStream)
|
|
|
|
headerWs := r.Header.Get("x-cursor-workspace")
|
|
ws := workspace.ResolveWorkspace(cfg, headerWs)
|
|
|
|
promptLen := len(prompt)
|
|
if cfg.Verbose {
|
|
if promptLen > 200 {
|
|
logger.LogDebug("model=%s prompt_len=%d prompt_start=%q", cursorModel, promptLen, prompt[:200])
|
|
} else {
|
|
logger.LogDebug("model=%s prompt_len=%d prompt=%q", cursorModel, promptLen, prompt)
|
|
}
|
|
}
|
|
|
|
fixedArgs := agent.BuildAgentFixedArgs(cfg, ws.WorkspaceDir, cursorModel, isStream)
|
|
fit := winlimit.FitPromptToWinCmdline(cfg.AgentBin, fixedArgs, prompt, cfg.WinCmdlineMax, ws.WorkspaceDir)
|
|
|
|
if cfg.Verbose {
|
|
logger.LogDebug("cmd=%s args=%v", cfg.AgentBin, fit.Args)
|
|
}
|
|
|
|
if !fit.OK {
|
|
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
"error": map[string]string{"message": fit.Error, "code": "windows_cmdline_limit"},
|
|
}, nil)
|
|
return
|
|
}
|
|
if fit.Truncated {
|
|
logger.LogTruncation(fit.OriginalLength, fit.FinalPromptLength)
|
|
}
|
|
|
|
cmdArgs := fit.Args
|
|
id := "chatcmpl_" + uuid.New().String()
|
|
created := time.Now().Unix()
|
|
|
|
var truncatedHeaders map[string]string
|
|
if fit.Truncated {
|
|
truncatedHeaders = map[string]string{"X-Cursor-Proxy-Prompt-Truncated": "true"}
|
|
}
|
|
|
|
hasTools := len(tools) > 0 || len(funcs) > 0
|
|
var toolNames map[string]bool
|
|
if hasTools {
|
|
toolNames = toolcall.CollectToolNames(tools)
|
|
for _, f := range funcs {
|
|
if fm, ok := f.(map[string]interface{}); ok {
|
|
if name, ok := fm["name"].(string); ok {
|
|
toolNames[name] = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if isStream {
|
|
httputil.WriteSSEHeaders(w, truncatedHeaders)
|
|
flusher, _ := w.(http.Flusher)
|
|
|
|
var accumulated string
|
|
var chunkNum int
|
|
var p parser.Parser
|
|
|
|
// toolCallMarkerRe 偵測 tool call 開頭標記,一旦出現就停止即時輸出並進入累積模式
|
|
toolCallMarkerRe := regexp.MustCompile(`<tool_call>|<function_calls>`)
|
|
if hasTools {
|
|
var toolCallMode bool // 是否已進入 tool call 累積模式
|
|
p = parser.CreateStreamParserWithThinking(
|
|
func(text string) {
|
|
accumulated += text
|
|
chunkNum++
|
|
logger.LogStreamChunk(model, text, chunkNum)
|
|
if toolCallMode {
|
|
// 已進入累積模式,不即時輸出
|
|
return
|
|
}
|
|
if toolCallMarkerRe.MatchString(text) {
|
|
// 偵測到 tool call 標記,切換為累積模式
|
|
toolCallMode = true
|
|
return
|
|
}
|
|
chunk := map[string]interface{}{
|
|
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(chunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
},
|
|
func(_ string) {}, // thinking ignored in tools mode
|
|
func() {
|
|
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
|
parsed := toolcall.ExtractToolCalls(accumulated, toolNames)
|
|
|
|
if parsed.HasToolCalls() {
|
|
if parsed.TextContent != "" && toolCallMode {
|
|
// 已有部分 text 被即時輸出,只補發剩餘的
|
|
chunk := map[string]interface{}{
|
|
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "delta": map[string]interface{}{"role": "assistant", "content": parsed.TextContent}, "finish_reason": nil},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(chunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
for i, tc := range parsed.ToolCalls {
|
|
callID := "call_" + uuid.New().String()[:8]
|
|
chunk := map[string]interface{}{
|
|
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "delta": map[string]interface{}{
|
|
"tool_calls": []map[string]interface{}{
|
|
{
|
|
"index": i,
|
|
"id": callID,
|
|
"type": "function",
|
|
"function": map[string]interface{}{
|
|
"name": tc.Name,
|
|
"arguments": tc.Arguments,
|
|
},
|
|
},
|
|
},
|
|
}, "finish_reason": nil},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(chunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
stopChunk := map[string]interface{}{
|
|
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "tool_calls"},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(stopChunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
} else {
|
|
stopChunk := map[string]interface{}{
|
|
"id": id, "object": "chat.completion.chunk", "created": created, "model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(stopChunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
},
|
|
)
|
|
} else {
|
|
p = parser.CreateStreamParserWithThinking(
|
|
func(text string) {
|
|
accumulated += text
|
|
chunkNum++
|
|
logger.LogStreamChunk(model, text, chunkNum)
|
|
chunk := map[string]interface{}{
|
|
"id": id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(chunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
},
|
|
func(thinking string) {
|
|
chunk := map[string]interface{}{
|
|
"id": id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "delta": map[string]interface{}{"reasoning_content": thinking}, "finish_reason": nil},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(chunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
},
|
|
func() {
|
|
logger.LogTrafficResponse(cfg.Verbose, model, accumulated, true)
|
|
stopChunk := map[string]interface{}{
|
|
"id": id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop"},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(stopChunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
},
|
|
)
|
|
}
|
|
|
|
configDir := ph.GetNextConfigDir()
|
|
logger.LogAccountAssigned(configDir)
|
|
ph.ReportRequestStart(configDir)
|
|
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, true)
|
|
streamStart := time.Now().UnixMilli()
|
|
|
|
ctx := r.Context()
|
|
wrappedParser := func(line string) {
|
|
logger.LogRawLine(line)
|
|
p.Parse(line)
|
|
}
|
|
result, err := agent.RunAgentStreamWithContext(cfg, ws.WorkspaceDir, cmdArgs, wrappedParser, ws.TempDir, configDir, ctx)
|
|
|
|
// agent 結束後,若未收到 result/success 訊號,強制 flush 以確保 SSE stream 正確結尾
|
|
if ctx.Err() == nil {
|
|
p.Flush()
|
|
}
|
|
|
|
latencyMs := time.Now().UnixMilli() - streamStart
|
|
ph.ReportRequestEnd(configDir)
|
|
|
|
if ctx.Err() == context.DeadlineExceeded {
|
|
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
|
} else if ctx.Err() == context.Canceled {
|
|
logger.LogClientDisconnect(method, pathname, model, latencyMs)
|
|
} else if err == nil && isRateLimited(result.Stderr) {
|
|
ph.ReportRateLimit(configDir, extractRetryAfterMs(result.Stderr))
|
|
}
|
|
|
|
if err != nil || (result.Code != 0 && ctx.Err() == nil) {
|
|
ph.ReportRequestError(configDir, latencyMs)
|
|
if err != nil {
|
|
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
|
} else {
|
|
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, result.Code, result.Stderr)
|
|
}
|
|
logger.LogRequestDone(method, pathname, model, latencyMs, result.Code)
|
|
} else if ctx.Err() == nil {
|
|
ph.ReportRequestSuccess(configDir, latencyMs)
|
|
logger.LogRequestDone(method, pathname, model, latencyMs, 0)
|
|
}
|
|
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
return
|
|
}
|
|
|
|
configDir := ph.GetNextConfigDir()
|
|
logger.LogAccountAssigned(configDir)
|
|
ph.ReportRequestStart(configDir)
|
|
logger.LogRequestStart(method, pathname, model, cfg.TimeoutMs, false)
|
|
syncStart := time.Now().UnixMilli()
|
|
|
|
out, err := agent.RunAgentSync(cfg, ws.WorkspaceDir, cmdArgs, ws.TempDir, configDir, r.Context())
|
|
syncLatency := time.Now().UnixMilli() - syncStart
|
|
ph.ReportRequestEnd(configDir)
|
|
|
|
ctx := r.Context()
|
|
if ctx.Err() == context.DeadlineExceeded {
|
|
logger.LogRequestTimeout(method, pathname, model, cfg.TimeoutMs)
|
|
httputil.WriteJSON(w, 504, map[string]interface{}{
|
|
"error": map[string]string{"message": fmt.Sprintf("request timed out after %dms", cfg.TimeoutMs), "code": "timeout"},
|
|
}, nil)
|
|
return
|
|
}
|
|
if ctx.Err() == context.Canceled {
|
|
logger.LogClientDisconnect(method, pathname, model, syncLatency)
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
ph.ReportRequestError(configDir, syncLatency)
|
|
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
logger.LogRequestDone(method, pathname, model, syncLatency, -1)
|
|
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
"error": map[string]string{"message": err.Error(), "code": "cursor_cli_error"},
|
|
}, nil)
|
|
return
|
|
}
|
|
|
|
if isRateLimited(out.Stderr) {
|
|
ph.ReportRateLimit(configDir, extractRetryAfterMs(out.Stderr))
|
|
}
|
|
|
|
if out.Code != 0 {
|
|
ph.ReportRequestError(configDir, syncLatency)
|
|
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
errMsg := logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, out.Code, out.Stderr)
|
|
logger.LogRequestDone(method, pathname, model, syncLatency, out.Code)
|
|
httputil.WriteJSON(w, 500, map[string]interface{}{
|
|
"error": map[string]string{"message": errMsg, "code": "cursor_cli_error"},
|
|
}, nil)
|
|
return
|
|
}
|
|
|
|
ph.ReportRequestSuccess(configDir, syncLatency)
|
|
content := strings.TrimSpace(out.Stdout)
|
|
logger.LogTrafficResponse(cfg.Verbose, model, content, false)
|
|
logger.LogAccountStats(cfg.Verbose, ph.GetStats())
|
|
logger.LogRequestDone(method, pathname, model, syncLatency, 0)
|
|
|
|
if hasTools {
|
|
parsed := toolcall.ExtractToolCalls(content, toolNames)
|
|
if parsed.HasToolCalls() {
|
|
msg := map[string]interface{}{"role": "assistant"}
|
|
if parsed.TextContent != "" {
|
|
msg["content"] = parsed.TextContent
|
|
} else {
|
|
msg["content"] = nil
|
|
}
|
|
var tcArr []map[string]interface{}
|
|
for _, tc := range parsed.ToolCalls {
|
|
callID := "call_" + uuid.New().String()[:8]
|
|
tcArr = append(tcArr, map[string]interface{}{
|
|
"id": callID,
|
|
"type": "function",
|
|
"function": map[string]interface{}{
|
|
"name": tc.Name,
|
|
"arguments": tc.Arguments,
|
|
},
|
|
})
|
|
}
|
|
msg["tool_calls"] = tcArr
|
|
httputil.WriteJSON(w, 200, map[string]interface{}{
|
|
"id": id,
|
|
"object": "chat.completion",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": []map[string]interface{}{
|
|
{"index": 0, "message": msg, "finish_reason": "tool_calls"},
|
|
},
|
|
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
}, truncatedHeaders)
|
|
return
|
|
}
|
|
}
|
|
|
|
httputil.WriteJSON(w, 200, map[string]interface{}{
|
|
"id": id,
|
|
"object": "chat.completion",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": []map[string]interface{}{
|
|
{
|
|
"index": 0,
|
|
"message": map[string]string{"role": "assistant", "content": content},
|
|
"finish_reason": "stop",
|
|
},
|
|
},
|
|
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
}, truncatedHeaders)
|
|
}
|
|
|