201 lines
5.9 KiB
Go
201 lines
5.9 KiB
Go
|
|
package server
|
||
|
|
|
||
|
|
import (
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Sentinels the brain is instructed to wrap tool calls with. We use XML-ish
|
||
|
|
// tags rather than markdown fences because they are unambiguous and easy to
|
||
|
|
// detect mid-stream without confusing them with normal code blocks.
|
||
|
|
const (
|
||
|
|
toolCallOpen = "<tool_call>"
|
||
|
|
toolCallClose = "</tool_call>"
|
||
|
|
)
|
||
|
|
|
||
|
|
// ParsedToolCall is a successfully extracted tool invocation request from
|
||
|
|
// the brain's text stream.
|
||
|
|
type ParsedToolCall struct {
|
||
|
|
Name string
|
||
|
|
Input json.RawMessage
|
||
|
|
}
|
||
|
|
|
||
|
|
// ToolCallStreamParser is a small streaming state machine that splits an
|
||
|
|
// incoming text stream into:
|
||
|
|
// - safe-to-emit plain text (everything outside <tool_call>...</tool_call>)
|
||
|
|
// - one or more ParsedToolCall (everything between sentinels)
|
||
|
|
//
|
||
|
|
// It buffers just enough trailing bytes to avoid emitting half of an opening
|
||
|
|
// sentinel as text.
|
||
|
|
type ToolCallStreamParser struct {
|
||
|
|
buf strings.Builder
|
||
|
|
inToolCall bool
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewToolCallStreamParser returns a fresh parser.
|
||
|
|
func NewToolCallStreamParser() *ToolCallStreamParser {
|
||
|
|
return &ToolCallStreamParser{}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Feed appends s to the parser's buffer and returns:
|
||
|
|
// - emitText: text safe to forward as text_delta to the caller now
|
||
|
|
// - calls: tool calls fully extracted in this Feed
|
||
|
|
// - err: a malformed tool_call block (invalid JSON inside sentinels)
|
||
|
|
//
|
||
|
|
// Feed never returns text that could be the prefix of an opening sentinel —
|
||
|
|
// such bytes stay buffered until the next Feed/Flush.
|
||
|
|
func (p *ToolCallStreamParser) Feed(s string) (emitText string, calls []ParsedToolCall, err error) {
|
||
|
|
p.buf.WriteString(s)
|
||
|
|
var emitted strings.Builder
|
||
|
|
|
||
|
|
for {
|
||
|
|
current := p.buf.String()
|
||
|
|
if p.inToolCall {
|
||
|
|
closeIdx := strings.Index(current, toolCallClose)
|
||
|
|
if closeIdx < 0 {
|
||
|
|
return emitted.String(), calls, nil
|
||
|
|
}
|
||
|
|
payload := current[:closeIdx]
|
||
|
|
call, perr := parseToolCallPayload(payload)
|
||
|
|
rest := current[closeIdx+len(toolCallClose):]
|
||
|
|
rest = strings.TrimPrefix(rest, "\r")
|
||
|
|
rest = strings.TrimPrefix(rest, "\n")
|
||
|
|
p.buf.Reset()
|
||
|
|
p.buf.WriteString(rest)
|
||
|
|
p.inToolCall = false
|
||
|
|
if perr != nil {
|
||
|
|
return emitted.String(), calls, perr
|
||
|
|
}
|
||
|
|
calls = append(calls, call)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
openIdx := strings.Index(current, toolCallOpen)
|
||
|
|
if openIdx >= 0 {
|
||
|
|
emitted.WriteString(current[:openIdx])
|
||
|
|
rest := current[openIdx+len(toolCallOpen):]
|
||
|
|
rest = strings.TrimPrefix(rest, "\r")
|
||
|
|
rest = strings.TrimPrefix(rest, "\n")
|
||
|
|
p.buf.Reset()
|
||
|
|
p.buf.WriteString(rest)
|
||
|
|
p.inToolCall = true
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
// No open sentinel yet. Emit everything except a potential prefix
|
||
|
|
// of `<tool_call>` lurking at the tail of the buffer.
|
||
|
|
hold := potentialSentinelSuffix(current, toolCallOpen)
|
||
|
|
if hold == 0 {
|
||
|
|
emitted.WriteString(current)
|
||
|
|
p.buf.Reset()
|
||
|
|
return emitted.String(), calls, nil
|
||
|
|
}
|
||
|
|
emitted.WriteString(current[:len(current)-hold])
|
||
|
|
tail := current[len(current)-hold:]
|
||
|
|
p.buf.Reset()
|
||
|
|
p.buf.WriteString(tail)
|
||
|
|
return emitted.String(), calls, nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Flush returns any remaining buffered text and resets the parser. If we
|
||
|
|
// ended mid-`<tool_call>` block (no closing sentinel), the partial content
|
||
|
|
// is returned as plain text — better the caller sees something than data
|
||
|
|
// loss.
|
||
|
|
func (p *ToolCallStreamParser) Flush() (string, error) {
|
||
|
|
leftover := p.buf.String()
|
||
|
|
p.buf.Reset()
|
||
|
|
if p.inToolCall {
|
||
|
|
p.inToolCall = false
|
||
|
|
return toolCallOpen + leftover, fmt.Errorf("unterminated %s block", toolCallOpen)
|
||
|
|
}
|
||
|
|
return leftover, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ExtractAllToolCalls is the non-streaming counterpart: scan the full text
|
||
|
|
// once, return cleaned text (with tool_call blocks removed) plus extracted
|
||
|
|
// calls. Any malformed block is preserved verbatim in the returned text.
|
||
|
|
func ExtractAllToolCalls(text string) (cleanText string, calls []ParsedToolCall) {
|
||
|
|
var out strings.Builder
|
||
|
|
rest := text
|
||
|
|
for {
|
||
|
|
i := strings.Index(rest, toolCallOpen)
|
||
|
|
if i < 0 {
|
||
|
|
out.WriteString(rest)
|
||
|
|
break
|
||
|
|
}
|
||
|
|
out.WriteString(rest[:i])
|
||
|
|
after := rest[i+len(toolCallOpen):]
|
||
|
|
j := strings.Index(after, toolCallClose)
|
||
|
|
if j < 0 {
|
||
|
|
// Unterminated; keep the rest verbatim.
|
||
|
|
out.WriteString(toolCallOpen)
|
||
|
|
out.WriteString(after)
|
||
|
|
break
|
||
|
|
}
|
||
|
|
payload := after[:j]
|
||
|
|
if call, err := parseToolCallPayload(payload); err == nil {
|
||
|
|
calls = append(calls, call)
|
||
|
|
} else {
|
||
|
|
// Keep malformed block as-is so the user can see it.
|
||
|
|
out.WriteString(toolCallOpen)
|
||
|
|
out.WriteString(payload)
|
||
|
|
out.WriteString(toolCallClose)
|
||
|
|
}
|
||
|
|
rest = strings.TrimPrefix(after[j+len(toolCallClose):], "\n")
|
||
|
|
}
|
||
|
|
return strings.TrimSpace(out.String()), calls
|
||
|
|
}
|
||
|
|
|
||
|
|
func parseToolCallPayload(payload string) (ParsedToolCall, error) {
|
||
|
|
trimmed := strings.TrimSpace(payload)
|
||
|
|
// Allow the brain to wrap the JSON in ```json fences too.
|
||
|
|
trimmed = strings.TrimPrefix(trimmed, "```json")
|
||
|
|
trimmed = strings.TrimPrefix(trimmed, "```")
|
||
|
|
trimmed = strings.TrimSuffix(trimmed, "```")
|
||
|
|
trimmed = strings.TrimSpace(trimmed)
|
||
|
|
if trimmed == "" {
|
||
|
|
return ParsedToolCall{}, fmt.Errorf("empty tool_call body")
|
||
|
|
}
|
||
|
|
var raw struct {
|
||
|
|
Name string `json:"name"`
|
||
|
|
Tool string `json:"tool"`
|
||
|
|
Input json.RawMessage `json:"input"`
|
||
|
|
Args json.RawMessage `json:"arguments"`
|
||
|
|
}
|
||
|
|
if err := json.Unmarshal([]byte(trimmed), &raw); err != nil {
|
||
|
|
return ParsedToolCall{}, fmt.Errorf("invalid tool_call json: %w", err)
|
||
|
|
}
|
||
|
|
name := raw.Name
|
||
|
|
if name == "" {
|
||
|
|
name = raw.Tool
|
||
|
|
}
|
||
|
|
if name == "" {
|
||
|
|
return ParsedToolCall{}, fmt.Errorf("tool_call missing name")
|
||
|
|
}
|
||
|
|
input := raw.Input
|
||
|
|
if len(input) == 0 {
|
||
|
|
input = raw.Args
|
||
|
|
}
|
||
|
|
if len(input) == 0 {
|
||
|
|
input = json.RawMessage(`{}`)
|
||
|
|
}
|
||
|
|
return ParsedToolCall{Name: name, Input: input}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// potentialSentinelSuffix returns the length of the longest suffix of s
|
||
|
|
// that is a strict prefix of sentinel.
|
||
|
|
func potentialSentinelSuffix(s, sentinel string) int {
|
||
|
|
maxLen := len(sentinel) - 1
|
||
|
|
if maxLen > len(s) {
|
||
|
|
maxLen = len(s)
|
||
|
|
}
|
||
|
|
for i := maxLen; i > 0; i-- {
|
||
|
|
if strings.HasPrefix(sentinel, s[len(s)-i:]) {
|
||
|
|
return i
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return 0
|
||
|
|
}
|