269 lines
6.9 KiB
Go
269 lines
6.9 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"math"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/daniel/cursor-adapter/internal/converter"
|
|
"github.com/daniel/cursor-adapter/internal/sanitize"
|
|
"github.com/daniel/cursor-adapter/internal/types"
|
|
)
|
|
|
|
var (
|
|
modelCacheMu sync.Mutex
|
|
modelCacheData []string
|
|
modelCacheAt time.Time
|
|
modelCacheTTL = 5 * time.Minute
|
|
)
|
|
|
|
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
|
models, err := s.cachedListModels(r.Context())
|
|
if err != nil {
|
|
writeJSON(w, http.StatusInternalServerError, types.NewErrorResponse(err.Error(), "internal_error", ""))
|
|
return
|
|
}
|
|
|
|
ts := time.Now().Unix()
|
|
data := make([]types.ModelInfo, 0, len(models)*2)
|
|
for _, m := range models {
|
|
data = append(data, types.ModelInfo{ID: m, Object: "model", Created: ts, OwnedBy: "cursor"})
|
|
}
|
|
aliases := converter.GetAnthropicModelAliases(models)
|
|
for _, a := range aliases {
|
|
data = append(data, types.ModelInfo{ID: a.ID, Object: "model", Created: ts, OwnedBy: "cursor"})
|
|
}
|
|
writeJSON(w, http.StatusOK, types.ModelList{Object: "list", Data: data})
|
|
}
|
|
|
|
func (s *Server) cachedListModels(ctx context.Context) ([]string, error) {
|
|
modelCacheMu.Lock()
|
|
defer modelCacheMu.Unlock()
|
|
|
|
if modelCacheData != nil && time.Since(modelCacheAt) < modelCacheTTL {
|
|
return modelCacheData, nil
|
|
}
|
|
|
|
models, err := s.br.ListModels(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
modelCacheData = models
|
|
modelCacheAt = time.Now()
|
|
return models, nil
|
|
}
|
|
|
|
func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
var req types.ChatCompletionRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("invalid request body: "+err.Error(), "invalid_request_error", ""))
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
|
|
if len(req.Messages) == 0 {
|
|
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("messages must not be empty", "invalid_request_error", ""))
|
|
return
|
|
}
|
|
|
|
var parts []string
|
|
for _, m := range req.Messages {
|
|
text := sanitize.Text(string(m.Content))
|
|
parts = append(parts, fmt.Sprintf("%s: %s", m.Role, text))
|
|
}
|
|
prompt := strings.Join(parts, "\n")
|
|
|
|
model := req.Model
|
|
if model == "" {
|
|
model = s.cfg.DefaultModel
|
|
}
|
|
cursorModel := converter.ResolveToCursorModel(model)
|
|
sessionKey := ensureSessionHeader(w, r)
|
|
|
|
chatID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
|
|
created := time.Now().Unix()
|
|
|
|
if req.Stream {
|
|
s.streamChat(w, r, prompt, cursorModel, model, chatID, created, sessionKey)
|
|
return
|
|
}
|
|
|
|
s.nonStreamChat(w, r, prompt, cursorModel, model, chatID, created, sessionKey)
|
|
}
|
|
|
|
func (s *Server) streamChat(w http.ResponseWriter, r *http.Request, prompt, cursorModel, displayModel, chatID string, created int64, sessionKey string) {
|
|
sse := NewSSEWriter(w)
|
|
parser := converter.NewStreamParser(chatID)
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(s.cfg.Timeout)*time.Second)
|
|
defer cancel()
|
|
go func() {
|
|
<-r.Context().Done()
|
|
cancel()
|
|
}()
|
|
|
|
outputChan, errChan := s.br.Execute(ctx, prompt, cursorModel, sessionKey)
|
|
|
|
roleAssistant := "assistant"
|
|
initChunk := types.NewChatCompletionChunk(chatID, created, displayModel, types.Delta{
|
|
Role: &roleAssistant,
|
|
})
|
|
if err := sse.WriteChunk(initChunk); err != nil {
|
|
return
|
|
}
|
|
|
|
var accumulated strings.Builder
|
|
|
|
for line := range outputChan {
|
|
result := parser.Parse(line)
|
|
|
|
if result.Skip {
|
|
continue
|
|
}
|
|
|
|
if result.Error != nil {
|
|
if strings.Contains(result.Error.Error(), "unmarshal error") {
|
|
result = parser.ParseRawText(line)
|
|
if result.Skip {
|
|
continue
|
|
}
|
|
if result.Chunk != nil {
|
|
result.Chunk.Created = created
|
|
result.Chunk.Model = displayModel
|
|
if c := result.Chunk.Choices[0].Delta.Content; c != nil {
|
|
accumulated.WriteString(*c)
|
|
}
|
|
if err := sse.WriteChunk(*result.Chunk); err != nil {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
sse.WriteError(result.Error.Error())
|
|
return
|
|
}
|
|
|
|
if result.Chunk != nil {
|
|
result.Chunk.Created = created
|
|
result.Chunk.Model = displayModel
|
|
if len(result.Chunk.Choices) > 0 {
|
|
if c := result.Chunk.Choices[0].Delta.Content; c != nil {
|
|
accumulated.WriteString(*c)
|
|
}
|
|
}
|
|
if err := sse.WriteChunk(*result.Chunk); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
if result.Done {
|
|
break
|
|
}
|
|
}
|
|
|
|
promptTokens := maxInt(1, int(math.Round(float64(len(prompt))/4.0)))
|
|
completionTokens := maxInt(1, int(math.Round(float64(accumulated.Len())/4.0)))
|
|
usage := &types.Usage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: promptTokens + completionTokens,
|
|
}
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
if err != nil {
|
|
slog.Error("stream bridge error", "err", err)
|
|
sse.WriteError(err.Error())
|
|
return
|
|
}
|
|
default:
|
|
}
|
|
|
|
stopReason := "stop"
|
|
finalChunk := types.NewChatCompletionChunk(chatID, created, displayModel, types.Delta{})
|
|
finalChunk.Choices[0].FinishReason = &stopReason
|
|
finalChunk.Usage = usage
|
|
sse.WriteChunk(finalChunk)
|
|
sse.WriteDone()
|
|
}
|
|
|
|
func (s *Server) nonStreamChat(w http.ResponseWriter, r *http.Request, prompt, cursorModel, displayModel, chatID string, created int64, sessionKey string) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(s.cfg.Timeout)*time.Second)
|
|
defer cancel()
|
|
go func() {
|
|
<-r.Context().Done()
|
|
cancel()
|
|
}()
|
|
|
|
content, err := s.br.ExecuteSync(ctx, prompt, cursorModel, sessionKey)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusInternalServerError, types.NewErrorResponse(err.Error(), "internal_error", ""))
|
|
return
|
|
}
|
|
|
|
usage := estimateUsage(prompt, content)
|
|
stopReason := "stop"
|
|
|
|
resp := types.ChatCompletionResponse{
|
|
ID: chatID,
|
|
Object: "chat.completion",
|
|
Created: created,
|
|
Model: displayModel,
|
|
Choices: []types.Choice{
|
|
{
|
|
Index: 0,
|
|
Message: types.ChatMessage{Role: "assistant", Content: types.ChatMessageContent(content)},
|
|
FinishReason: &stopReason,
|
|
},
|
|
},
|
|
Usage: usage,
|
|
}
|
|
writeJSON(w, http.StatusOK, resp)
|
|
}
|
|
|
|
func estimateUsage(prompt, content string) types.Usage {
|
|
promptTokens := maxInt(1, int(math.Round(float64(len(prompt))/4.0)))
|
|
completionTokens := maxInt(1, int(math.Round(float64(len(content))/4.0)))
|
|
return types.Usage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: promptTokens + completionTokens,
|
|
}
|
|
}
|
|
|
|
func maxInt(a, b int) int {
|
|
if a > b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
status := "ok"
|
|
cliStatus := "available"
|
|
if err := s.br.CheckHealth(ctx); err != nil {
|
|
cliStatus = fmt.Sprintf("unavailable: %v", err)
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, map[string]string{
|
|
"status": status,
|
|
"cursor_cli": cliStatus,
|
|
"version": "0.2.0",
|
|
})
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
json.NewEncoder(w).Encode(v)
|
|
}
|