154 lines
3.8 KiB
Go
154 lines
3.8 KiB
Go
package toolcall
|
|
|
|
import (
|
|
"encoding/json"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
type ToolCall struct {
|
|
Name string
|
|
Arguments string // JSON string
|
|
}
|
|
|
|
type ParsedResponse struct {
|
|
TextContent string
|
|
ToolCalls []ToolCall
|
|
}
|
|
|
|
func (p *ParsedResponse) HasToolCalls() bool {
|
|
return len(p.ToolCalls) > 0
|
|
}
|
|
|
|
var toolCallTagRe = regexp.MustCompile(`(?s)<tool_call>\s*(\{.*?\})\s*</tool_call>`)
|
|
var antmlFunctionCallsRe = regexp.MustCompile(`(?s)<function_calls>\s*(.*?)\s*</function_calls>`)
|
|
var antmlInvokeRe = regexp.MustCompile(`(?s)<invoke\s+name="([^"]+)">\s*(.*?)\s*</invoke>`)
|
|
var antmlParamRe = regexp.MustCompile(`(?s)<parameter\s+name="([^"]+)">(.*?)</parameter>`)
|
|
|
|
func ExtractToolCalls(text string, toolNames map[string]bool) *ParsedResponse {
|
|
result := &ParsedResponse{}
|
|
|
|
if locs := toolCallTagRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
|
|
var calls []ToolCall
|
|
var textParts []string
|
|
last := 0
|
|
for _, loc := range locs {
|
|
if loc[0] > last {
|
|
textParts = append(textParts, text[last:loc[0]])
|
|
}
|
|
jsonStr := text[loc[2]:loc[3]]
|
|
if tc := parseToolCallJSON(jsonStr, toolNames); tc != nil {
|
|
calls = append(calls, *tc)
|
|
} else {
|
|
textParts = append(textParts, text[loc[0]:loc[1]])
|
|
}
|
|
last = loc[1]
|
|
}
|
|
if last < len(text) {
|
|
textParts = append(textParts, text[last:])
|
|
}
|
|
if len(calls) > 0 {
|
|
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
|
|
result.ToolCalls = calls
|
|
return result
|
|
}
|
|
}
|
|
|
|
if locs := antmlFunctionCallsRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
|
|
var calls []ToolCall
|
|
var textParts []string
|
|
last := 0
|
|
for _, loc := range locs {
|
|
if loc[0] > last {
|
|
textParts = append(textParts, text[last:loc[0]])
|
|
}
|
|
block := text[loc[2]:loc[3]]
|
|
invokes := antmlInvokeRe.FindAllStringSubmatch(block, -1)
|
|
for _, inv := range invokes {
|
|
name := inv[1]
|
|
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
|
|
continue
|
|
}
|
|
body := inv[2]
|
|
args := map[string]interface{}{}
|
|
params := antmlParamRe.FindAllStringSubmatch(body, -1)
|
|
for _, p := range params {
|
|
paramName := p[1]
|
|
paramValue := strings.TrimSpace(p[2])
|
|
var jsonVal interface{}
|
|
if err := json.Unmarshal([]byte(paramValue), &jsonVal); err == nil {
|
|
args[paramName] = jsonVal
|
|
} else {
|
|
args[paramName] = paramValue
|
|
}
|
|
}
|
|
argsJSON, _ := json.Marshal(args)
|
|
calls = append(calls, ToolCall{Name: name, Arguments: string(argsJSON)})
|
|
}
|
|
last = loc[1]
|
|
}
|
|
if last < len(text) {
|
|
textParts = append(textParts, text[last:])
|
|
}
|
|
if len(calls) > 0 {
|
|
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
|
|
result.ToolCalls = calls
|
|
return result
|
|
}
|
|
}
|
|
|
|
result.TextContent = text
|
|
return result
|
|
}
|
|
|
|
func parseToolCallJSON(jsonStr string, toolNames map[string]bool) *ToolCall {
|
|
var raw map[string]interface{}
|
|
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
|
|
return nil
|
|
}
|
|
name, _ := raw["name"].(string)
|
|
if name == "" {
|
|
return nil
|
|
}
|
|
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
|
|
return nil
|
|
}
|
|
var argsStr string
|
|
switch a := raw["arguments"].(type) {
|
|
case string:
|
|
argsStr = a
|
|
case map[string]interface{}, []interface{}:
|
|
b, _ := json.Marshal(a)
|
|
argsStr = string(b)
|
|
default:
|
|
if p, ok := raw["parameters"]; ok {
|
|
b, _ := json.Marshal(p)
|
|
argsStr = string(b)
|
|
} else {
|
|
argsStr = "{}"
|
|
}
|
|
}
|
|
return &ToolCall{Name: name, Arguments: argsStr}
|
|
}
|
|
|
|
func CollectToolNames(tools []interface{}) map[string]bool {
|
|
names := map[string]bool{}
|
|
for _, t := range tools {
|
|
m, ok := t.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
if m["type"] == "function" {
|
|
if fn, ok := m["function"].(map[string]interface{}); ok {
|
|
if name, ok := fn["name"].(string); ok {
|
|
names[name] = true
|
|
}
|
|
}
|
|
}
|
|
if name, ok := m["name"].(string); ok {
|
|
names[name] = true
|
|
}
|
|
}
|
|
return names
|
|
}
|