332 lines
9.2 KiB
Go
332 lines
9.2 KiB
Go
|
|
package server_test
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"cursor-api-proxy/internal/config"
|
|||
|
|
"cursor-api-proxy/internal/server"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"net"
|
|||
|
|
"context"
|
|||
|
|
"net/http"
|
|||
|
|
"os"
|
|||
|
|
"strings"
|
|||
|
|
"testing"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// freePort 取得一個暫時可用的隨機 port
|
|||
|
|
func freePort(t *testing.T) int {
|
|||
|
|
t.Helper()
|
|||
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatal(err)
|
|||
|
|
}
|
|||
|
|
port := l.Addr().(*net.TCPAddr).Port
|
|||
|
|
l.Close()
|
|||
|
|
return port
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// makeFakeAgentBin 建立一個 shell script,模擬 agent 固定輸出
|
|||
|
|
// sync 模式:直接輸出一行文字
|
|||
|
|
// stream 模式:輸出 JSON stream 行
|
|||
|
|
func makeFakeAgentBin(t *testing.T, syncOutput string) string {
|
|||
|
|
t.Helper()
|
|||
|
|
dir := t.TempDir()
|
|||
|
|
script := dir + "/agent"
|
|||
|
|
content := fmt.Sprintf(`#!/bin/sh
|
|||
|
|
# 若有 --stream-json 則輸出 stream 格式
|
|||
|
|
for arg; do
|
|||
|
|
if [ "$arg" = "--stream-json" ]; then
|
|||
|
|
printf '%%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"%s"}]}}'
|
|||
|
|
printf '%%s\n' '{"type":"result","subtype":"success"}'
|
|||
|
|
exit 0
|
|||
|
|
fi
|
|||
|
|
done
|
|||
|
|
# 否則輸出 sync 格式
|
|||
|
|
printf '%%s' '%s'
|
|||
|
|
`, syncOutput, syncOutput)
|
|||
|
|
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
|||
|
|
t.Fatal(err)
|
|||
|
|
}
|
|||
|
|
return script
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// makeFakeAgentBinWithModels 額外支援 --list-models 輸出
|
|||
|
|
func makeFakeAgentBinWithModels(t *testing.T) string {
|
|||
|
|
t.Helper()
|
|||
|
|
dir := t.TempDir()
|
|||
|
|
script := dir + "/agent"
|
|||
|
|
content := `#!/bin/sh
|
|||
|
|
for arg; do
|
|||
|
|
if [ "$arg" = "--list-models" ]; then
|
|||
|
|
printf 'claude-3-opus - Claude 3 Opus\n'
|
|||
|
|
printf 'claude-3-sonnet - Claude 3 Sonnet\n'
|
|||
|
|
exit 0
|
|||
|
|
fi
|
|||
|
|
if [ "$arg" = "--stream-json" ]; then
|
|||
|
|
printf '%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}'
|
|||
|
|
printf '%s\n' '{"type":"result","subtype":"success"}'
|
|||
|
|
exit 0
|
|||
|
|
fi
|
|||
|
|
done
|
|||
|
|
printf 'Hello from agent'
|
|||
|
|
`
|
|||
|
|
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
|||
|
|
t.Fatal(err)
|
|||
|
|
}
|
|||
|
|
return script
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func makeTestConfig(agentBin string, port int, overrides ...func(*config.BridgeConfig)) config.BridgeConfig {
|
|||
|
|
cfg := config.BridgeConfig{
|
|||
|
|
AgentBin: agentBin,
|
|||
|
|
Host: "127.0.0.1",
|
|||
|
|
Port: port,
|
|||
|
|
DefaultModel: "auto",
|
|||
|
|
Mode: "ask",
|
|||
|
|
Force: false,
|
|||
|
|
ApproveMcps: false,
|
|||
|
|
StrictModel: true,
|
|||
|
|
Workspace: os.TempDir(),
|
|||
|
|
TimeoutMs: 30000,
|
|||
|
|
SessionsLogPath: os.TempDir() + "/test-sessions.log",
|
|||
|
|
ChatOnlyWorkspace: true,
|
|||
|
|
Verbose: false,
|
|||
|
|
MaxMode: false,
|
|||
|
|
ConfigDirs: []string{},
|
|||
|
|
MultiPort: false,
|
|||
|
|
WinCmdlineMax: 30000,
|
|||
|
|
}
|
|||
|
|
for _, fn := range overrides {
|
|||
|
|
fn(&cfg)
|
|||
|
|
}
|
|||
|
|
return cfg
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func waitListening(t *testing.T, host string, port int, timeout time.Duration) {
|
|||
|
|
t.Helper()
|
|||
|
|
deadline := time.Now().Add(timeout)
|
|||
|
|
for time.Now().Before(deadline) {
|
|||
|
|
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 50*time.Millisecond)
|
|||
|
|
if err == nil {
|
|||
|
|
conn.Close()
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
time.Sleep(20 * time.Millisecond)
|
|||
|
|
}
|
|||
|
|
t.Fatalf("server on port %d did not start within %v", port, timeout)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func doRequest(t *testing.T, method, url, body string, headers map[string]string) (int, string) {
|
|||
|
|
t.Helper()
|
|||
|
|
var reqBody io.Reader
|
|||
|
|
if body != "" {
|
|||
|
|
reqBody = strings.NewReader(body)
|
|||
|
|
}
|
|||
|
|
req, err := http.NewRequest(method, url, reqBody)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatal(err)
|
|||
|
|
}
|
|||
|
|
if body != "" {
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
}
|
|||
|
|
for k, v := range headers {
|
|||
|
|
req.Header.Set(k, v)
|
|||
|
|
}
|
|||
|
|
resp, err := http.DefaultClient.Do(req)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatal(err)
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
data, _ := io.ReadAll(resp.Body)
|
|||
|
|
return resp.StatusCode, string(data)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestBridgeServer_Health(t *testing.T) {
|
|||
|
|
port := freePort(t)
|
|||
|
|
agentBin := makeFakeAgentBinWithModels(t)
|
|||
|
|
cfg := makeTestConfig(agentBin, port)
|
|||
|
|
|
|||
|
|
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|||
|
|
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|||
|
|
defer func() {
|
|||
|
|
for _, s := range srvs {
|
|||
|
|
s.Shutdown(context.Background())
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
|||
|
|
if status != 200 {
|
|||
|
|
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
|||
|
|
}
|
|||
|
|
var result map[string]interface{}
|
|||
|
|
json.Unmarshal([]byte(body), &result)
|
|||
|
|
if result["ok"] != true {
|
|||
|
|
t.Errorf("ok = %v, want true", result["ok"])
|
|||
|
|
}
|
|||
|
|
if result["version"] != "1.0.0" {
|
|||
|
|
t.Errorf("version = %v, want 1.0.0", result["version"])
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestBridgeServer_Models(t *testing.T) {
|
|||
|
|
port := freePort(t)
|
|||
|
|
agentBin := makeFakeAgentBinWithModels(t)
|
|||
|
|
cfg := makeTestConfig(agentBin, port)
|
|||
|
|
|
|||
|
|
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|||
|
|
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|||
|
|
defer func() {
|
|||
|
|
for _, s := range srvs {
|
|||
|
|
s.Shutdown(context.Background())
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/v1/models", port), "", nil)
|
|||
|
|
if status != 200 {
|
|||
|
|
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
|||
|
|
}
|
|||
|
|
var result map[string]interface{}
|
|||
|
|
json.Unmarshal([]byte(body), &result)
|
|||
|
|
if result["object"] != "list" {
|
|||
|
|
t.Errorf("object = %v, want list", result["object"])
|
|||
|
|
}
|
|||
|
|
data := result["data"].([]interface{})
|
|||
|
|
if len(data) < 2 {
|
|||
|
|
t.Errorf("data len = %d, want >= 2", len(data))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestBridgeServer_Unauthorized(t *testing.T) {
|
|||
|
|
port := freePort(t)
|
|||
|
|
agentBin := makeFakeAgentBinWithModels(t)
|
|||
|
|
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
|||
|
|
c.RequiredKey = "secret123"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|||
|
|
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|||
|
|
defer func() {
|
|||
|
|
for _, s := range srvs {
|
|||
|
|
s.Shutdown(context.Background())
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
|||
|
|
if status != 401 {
|
|||
|
|
t.Fatalf("status = %d, want 401; body: %s", status, body)
|
|||
|
|
}
|
|||
|
|
var result map[string]interface{}
|
|||
|
|
json.Unmarshal([]byte(body), &result)
|
|||
|
|
errObj := result["error"].(map[string]interface{})
|
|||
|
|
if errObj["message"] != "Invalid API key" {
|
|||
|
|
t.Errorf("message = %v, want 'Invalid API key'", errObj["message"])
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestBridgeServer_AuthorizedKey(t *testing.T) {
|
|||
|
|
port := freePort(t)
|
|||
|
|
agentBin := makeFakeAgentBinWithModels(t)
|
|||
|
|
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
|||
|
|
c.RequiredKey = "secret123"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|||
|
|
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|||
|
|
defer func() {
|
|||
|
|
for _, s := range srvs {
|
|||
|
|
s.Shutdown(context.Background())
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
status, _ := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", map[string]string{
|
|||
|
|
"Authorization": "Bearer secret123",
|
|||
|
|
})
|
|||
|
|
if status != 200 {
|
|||
|
|
t.Errorf("status = %d, want 200", status)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestBridgeServer_NotFound(t *testing.T) {
|
|||
|
|
port := freePort(t)
|
|||
|
|
agentBin := makeFakeAgentBinWithModels(t)
|
|||
|
|
cfg := makeTestConfig(agentBin, port)
|
|||
|
|
|
|||
|
|
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|||
|
|
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|||
|
|
defer func() {
|
|||
|
|
for _, s := range srvs {
|
|||
|
|
s.Shutdown(context.Background())
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/unknown", port), "", nil)
|
|||
|
|
if status != 404 {
|
|||
|
|
t.Fatalf("status = %d, want 404; body: %s", status, body)
|
|||
|
|
}
|
|||
|
|
var result map[string]interface{}
|
|||
|
|
json.Unmarshal([]byte(body), &result)
|
|||
|
|
errObj := result["error"].(map[string]interface{})
|
|||
|
|
if errObj["code"] != "not_found" {
|
|||
|
|
t.Errorf("code = %v, want not_found", errObj["code"])
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestBridgeServer_ChatCompletions_Sync(t *testing.T) {
|
|||
|
|
port := freePort(t)
|
|||
|
|
agentBin := makeFakeAgentBin(t, "Hello from agent")
|
|||
|
|
cfg := makeTestConfig(agentBin, port)
|
|||
|
|
|
|||
|
|
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|||
|
|
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
|||
|
|
defer func() {
|
|||
|
|
for _, s := range srvs {
|
|||
|
|
s.Shutdown(context.Background())
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hi"}]}`
|
|||
|
|
status, body := doRequest(t, "POST", fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", port), reqBody, nil)
|
|||
|
|
if status != 200 {
|
|||
|
|
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
|||
|
|
}
|
|||
|
|
var result map[string]interface{}
|
|||
|
|
json.Unmarshal([]byte(body), &result)
|
|||
|
|
if result["object"] != "chat.completion" {
|
|||
|
|
t.Errorf("object = %v, want chat.completion", result["object"])
|
|||
|
|
}
|
|||
|
|
choices := result["choices"].([]interface{})
|
|||
|
|
msg := choices[0].(map[string]interface{})["message"].(map[string]interface{})
|
|||
|
|
if msg["content"] != "Hello from agent" {
|
|||
|
|
t.Errorf("content = %v, want 'Hello from agent'", msg["content"])
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestBridgeServer_MultiPort(t *testing.T) {
|
|||
|
|
basePort := freePort(t)
|
|||
|
|
agentBin := makeFakeAgentBinWithModels(t)
|
|||
|
|
|
|||
|
|
dir1 := t.TempDir()
|
|||
|
|
dir2 := t.TempDir()
|
|||
|
|
|
|||
|
|
cfg := makeTestConfig(agentBin, basePort, func(c *config.BridgeConfig) {
|
|||
|
|
c.ConfigDirs = []string{dir1, dir2}
|
|||
|
|
c.MultiPort = true
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
|||
|
|
if len(srvs) != 2 {
|
|||
|
|
t.Fatalf("got %d servers, want 2", len(srvs))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 等待兩個 server 啟動(port 可能會衝突,這裡不嚴格測試 port 分配)
|
|||
|
|
time.Sleep(200 * time.Millisecond)
|
|||
|
|
|
|||
|
|
defer func() {
|
|||
|
|
for _, s := range srvs {
|
|||
|
|
s.Shutdown(context.Background())
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
}
|