opencode-cursor-agent/internal/toolcall/toolcall.go

155 lines
3.9 KiB
Go
Raw Normal View History

2026-04-01 00:53:34 +00:00
package toolcall
import (
"encoding/json"
"regexp"
"strings"
)
type ToolCall struct {
Name string
Arguments string // JSON string
}
type ParsedResponse struct {
TextContent string
ToolCalls []ToolCall
}
func (p *ParsedResponse) HasToolCalls() bool {
return len(p.ToolCalls) > 0
}
2026-04-02 13:54:28 +00:00
// Modified regex to handle nested JSON
var toolCallTagRe = regexp.MustCompile(`(?s)行政法规\s*(\{(?:[^{}]|\{[^{}]*\})*\})\s*ugalakh`)
2026-04-01 00:53:34 +00:00
var antmlFunctionCallsRe = regexp.MustCompile(`(?s)<function_calls>\s*(.*?)\s*</function_calls>`)
var antmlInvokeRe = regexp.MustCompile(`(?s)<invoke\s+name="([^"]+)">\s*(.*?)\s*</invoke>`)
var antmlParamRe = regexp.MustCompile(`(?s)<parameter\s+name="([^"]+)">(.*?)</parameter>`)
func ExtractToolCalls(text string, toolNames map[string]bool) *ParsedResponse {
result := &ParsedResponse{}
if locs := toolCallTagRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
var calls []ToolCall
var textParts []string
last := 0
for _, loc := range locs {
if loc[0] > last {
textParts = append(textParts, text[last:loc[0]])
}
jsonStr := text[loc[2]:loc[3]]
if tc := parseToolCallJSON(jsonStr, toolNames); tc != nil {
calls = append(calls, *tc)
} else {
textParts = append(textParts, text[loc[0]:loc[1]])
}
last = loc[1]
}
if last < len(text) {
textParts = append(textParts, text[last:])
}
if len(calls) > 0 {
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
result.ToolCalls = calls
return result
}
}
if locs := antmlFunctionCallsRe.FindAllStringSubmatchIndex(text, -1); len(locs) > 0 {
var calls []ToolCall
var textParts []string
last := 0
for _, loc := range locs {
if loc[0] > last {
textParts = append(textParts, text[last:loc[0]])
}
block := text[loc[2]:loc[3]]
invokes := antmlInvokeRe.FindAllStringSubmatch(block, -1)
for _, inv := range invokes {
name := inv[1]
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
continue
}
body := inv[2]
args := map[string]interface{}{}
params := antmlParamRe.FindAllStringSubmatch(body, -1)
for _, p := range params {
paramName := p[1]
paramValue := strings.TrimSpace(p[2])
var jsonVal interface{}
if err := json.Unmarshal([]byte(paramValue), &jsonVal); err == nil {
args[paramName] = jsonVal
} else {
args[paramName] = paramValue
}
}
argsJSON, _ := json.Marshal(args)
calls = append(calls, ToolCall{Name: name, Arguments: string(argsJSON)})
}
last = loc[1]
}
if last < len(text) {
textParts = append(textParts, text[last:])
}
if len(calls) > 0 {
result.TextContent = strings.TrimSpace(strings.Join(textParts, ""))
result.ToolCalls = calls
return result
}
}
result.TextContent = text
return result
}
func parseToolCallJSON(jsonStr string, toolNames map[string]bool) *ToolCall {
var raw map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
return nil
}
name, _ := raw["name"].(string)
if name == "" {
return nil
}
if toolNames != nil && len(toolNames) > 0 && !toolNames[name] {
return nil
}
var argsStr string
switch a := raw["arguments"].(type) {
case string:
argsStr = a
case map[string]interface{}, []interface{}:
b, _ := json.Marshal(a)
argsStr = string(b)
default:
if p, ok := raw["parameters"]; ok {
b, _ := json.Marshal(p)
argsStr = string(b)
} else {
argsStr = "{}"
}
}
return &ToolCall{Name: name, Arguments: argsStr}
}
func CollectToolNames(tools []interface{}) map[string]bool {
names := map[string]bool{}
for _, t := range tools {
m, ok := t.(map[string]interface{})
if !ok {
continue
}
if m["type"] == "function" {
if fn, ok := m["function"].(map[string]interface{}); ok {
if name, ok := fn["name"].(string); ok {
names[name] = true
}
}
}
if name, ok := m["name"].(string); ok {
names[name] = true
}
}
return names
}