merge: refactor/adapter
This commit is contained in:
commit
ef4b6519f5
|
|
@ -0,0 +1,174 @@
|
|||
package anthropic
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/pkg/adapter/openai"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System interface{} `json:"system"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []interface{} `json:"tools"`
|
||||
}
|
||||
|
||||
func systemToText(system interface{}) string {
|
||||
if system == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if m, ok := p.(map[string]interface{}); ok {
|
||||
if m["type"] == "text" {
|
||||
if t, ok := m["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func anthropicBlockToText(p interface{}) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := p.(type) {
|
||||
case string:
|
||||
return v
|
||||
case map[string]interface{}:
|
||||
typ, _ := v["type"].(string)
|
||||
switch typ {
|
||||
case "text":
|
||||
if t, ok := v["text"].(string); ok {
|
||||
return t
|
||||
}
|
||||
case "image":
|
||||
if src, ok := v["source"].(map[string]interface{}); ok {
|
||||
srcType, _ := src["type"].(string)
|
||||
switch srcType {
|
||||
case "base64":
|
||||
mt, _ := src["media_type"].(string)
|
||||
if mt == "" {
|
||||
mt = "image"
|
||||
}
|
||||
return "[Image: base64 " + mt + "]"
|
||||
case "url":
|
||||
url, _ := src["url"].(string)
|
||||
return "[Image: " + url + "]"
|
||||
}
|
||||
}
|
||||
return "[Image]"
|
||||
case "document":
|
||||
title, _ := v["title"].(string)
|
||||
if title == "" {
|
||||
if src, ok := v["source"].(map[string]interface{}); ok {
|
||||
title, _ = src["url"].(string)
|
||||
}
|
||||
}
|
||||
if title != "" {
|
||||
return "[Document: " + title + "]"
|
||||
}
|
||||
return "[Document]"
|
||||
case "tool_use":
|
||||
name, _ := v["name"].(string)
|
||||
id, _ := v["id"].(string)
|
||||
input := v["input"]
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
if inputJSON == nil {
|
||||
inputJSON = []byte("{}")
|
||||
}
|
||||
tag := fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, string(inputJSON))
|
||||
if id != "" {
|
||||
tag = fmt.Sprintf("[tool_use_id=%s] ", id) + tag
|
||||
}
|
||||
return tag
|
||||
case "tool_result":
|
||||
toolUseID, _ := v["tool_use_id"].(string)
|
||||
content := v["content"]
|
||||
var contentText string
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
contentText = c
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, block := range c {
|
||||
if bm, ok := block.(map[string]interface{}); ok {
|
||||
if bm["type"] == "text" {
|
||||
if t, ok := bm["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
contentText = strings.Join(parts, "\n")
|
||||
}
|
||||
label := "Tool result"
|
||||
if toolUseID != "" {
|
||||
label += " [id=" + toolUseID + "]"
|
||||
}
|
||||
return label + ": " + contentText
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func anthropicContentToText(content interface{}) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if t := anthropicBlockToText(p); t != "" {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func BuildPromptFromAnthropicMessages(messages []MessageParam, system interface{}) string {
|
||||
var oaiMessages []interface{}
|
||||
|
||||
systemText := systemToText(system)
|
||||
if systemText != "" {
|
||||
oaiMessages = append(oaiMessages, map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": systemText,
|
||||
})
|
||||
}
|
||||
|
||||
for _, m := range messages {
|
||||
text := anthropicContentToText(m.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
role := m.Role
|
||||
if role != "user" && role != "assistant" {
|
||||
role = "user"
|
||||
}
|
||||
oaiMessages = append(oaiMessages, map[string]interface{}{
|
||||
"role": role,
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
|
||||
return openai.BuildPromptFromMessages(oaiMessages)
|
||||
}
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
package anthropic_test
|
||||
|
||||
import (
|
||||
"cursor-api-proxy/pkg/adapter/anthropic"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_Simple(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "Hello") {
|
||||
t.Errorf("prompt missing user message: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "Hi there") {
|
||||
t.Errorf("prompt missing assistant message: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_WithSystem(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "ping"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, "You are a helpful bot.")
|
||||
if !strings.Contains(prompt, "You are a helpful bot.") {
|
||||
t.Errorf("prompt missing system: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "ping") {
|
||||
t.Errorf("prompt missing user: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_SystemArray(t *testing.T) {
|
||||
system := []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "Part A"},
|
||||
map[string]interface{}{"type": "text", "text": "Part B"},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: "test"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, system)
|
||||
if !strings.Contains(prompt, "Part A") {
|
||||
t.Errorf("prompt missing Part A: %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "Part B") {
|
||||
t.Errorf("prompt missing Part B: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_ContentBlocks(t *testing.T) {
|
||||
content := []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "block one"},
|
||||
map[string]interface{}{"type": "text", "text": "block two"},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: content},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "block one") {
|
||||
t.Errorf("prompt missing 'block one': %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "block two") {
|
||||
t.Errorf("prompt missing 'block two': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_ImageBlock(t *testing.T) {
|
||||
content := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "abc123",
|
||||
},
|
||||
},
|
||||
}
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: content},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "[Image") {
|
||||
t.Errorf("prompt missing [Image]: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_EmptyContentSkipped(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "user", Content: ""},
|
||||
{Role: "assistant", Content: "response"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "response") {
|
||||
t.Errorf("prompt missing 'response': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromAnthropicMessages_UnknownRoleBecomesUser(t *testing.T) {
|
||||
messages := []anthropic.MessageParam{
|
||||
{Role: "system", Content: "system-as-user"},
|
||||
}
|
||||
prompt := anthropic.BuildPromptFromAnthropicMessages(messages, nil)
|
||||
if !strings.Contains(prompt, "system-as-user") {
|
||||
t.Errorf("prompt missing 'system-as-user': %q", prompt)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []interface{} `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []interface{} `json:"tools"`
|
||||
ToolChoice interface{} `json:"tool_choice"`
|
||||
Functions []interface{} `json:"functions"`
|
||||
FunctionCall interface{} `json:"function_call"`
|
||||
}
|
||||
|
||||
func NormalizeModelID(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(trimmed, "/")
|
||||
last := parts[len(parts)-1]
|
||||
if last == "" {
|
||||
return ""
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func imageURLToText(imageURL interface{}) string {
|
||||
if imageURL == nil {
|
||||
return "[Image]"
|
||||
}
|
||||
var url string
|
||||
switch v := imageURL.(type) {
|
||||
case string:
|
||||
url = v
|
||||
case map[string]interface{}:
|
||||
if u, ok := v["url"].(string); ok {
|
||||
url = u
|
||||
}
|
||||
}
|
||||
if url == "" {
|
||||
return "[Image]"
|
||||
}
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
end := strings.Index(url, ";")
|
||||
mime := "image"
|
||||
if end > 5 {
|
||||
mime = url[5:end]
|
||||
}
|
||||
return "[Image: base64 " + mime + "]"
|
||||
}
|
||||
return "[Image: " + url + "]"
|
||||
}
|
||||
|
||||
func MessageContentToText(content interface{}) string {
|
||||
if content == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var parts []string
|
||||
for _, p := range v {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
switch part := p.(type) {
|
||||
case string:
|
||||
parts = append(parts, part)
|
||||
case map[string]interface{}:
|
||||
typ, _ := part["type"].(string)
|
||||
switch typ {
|
||||
case "text":
|
||||
if t, ok := part["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
case "image_url":
|
||||
parts = append(parts, imageURLToText(part["image_url"]))
|
||||
case "image":
|
||||
src := part["source"]
|
||||
if src == nil {
|
||||
src = part["url"]
|
||||
}
|
||||
parts = append(parts, imageURLToText(src))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func ToolsToSystemText(tools []interface{}, functions []interface{}) string {
|
||||
var defs []interface{}
|
||||
|
||||
for _, t := range tools {
|
||||
if m, ok := t.(map[string]interface{}); ok {
|
||||
if m["type"] == "function" {
|
||||
if fn := m["function"]; fn != nil {
|
||||
defs = append(defs, fn)
|
||||
}
|
||||
} else {
|
||||
defs = append(defs, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
defs = append(defs, functions...)
|
||||
|
||||
if len(defs) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, "Available tools (respond with a JSON object to call one):", "")
|
||||
|
||||
for _, raw := range defs {
|
||||
fn, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
params := "{}"
|
||||
if p := fn["parameters"]; p != nil {
|
||||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
||||
params = string(b)
|
||||
}
|
||||
} else if p := fn["input_schema"]; p != nil {
|
||||
if b, err := json.MarshalIndent(p, "", " "); err == nil {
|
||||
params = string(b)
|
||||
}
|
||||
}
|
||||
lines = append(lines, "Function: "+name+"\nDescription: "+desc+"\nParameters: "+params)
|
||||
}
|
||||
|
||||
lines = append(lines, "",
|
||||
"When you want to call a tool, use this EXACT format:",
|
||||
"",
|
||||
"<tool_call>",
|
||||
`{"name": "function_name", "arguments": {"param1": "value1"}}`,
|
||||
"</tool_call>",
|
||||
"",
|
||||
"Rules:",
|
||||
"- Write your reasoning BEFORE the tool call",
|
||||
"- You may make multiple tool calls by using multiple <tool_call> blocks",
|
||||
"- STOP writing after the last </tool_call> tag",
|
||||
"- If no tool is needed, respond normally without <tool_call> tags",
|
||||
)
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
type SimpleMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func BuildPromptFromMessages(messages []interface{}) string {
|
||||
var systemParts []string
|
||||
var convo []string
|
||||
|
||||
for _, raw := range messages {
|
||||
m, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := m["role"].(string)
|
||||
text := MessageContentToText(m["content"])
|
||||
|
||||
switch role {
|
||||
case "system", "developer":
|
||||
if text != "" {
|
||||
systemParts = append(systemParts, text)
|
||||
}
|
||||
case "user":
|
||||
if text != "" {
|
||||
convo = append(convo, "User: "+text)
|
||||
}
|
||||
case "assistant":
|
||||
toolCalls, _ := m["tool_calls"].([]interface{})
|
||||
if len(toolCalls) > 0 {
|
||||
var parts []string
|
||||
if text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
tcMap, ok := tc.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tcMap["function"].(map[string]interface{})
|
||||
if fn == nil {
|
||||
continue
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
args, _ := fn["arguments"].(string)
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("<tool_call>\n{\"name\": \"%s\", \"arguments\": %s}\n</tool_call>", name, args))
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
convo = append(convo, "Assistant: "+strings.Join(parts, "\n"))
|
||||
}
|
||||
} else if text != "" {
|
||||
convo = append(convo, "Assistant: "+text)
|
||||
}
|
||||
case "tool", "function":
|
||||
name, _ := m["name"].(string)
|
||||
toolCallID, _ := m["tool_call_id"].(string)
|
||||
label := "Tool result"
|
||||
if name != "" {
|
||||
label = "Tool result (" + name + ")"
|
||||
}
|
||||
if toolCallID != "" {
|
||||
label += " [id=" + toolCallID + "]"
|
||||
}
|
||||
if text != "" {
|
||||
convo = append(convo, label+": "+text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
system := ""
|
||||
if len(systemParts) > 0 {
|
||||
system = "System:\n" + strings.Join(systemParts, "\n\n") + "\n\n"
|
||||
}
|
||||
transcript := strings.Join(convo, "\n\n")
|
||||
return system + transcript + "\n\nAssistant:"
|
||||
}
|
||||
|
||||
func BuildPromptFromSimpleMessages(messages []SimpleMessage) string {
|
||||
ifaces := make([]interface{}, len(messages))
|
||||
for i, m := range messages {
|
||||
ifaces[i] = map[string]interface{}{"role": m.Role, "content": m.Content}
|
||||
}
|
||||
return BuildPromptFromMessages(ifaces)
|
||||
}
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"gpt-4", "gpt-4"},
|
||||
{"openai/gpt-4", "gpt-4"},
|
||||
{"anthropic/claude-3", "claude-3"},
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"a/b/c", "c"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := NormalizeModelID(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("NormalizeModelID(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptFromMessages(t *testing.T) {
|
||||
messages := []interface{}{
|
||||
map[string]interface{}{"role": "system", "content": "You are helpful."},
|
||||
map[string]interface{}{"role": "user", "content": "Hello"},
|
||||
map[string]interface{}{"role": "assistant", "content": "Hi there"},
|
||||
}
|
||||
got := BuildPromptFromMessages(messages)
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
containsSystem := false
|
||||
containsUser := false
|
||||
containsAssistant := false
|
||||
for i := 0; i < len(got)-10; i++ {
|
||||
if got[i:i+6] == "System" {
|
||||
containsSystem = true
|
||||
}
|
||||
if got[i:i+4] == "User" {
|
||||
containsUser = true
|
||||
}
|
||||
if got[i:i+9] == "Assistant" {
|
||||
containsAssistant = true
|
||||
}
|
||||
}
|
||||
if !containsSystem || !containsUser || !containsAssistant {
|
||||
t.Errorf("prompt missing sections: system=%v user=%v assistant=%v\n%s",
|
||||
containsSystem, containsUser, containsAssistant, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolsToSystemText(t *testing.T) {
|
||||
tools := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": map[string]interface{}{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := ToolsToSystemText(tools, nil)
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty tools text")
|
||||
}
|
||||
if len(got) < 10 {
|
||||
t.Errorf("tools text too short: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolsToSystemTextEmpty(t *testing.T) {
|
||||
got := ToolsToSystemText(nil, nil)
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for no tools, got %q", got)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue