332 lines
9.2 KiB
Go
332 lines
9.2 KiB
Go
package server_test
|
||
|
||
import (
|
||
"cursor-api-proxy/internal/config"
|
||
"cursor-api-proxy/internal/server"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"context"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
// freePort 取得一個暫時可用的隨機 port
|
||
func freePort(t *testing.T) int {
|
||
t.Helper()
|
||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
port := l.Addr().(*net.TCPAddr).Port
|
||
l.Close()
|
||
return port
|
||
}
|
||
|
||
// makeFakeAgentBin 建立一個 shell script,模擬 agent 固定輸出
|
||
// sync 模式:直接輸出一行文字
|
||
// stream 模式:輸出 JSON stream 行
|
||
func makeFakeAgentBin(t *testing.T, syncOutput string) string {
|
||
t.Helper()
|
||
dir := t.TempDir()
|
||
script := dir + "/agent"
|
||
content := fmt.Sprintf(`#!/bin/sh
|
||
# 若有 --stream-json 則輸出 stream 格式
|
||
for arg; do
|
||
if [ "$arg" = "--stream-json" ]; then
|
||
printf '%%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"%s"}]}}'
|
||
printf '%%s\n' '{"type":"result","subtype":"success"}'
|
||
exit 0
|
||
fi
|
||
done
|
||
# 否則輸出 sync 格式
|
||
printf '%%s' '%s'
|
||
`, syncOutput, syncOutput)
|
||
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
return script
|
||
}
|
||
|
||
// makeFakeAgentBinWithModels 額外支援 --list-models 輸出
|
||
func makeFakeAgentBinWithModels(t *testing.T) string {
|
||
t.Helper()
|
||
dir := t.TempDir()
|
||
script := dir + "/agent"
|
||
content := `#!/bin/sh
|
||
for arg; do
|
||
if [ "$arg" = "--list-models" ]; then
|
||
printf 'claude-3-opus - Claude 3 Opus\n'
|
||
printf 'claude-3-sonnet - Claude 3 Sonnet\n'
|
||
exit 0
|
||
fi
|
||
if [ "$arg" = "--stream-json" ]; then
|
||
printf '%s\n' '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}'
|
||
printf '%s\n' '{"type":"result","subtype":"success"}'
|
||
exit 0
|
||
fi
|
||
done
|
||
printf 'Hello from agent'
|
||
`
|
||
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
return script
|
||
}
|
||
|
||
func makeTestConfig(agentBin string, port int, overrides ...func(*config.BridgeConfig)) config.BridgeConfig {
|
||
cfg := config.BridgeConfig{
|
||
AgentBin: agentBin,
|
||
Host: "127.0.0.1",
|
||
Port: port,
|
||
DefaultModel: "auto",
|
||
Mode: "ask",
|
||
Force: false,
|
||
ApproveMcps: false,
|
||
StrictModel: true,
|
||
Workspace: os.TempDir(),
|
||
TimeoutMs: 30000,
|
||
SessionsLogPath: os.TempDir() + "/test-sessions.log",
|
||
ChatOnlyWorkspace: true,
|
||
Verbose: false,
|
||
MaxMode: false,
|
||
ConfigDirs: []string{},
|
||
MultiPort: false,
|
||
WinCmdlineMax: 30000,
|
||
}
|
||
for _, fn := range overrides {
|
||
fn(&cfg)
|
||
}
|
||
return cfg
|
||
}
|
||
|
||
func waitListening(t *testing.T, host string, port int, timeout time.Duration) {
|
||
t.Helper()
|
||
deadline := time.Now().Add(timeout)
|
||
for time.Now().Before(deadline) {
|
||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 50*time.Millisecond)
|
||
if err == nil {
|
||
conn.Close()
|
||
return
|
||
}
|
||
time.Sleep(20 * time.Millisecond)
|
||
}
|
||
t.Fatalf("server on port %d did not start within %v", port, timeout)
|
||
}
|
||
|
||
func doRequest(t *testing.T, method, url, body string, headers map[string]string) (int, string) {
|
||
t.Helper()
|
||
var reqBody io.Reader
|
||
if body != "" {
|
||
reqBody = strings.NewReader(body)
|
||
}
|
||
req, err := http.NewRequest(method, url, reqBody)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if body != "" {
|
||
req.Header.Set("Content-Type", "application/json")
|
||
}
|
||
for k, v := range headers {
|
||
req.Header.Set(k, v)
|
||
}
|
||
resp, err := http.DefaultClient.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer resp.Body.Close()
|
||
data, _ := io.ReadAll(resp.Body)
|
||
return resp.StatusCode, string(data)
|
||
}
|
||
|
||
func TestBridgeServer_Health(t *testing.T) {
|
||
port := freePort(t)
|
||
agentBin := makeFakeAgentBinWithModels(t)
|
||
cfg := makeTestConfig(agentBin, port)
|
||
|
||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||
defer func() {
|
||
for _, s := range srvs {
|
||
s.Shutdown(context.Background())
|
||
}
|
||
}()
|
||
|
||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
||
if status != 200 {
|
||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||
}
|
||
var result map[string]interface{}
|
||
json.Unmarshal([]byte(body), &result)
|
||
if result["ok"] != true {
|
||
t.Errorf("ok = %v, want true", result["ok"])
|
||
}
|
||
if result["version"] != "1.0.0" {
|
||
t.Errorf("version = %v, want 1.0.0", result["version"])
|
||
}
|
||
}
|
||
|
||
func TestBridgeServer_Models(t *testing.T) {
|
||
port := freePort(t)
|
||
agentBin := makeFakeAgentBinWithModels(t)
|
||
cfg := makeTestConfig(agentBin, port)
|
||
|
||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||
defer func() {
|
||
for _, s := range srvs {
|
||
s.Shutdown(context.Background())
|
||
}
|
||
}()
|
||
|
||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/v1/models", port), "", nil)
|
||
if status != 200 {
|
||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||
}
|
||
var result map[string]interface{}
|
||
json.Unmarshal([]byte(body), &result)
|
||
if result["object"] != "list" {
|
||
t.Errorf("object = %v, want list", result["object"])
|
||
}
|
||
data := result["data"].([]interface{})
|
||
if len(data) < 2 {
|
||
t.Errorf("data len = %d, want >= 2", len(data))
|
||
}
|
||
}
|
||
|
||
func TestBridgeServer_Unauthorized(t *testing.T) {
|
||
port := freePort(t)
|
||
agentBin := makeFakeAgentBinWithModels(t)
|
||
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
||
c.RequiredKey = "secret123"
|
||
})
|
||
|
||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||
defer func() {
|
||
for _, s := range srvs {
|
||
s.Shutdown(context.Background())
|
||
}
|
||
}()
|
||
|
||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", nil)
|
||
if status != 401 {
|
||
t.Fatalf("status = %d, want 401; body: %s", status, body)
|
||
}
|
||
var result map[string]interface{}
|
||
json.Unmarshal([]byte(body), &result)
|
||
errObj := result["error"].(map[string]interface{})
|
||
if errObj["message"] != "Invalid API key" {
|
||
t.Errorf("message = %v, want 'Invalid API key'", errObj["message"])
|
||
}
|
||
}
|
||
|
||
func TestBridgeServer_AuthorizedKey(t *testing.T) {
|
||
port := freePort(t)
|
||
agentBin := makeFakeAgentBinWithModels(t)
|
||
cfg := makeTestConfig(agentBin, port, func(c *config.BridgeConfig) {
|
||
c.RequiredKey = "secret123"
|
||
})
|
||
|
||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||
defer func() {
|
||
for _, s := range srvs {
|
||
s.Shutdown(context.Background())
|
||
}
|
||
}()
|
||
|
||
status, _ := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/health", port), "", map[string]string{
|
||
"Authorization": "Bearer secret123",
|
||
})
|
||
if status != 200 {
|
||
t.Errorf("status = %d, want 200", status)
|
||
}
|
||
}
|
||
|
||
func TestBridgeServer_NotFound(t *testing.T) {
|
||
port := freePort(t)
|
||
agentBin := makeFakeAgentBinWithModels(t)
|
||
cfg := makeTestConfig(agentBin, port)
|
||
|
||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||
defer func() {
|
||
for _, s := range srvs {
|
||
s.Shutdown(context.Background())
|
||
}
|
||
}()
|
||
|
||
status, body := doRequest(t, "GET", fmt.Sprintf("http://127.0.0.1:%d/unknown", port), "", nil)
|
||
if status != 404 {
|
||
t.Fatalf("status = %d, want 404; body: %s", status, body)
|
||
}
|
||
var result map[string]interface{}
|
||
json.Unmarshal([]byte(body), &result)
|
||
errObj := result["error"].(map[string]interface{})
|
||
if errObj["code"] != "not_found" {
|
||
t.Errorf("code = %v, want not_found", errObj["code"])
|
||
}
|
||
}
|
||
|
||
func TestBridgeServer_ChatCompletions_Sync(t *testing.T) {
|
||
port := freePort(t)
|
||
agentBin := makeFakeAgentBin(t, "Hello from agent")
|
||
cfg := makeTestConfig(agentBin, port)
|
||
|
||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||
waitListening(t, "127.0.0.1", port, 3*time.Second)
|
||
defer func() {
|
||
for _, s := range srvs {
|
||
s.Shutdown(context.Background())
|
||
}
|
||
}()
|
||
|
||
reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hi"}]}`
|
||
status, body := doRequest(t, "POST", fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", port), reqBody, nil)
|
||
if status != 200 {
|
||
t.Fatalf("status = %d, want 200; body: %s", status, body)
|
||
}
|
||
var result map[string]interface{}
|
||
json.Unmarshal([]byte(body), &result)
|
||
if result["object"] != "chat.completion" {
|
||
t.Errorf("object = %v, want chat.completion", result["object"])
|
||
}
|
||
choices := result["choices"].([]interface{})
|
||
msg := choices[0].(map[string]interface{})["message"].(map[string]interface{})
|
||
if msg["content"] != "Hello from agent" {
|
||
t.Errorf("content = %v, want 'Hello from agent'", msg["content"])
|
||
}
|
||
}
|
||
|
||
func TestBridgeServer_MultiPort(t *testing.T) {
|
||
basePort := freePort(t)
|
||
agentBin := makeFakeAgentBinWithModels(t)
|
||
|
||
dir1 := t.TempDir()
|
||
dir2 := t.TempDir()
|
||
|
||
cfg := makeTestConfig(agentBin, basePort, func(c *config.BridgeConfig) {
|
||
c.ConfigDirs = []string{dir1, dir2}
|
||
c.MultiPort = true
|
||
})
|
||
|
||
srvs := server.StartBridgeServer(server.ServerOptions{Version: "1.0.0", Config: cfg})
|
||
if len(srvs) != 2 {
|
||
t.Fatalf("got %d servers, want 2", len(srvs))
|
||
}
|
||
|
||
// 等待兩個 server 啟動(port 可能會衝突,這裡不嚴格測試 port 分配)
|
||
time.Sleep(200 * time.Millisecond)
|
||
|
||
defer func() {
|
||
for _, s := range srvs {
|
||
s.Shutdown(context.Background())
|
||
}
|
||
}()
|
||
}
|