opencode-cursor-agent/internal/server/anthropic_handlers.go

337 lines
8.6 KiB
Go

package server
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/daniel/cursor-adapter/internal/converter"
"github.com/daniel/cursor-adapter/internal/types"
)
func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request) {
bodyBytes, readErr := io.ReadAll(r.Body)
if readErr != nil {
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("read body: "+readErr.Error(), "invalid_request_error", ""))
return
}
r.Body.Close()
var req types.AnthropicMessagesRequest
if err := json.Unmarshal(bodyBytes, &req); err != nil {
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("invalid request body: "+err.Error(), "invalid_request_error", ""))
return
}
if req.MaxTokens <= 0 {
req.MaxTokens = 16384
}
if len(req.Messages) == 0 {
writeJSON(w, http.StatusBadRequest, types.NewErrorResponse("messages must not be empty", "invalid_request_error", ""))
return
}
model := req.Model
if model == "" || model == "auto" {
model = s.cfg.DefaultModel
}
cursorModel := converter.ResolveToCursorModel(model)
sessionKey := ensureSessionHeader(w, r)
// Surface caller-side knobs in the log: which tool names the brain is
// about to see, and (if no explicit X-Cursor-Workspace header was set)
// any host directory the caller's prompt happens to mention. The
// detected directory is promoted onto the request header so the
// downstream bridge picks it up via the standard ctx override path.
if len(req.Tools) > 0 {
toolNames := make([]string, 0, len(req.Tools))
for _, t := range req.Tools {
toolNames = append(toolNames, t.Name)
}
log.Printf("[tools] caller has %d executors: %v", len(toolNames), toolNames)
}
if r.Header.Get(workspaceHeaderName) == "" {
if detected := detectAnthropicCwd(req); detected != "" {
log.Printf("[workspace] detected caller cwd from prompt: %s", detected)
r.Header.Set(workspaceHeaderName, detected)
}
}
msgID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
prompt := buildPromptFromAnthropicMessages(req, s.cfg.SystemPrompt)
if req.Stream {
s.streamAnthropicMessages(w, r, prompt, cursorModel, model, msgID, sessionKey)
return
}
s.nonStreamAnthropicMessages(w, r, prompt, cursorModel, model, msgID, sessionKey)
}
func (s *Server) streamAnthropicMessages(w http.ResponseWriter, r *http.Request, prompt, cursorModel, displayModel, msgID, sessionKey string) {
sse := NewSSEWriter(w)
parser := converter.NewStreamParser(msgID)
tcParser := NewToolCallStreamParser()
ctx, cancel := context.WithTimeout(requestContext(r), time.Duration(s.cfg.Timeout)*time.Second)
defer cancel()
go func() {
<-r.Context().Done()
cancel()
}()
outputChan, errChan := s.br.Execute(ctx, prompt, cursorModel, sessionKey)
writeAnthropicSSE(sse, map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"model": displayModel,
"content": []interface{}{},
},
})
st := &anthropicStreamState{
sse: sse,
blockIndex: 0,
}
emitText := func(text string) {
if text == "" {
return
}
st.ensureTextBlockOpen()
writeAnthropicSSE(sse, map[string]interface{}{
"type": "content_block_delta",
"index": st.blockIndex,
"delta": map[string]interface{}{"type": "text_delta", "text": text},
})
st.outChars += len(text)
}
emitToolCall := func(call ParsedToolCall) {
st.closeTextBlockIfOpen()
st.blockIndex++
toolID := newToolUseID()
writeAnthropicSSE(sse, map[string]interface{}{
"type": "content_block_start",
"index": st.blockIndex,
"content_block": map[string]interface{}{
"type": "tool_use",
"id": toolID,
"name": call.Name,
"input": map[string]interface{}{},
},
})
writeAnthropicSSE(sse, map[string]interface{}{
"type": "content_block_delta",
"index": st.blockIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(call.Input),
},
})
writeAnthropicSSE(sse, map[string]interface{}{
"type": "content_block_stop",
"index": st.blockIndex,
})
st.toolCallsEmitted++
}
feedDelta := func(content string) bool {
emit, calls, err := tcParser.Feed(content)
emitText(emit)
for _, c := range calls {
emitToolCall(c)
}
if err != nil {
log.Printf("[tool_call] parse error: %v", err)
}
return true
}
for line := range outputChan {
result := parser.Parse(line)
if result.Skip {
continue
}
if result.Error != nil {
if strings.Contains(result.Error.Error(), "unmarshal error") {
result = parser.ParseRawText(line)
if result.Skip {
continue
}
if result.Chunk != nil && len(result.Chunk.Choices) > 0 {
if c := result.Chunk.Choices[0].Delta.Content; c != nil {
feedDelta(*c)
continue
}
}
}
writeAnthropicSSE(sse, map[string]interface{}{
"type": "error",
"error": map[string]interface{}{"type": "api_error", "message": result.Error.Error()},
})
return
}
if result.Chunk != nil && len(result.Chunk.Choices) > 0 {
if c := result.Chunk.Choices[0].Delta.Content; c != nil {
feedDelta(*c)
}
}
if result.Done {
break
}
}
if leftover, err := tcParser.Flush(); leftover != "" {
emitText(leftover)
if err != nil {
log.Printf("[tool_call] flush warning: %v", err)
}
}
st.closeTextBlockIfOpen()
stopReason := "end_turn"
if st.toolCallsEmitted > 0 {
stopReason = "tool_use"
}
outTokens := maxInt(1, st.outChars/4)
writeAnthropicSSE(sse, map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": stopReason, "stop_sequence": nil},
"usage": map[string]interface{}{"output_tokens": outTokens},
})
writeAnthropicSSE(sse, map[string]interface{}{
"type": "message_stop",
})
select {
case <-errChan:
default:
}
}
func (s *Server) nonStreamAnthropicMessages(w http.ResponseWriter, r *http.Request, prompt, cursorModel, displayModel, msgID, sessionKey string) {
ctx, cancel := context.WithTimeout(requestContext(r), time.Duration(s.cfg.Timeout)*time.Second)
defer cancel()
go func() {
<-r.Context().Done()
cancel()
}()
rawContent, err := s.br.ExecuteSync(ctx, prompt, cursorModel, sessionKey)
if err != nil {
writeJSON(w, http.StatusInternalServerError, types.NewErrorResponse(err.Error(), "api_error", ""))
return
}
cleanText, calls := ExtractAllToolCalls(rawContent)
usage := estimateUsage(prompt, rawContent)
var content []types.AnthropicResponseBlock
if cleanText != "" {
content = append(content, types.AnthropicResponseBlock{Type: "text", Text: cleanText})
}
for _, c := range calls {
content = append(content, types.AnthropicResponseBlock{
Type: "tool_use",
ID: newToolUseID(),
Name: c.Name,
Input: c.Input,
})
}
if len(content) == 0 {
content = append(content, types.AnthropicResponseBlock{Type: "text", Text: ""})
}
stopReason := "end_turn"
if len(calls) > 0 {
stopReason = "tool_use"
}
resp := types.AnthropicMessagesResponse{
ID: msgID,
Type: "message",
Role: "assistant",
Content: content,
Model: displayModel,
StopReason: stopReason,
Usage: types.AnthropicUsage{
InputTokens: usage.PromptTokens,
OutputTokens: usage.CompletionTokens,
},
}
writeJSON(w, http.StatusOK, resp)
}
// anthropicStreamState tracks per-request streaming state: which content
// block index we are on, whether the current text block is open, output
// character count for usage estimation, and how many tool_use blocks were
// emitted so we can pick stop_reason.
type anthropicStreamState struct {
sse *SSEWriter
blockIndex int
textOpen bool
outChars int
toolCallsEmitted int
}
func (st *anthropicStreamState) ensureTextBlockOpen() {
if st.textOpen {
return
}
writeAnthropicSSE(st.sse, map[string]interface{}{
"type": "content_block_start",
"index": st.blockIndex,
"content_block": map[string]interface{}{"type": "text", "text": ""},
})
st.textOpen = true
}
func (st *anthropicStreamState) closeTextBlockIfOpen() {
if !st.textOpen {
return
}
writeAnthropicSSE(st.sse, map[string]interface{}{
"type": "content_block_stop",
"index": st.blockIndex,
})
st.textOpen = false
}
func newToolUseID() string {
var b [12]byte
if _, err := rand.Read(b[:]); err != nil {
return fmt.Sprintf("toolu_%d", time.Now().UnixNano())
}
return "toolu_" + hex.EncodeToString(b[:])
}
func writeAnthropicSSE(sse *SSEWriter, event interface{}) {
data, err := json.Marshal(event)
if err != nil {
return
}
fmt.Fprintf(sse.w, "data: %s\n\n", data)
if sse.flush != nil {
sse.flush.Flush()
}
}