191 lines
5.0 KiB
Go
191 lines
5.0 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/daniel/cursor-adapter/internal/converter"
|
|
"github.com/daniel/cursor-adapter/internal/types"
|
|
)
|
|
|
|
func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request) {
|
|
var req types.AnthropicMessagesRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("invalid request body: "+err.Error(), "invalid_request_error", ""))
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
|
|
if req.MaxTokens <= 0 {
|
|
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("max_tokens is required", "invalid_request_error", ""))
|
|
return
|
|
}
|
|
if len(req.Messages) == 0 {
|
|
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("messages must not be empty", "invalid_request_error", ""))
|
|
return
|
|
}
|
|
|
|
model := req.Model
|
|
if model == "" {
|
|
model = s.cfg.DefaultModel
|
|
}
|
|
cursorModel := converter.ResolveToCursorModel(model)
|
|
sessionKey := ensureSessionHeader(w, r)
|
|
|
|
msgID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
|
prompt := buildPromptFromAnthropicMessages(req)
|
|
|
|
if req.Stream {
|
|
s.streamAnthropicMessages(w, r, prompt, cursorModel, model, msgID, sessionKey)
|
|
return
|
|
}
|
|
|
|
s.nonStreamAnthropicMessages(w, r, prompt, cursorModel, model, msgID, sessionKey)
|
|
}
|
|
|
|
func (s *Server) streamAnthropicMessages(w http.ResponseWriter, r *http.Request, prompt, cursorModel, displayModel, msgID, sessionKey string) {
|
|
sse := NewSSEWriter(w)
|
|
parser := converter.NewStreamParser(msgID)
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(s.cfg.Timeout)*time.Second)
|
|
defer cancel()
|
|
go func() {
|
|
<-r.Context().Done()
|
|
cancel()
|
|
}()
|
|
|
|
outputChan, errChan := s.br.Execute(ctx, prompt, cursorModel, sessionKey)
|
|
|
|
writeAnthropicSSE(sse, map[string]interface{}{
|
|
"type": "message_start",
|
|
"message": map[string]interface{}{
|
|
"id": msgID,
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": displayModel,
|
|
"content": []interface{}{},
|
|
},
|
|
})
|
|
writeAnthropicSSE(sse, map[string]interface{}{
|
|
"type": "content_block_start",
|
|
"index": 0,
|
|
"content_block": map[string]interface{}{"type": "text", "text": ""},
|
|
})
|
|
|
|
var accumulated strings.Builder
|
|
|
|
for line := range outputChan {
|
|
result := parser.Parse(line)
|
|
|
|
if result.Skip {
|
|
continue
|
|
}
|
|
|
|
if result.Error != nil {
|
|
if strings.Contains(result.Error.Error(), "unmarshal error") {
|
|
result = parser.ParseRawText(line)
|
|
if result.Skip {
|
|
continue
|
|
}
|
|
if result.Chunk != nil && len(result.Chunk.Choices) > 0 {
|
|
if c := result.Chunk.Choices[0].Delta.Content; c != nil {
|
|
accumulated.WriteString(*c)
|
|
writeAnthropicSSE(sse, map[string]interface{}{
|
|
"type": "content_block_delta",
|
|
"index": 0,
|
|
"delta": map[string]interface{}{"type": "text_delta", "text": *c},
|
|
})
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
writeAnthropicSSE(sse, map[string]interface{}{
|
|
"type": "error",
|
|
"error": map[string]interface{}{"type": "api_error", "message": result.Error.Error()},
|
|
})
|
|
return
|
|
}
|
|
|
|
if result.Chunk != nil && len(result.Chunk.Choices) > 0 {
|
|
if c := result.Chunk.Choices[0].Delta.Content; c != nil {
|
|
accumulated.WriteString(*c)
|
|
writeAnthropicSSE(sse, map[string]interface{}{
|
|
"type": "content_block_delta",
|
|
"index": 0,
|
|
"delta": map[string]interface{}{"type": "text_delta", "text": *c},
|
|
})
|
|
}
|
|
}
|
|
|
|
if result.Done {
|
|
break
|
|
}
|
|
}
|
|
|
|
outTokens := maxInt(1, accumulated.Len()/4)
|
|
|
|
writeAnthropicSSE(sse, map[string]interface{}{
|
|
"type": "content_block_stop",
|
|
"index": 0,
|
|
})
|
|
writeAnthropicSSE(sse, map[string]interface{}{
|
|
"type": "message_delta",
|
|
"delta": map[string]interface{}{"stop_reason": "end_turn", "stop_sequence": nil},
|
|
"usage": map[string]interface{}{"output_tokens": outTokens},
|
|
})
|
|
writeAnthropicSSE(sse, map[string]interface{}{
|
|
"type": "message_stop",
|
|
})
|
|
|
|
select {
|
|
case <-errChan:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (s *Server) nonStreamAnthropicMessages(w http.ResponseWriter, r *http.Request, prompt, cursorModel, displayModel, msgID, sessionKey string) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(s.cfg.Timeout)*time.Second)
|
|
defer cancel()
|
|
go func() {
|
|
<-r.Context().Done()
|
|
cancel()
|
|
}()
|
|
|
|
content, err := s.br.ExecuteSync(ctx, prompt, cursorModel, sessionKey)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusInternalServerError, types.NewErrorResponse(err.Error(), "api_error", ""))
|
|
return
|
|
}
|
|
usage := estimateUsage(prompt, content)
|
|
|
|
resp := types.AnthropicMessagesResponse{
|
|
ID: msgID,
|
|
Type: "message",
|
|
Role: "assistant",
|
|
Content: []types.AnthropicTextBlock{{Type: "text", Text: content}},
|
|
Model: displayModel,
|
|
StopReason: "end_turn",
|
|
Usage: types.AnthropicUsage{
|
|
InputTokens: usage.PromptTokens,
|
|
OutputTokens: usage.CompletionTokens,
|
|
},
|
|
}
|
|
writeJSON(w, http.StatusOK, resp)
|
|
}
|
|
|
|
func writeAnthropicSSE(sse *SSEWriter, event interface{}) {
|
|
data, err := json.Marshal(event)
|
|
if err != nil {
|
|
return
|
|
}
|
|
fmt.Fprintf(sse.w, "data: %s\n\n", data)
|
|
if sse.flush != nil {
|
|
sse.flush.Flush()
|
|
}
|
|
}
|