opencode-cursor-agent/internal/server/handlers.go

290 lines
7.6 KiB
Go

package server
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"math"
"net/http"
"strings"
"sync"
"time"
"github.com/daniel/cursor-adapter/internal/converter"
"github.com/daniel/cursor-adapter/internal/sanitize"
"github.com/daniel/cursor-adapter/internal/types"
)
var (
modelCacheMu sync.Mutex
modelCacheData []string
modelCacheAt time.Time
modelCacheTTL = 5 * time.Minute
)
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
models, err := s.cachedListModels(r.Context())
if err != nil {
writeJSON(w, http.StatusInternalServerError, types.NewErrorResponse(err.Error(), "internal_error", ""))
return
}
ts := time.Now().Unix()
data := make([]types.ModelInfo, 0, len(models)*2)
for _, m := range models {
data = append(data, types.ModelInfo{ID: m, Object: "model", Created: ts, OwnedBy: "cursor"})
}
aliases := converter.GetAnthropicModelAliases(models)
for _, a := range aliases {
data = append(data, types.ModelInfo{ID: a.ID, Object: "model", Created: ts, OwnedBy: "cursor"})
}
writeJSON(w, http.StatusOK, types.ModelList{Object: "list", Data: data})
}
func (s *Server) cachedListModels(ctx context.Context) ([]string, error) {
modelCacheMu.Lock()
defer modelCacheMu.Unlock()
if modelCacheData != nil && time.Since(modelCacheAt) < modelCacheTTL {
return modelCacheData, nil
}
models, err := s.br.ListModels(ctx)
if err != nil {
return nil, err
}
modelCacheData = models
modelCacheAt = time.Now()
return models, nil
}
func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
var req types.ChatCompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("invalid request body: "+err.Error(), "invalid_request_error", ""))
return
}
defer r.Body.Close()
if len(req.Messages) == 0 {
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("messages must not be empty", "invalid_request_error", ""))
return
}
// --- Pure brain: only our system prompt, drop the client's ---
var parts []string
if s.cfg.SystemPrompt != "" {
parts = append(parts, "system: "+s.cfg.SystemPrompt)
}
for _, m := range req.Messages {
// Drop client system messages (mode descriptions, tool schemas).
if m.Role == "system" {
continue
}
text := sanitize.Text(string(m.Content))
// Strip <system-reminder> blocks embedded in messages.
text = systemReminderRe.ReplaceAllString(text, "")
text = strings.TrimSpace(text)
if text == "" {
continue
}
parts = append(parts, fmt.Sprintf("%s: %s", m.Role, text))
}
prompt := strings.Join(parts, "\n")
model := req.Model
if model == "" {
model = s.cfg.DefaultModel
}
cursorModel := converter.ResolveToCursorModel(model)
sessionKey := ensureSessionHeader(w, r)
if r.Header.Get(workspaceHeaderName) == "" {
if detected := detectOpenAICwd(req); detected != "" {
slog.Debug("workspace detected from prompt", "path", detected)
r.Header.Set(workspaceHeaderName, detected)
}
}
chatID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
created := time.Now().Unix()
if req.Stream {
s.streamChat(w, r, prompt, cursorModel, model, chatID, created, sessionKey)
return
}
s.nonStreamChat(w, r, prompt, cursorModel, model, chatID, created, sessionKey)
}
func (s *Server) streamChat(w http.ResponseWriter, r *http.Request, prompt, cursorModel, displayModel, chatID string, created int64, sessionKey string) {
sse := NewSSEWriter(w)
parser := converter.NewStreamParser(chatID)
ctx, cancel := context.WithTimeout(requestContext(r), time.Duration(s.cfg.Timeout)*time.Second)
defer cancel()
go func() {
<-r.Context().Done()
cancel()
}()
outputChan, errChan := s.br.Execute(ctx, prompt, cursorModel, sessionKey)
roleAssistant := "assistant"
initChunk := types.NewChatCompletionChunk(chatID, created, displayModel, types.Delta{
Role: &roleAssistant,
})
if err := sse.WriteChunk(initChunk); err != nil {
return
}
var accumulated strings.Builder
for line := range outputChan {
result := parser.Parse(line)
if result.Skip {
continue
}
if result.Error != nil {
if strings.Contains(result.Error.Error(), "unmarshal error") {
result = parser.ParseRawText(line)
if result.Skip {
continue
}
if result.Chunk != nil {
result.Chunk.Created = created
result.Chunk.Model = displayModel
if c := result.Chunk.Choices[0].Delta.Content; c != nil {
accumulated.WriteString(*c)
}
if err := sse.WriteChunk(*result.Chunk); err != nil {
return
}
continue
}
}
sse.WriteError(result.Error.Error())
return
}
if result.Chunk != nil {
result.Chunk.Created = created
result.Chunk.Model = displayModel
if len(result.Chunk.Choices) > 0 {
if c := result.Chunk.Choices[0].Delta.Content; c != nil {
accumulated.WriteString(*c)
}
}
if err := sse.WriteChunk(*result.Chunk); err != nil {
return
}
}
if result.Done {
break
}
}
promptTokens := maxInt(1, int(math.Round(float64(len(prompt))/4.0)))
completionTokens := maxInt(1, int(math.Round(float64(accumulated.Len())/4.0)))
usage := &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
select {
case err := <-errChan:
if err != nil {
slog.Error("stream bridge error", "err", err)
sse.WriteError(err.Error())
return
}
default:
}
stopReason := "stop"
finalChunk := types.NewChatCompletionChunk(chatID, created, displayModel, types.Delta{})
finalChunk.Choices[0].FinishReason = &stopReason
finalChunk.Usage = usage
sse.WriteChunk(finalChunk)
sse.WriteDone()
}
func (s *Server) nonStreamChat(w http.ResponseWriter, r *http.Request, prompt, cursorModel, displayModel, chatID string, created int64, sessionKey string) {
ctx, cancel := context.WithTimeout(requestContext(r), time.Duration(s.cfg.Timeout)*time.Second)
defer cancel()
go func() {
<-r.Context().Done()
cancel()
}()
content, err := s.br.ExecuteSync(ctx, prompt, cursorModel, sessionKey)
if err != nil {
writeJSON(w, http.StatusInternalServerError, types.NewErrorResponse(err.Error(), "internal_error", ""))
return
}
usage := estimateUsage(prompt, content)
stopReason := "stop"
resp := types.ChatCompletionResponse{
ID: chatID,
Object: "chat.completion",
Created: created,
Model: displayModel,
Choices: []types.Choice{
{
Index: 0,
Message: types.ChatMessage{Role: "assistant", Content: types.ChatMessageContent(content)},
FinishReason: &stopReason,
},
},
Usage: usage,
}
writeJSON(w, http.StatusOK, resp)
}
func estimateUsage(prompt, content string) types.Usage {
promptTokens := maxInt(1, int(math.Round(float64(len(prompt))/4.0)))
completionTokens := maxInt(1, int(math.Round(float64(len(content))/4.0)))
return types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
status := "ok"
cliStatus := "available"
if err := s.br.CheckHealth(ctx); err != nil {
cliStatus = fmt.Sprintf("unavailable: %v", err)
}
writeJSON(w, http.StatusOK, map[string]string{
"status": status,
"cursor_cli": cliStatus,
"version": "0.2.0",
})
}
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(v)
}