179 lines
5.9 KiB
Go
179 lines
5.9 KiB
Go
|
|
package provider
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"haixun-backend/internal/model/ai/domain/enum"
|
||
|
|
domai "haixun-backend/internal/model/ai/domain/usecase"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestOpenAICompatible_ListModels(t *testing.T) {
|
||
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
if r.Method != http.MethodGet || r.URL.Path != "/models" {
|
||
|
|
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path)
|
||
|
|
}
|
||
|
|
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
|
||
|
|
t.Fatalf("unexpected auth header: %s", got)
|
||
|
|
}
|
||
|
|
_, _ = w.Write([]byte(`{"data":[{"id":"grok-3-fast"},{"id":"grok-3"},{"id":""}]}`))
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
p := NewOpenAICompatible(enum.ProviderXAI, server.URL)
|
||
|
|
models, err := p.ListModels(context.Background(), domai.Credential{APIKey: "test-token"})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("ListModels() error = %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
want := []string{"grok-3", "grok-3-fast"}
|
||
|
|
if len(models) != len(want) {
|
||
|
|
t.Fatalf("models = %v, want %v", models, want)
|
||
|
|
}
|
||
|
|
for i, id := range want {
|
||
|
|
if models[i] != id {
|
||
|
|
t.Fatalf("models[%d] = %s, want %s", i, models[i], id)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOpenAICompatible_ListModels_MissingToken(t *testing.T) {
|
||
|
|
p := NewOpenAICompatible(enum.ProviderXAI, "https://example.com")
|
||
|
|
_, err := p.ListModels(context.Background(), domai.Credential{})
|
||
|
|
if err == nil {
|
||
|
|
t.Fatal("expected error for missing token")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOpenAICompatible_GenerateText_UsesContent(t *testing.T) {
|
||
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
if r.Method != http.MethodPost || r.URL.Path != "/chat/completions" {
|
||
|
|
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path)
|
||
|
|
}
|
||
|
|
body, _ := io.ReadAll(r.Body)
|
||
|
|
var req chatCompletionRequest
|
||
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
||
|
|
t.Fatalf("unmarshal request: %v", err)
|
||
|
|
}
|
||
|
|
if req.Stream {
|
||
|
|
t.Fatal("expected non-stream request")
|
||
|
|
}
|
||
|
|
if req.Thinking == nil || req.Thinking.Type != "disabled" {
|
||
|
|
t.Fatalf("thinking = %+v, want disabled for deepseek on opencode", req.Thinking)
|
||
|
|
}
|
||
|
|
if req.MaxTokens == nil || *req.MaxTokens < 2048 {
|
||
|
|
t.Fatalf("max_tokens = %+v, want >= 2048 for deepseek on opencode", req.MaxTokens)
|
||
|
|
}
|
||
|
|
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"hello"},"finish_reason":"stop"}]}`))
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
p := NewOpenAICompatible(enum.ProviderOpenCode, server.URL)
|
||
|
|
maxTokens := 256
|
||
|
|
result, err := p.GenerateText(context.Background(), domai.GenerateRequest{
|
||
|
|
Provider: enum.ProviderOpenCode,
|
||
|
|
Model: "deepseek-v4-flash",
|
||
|
|
Credential: domai.Credential{APIKey: "test-token"},
|
||
|
|
Messages: []domai.Message{{Role: "user", Content: "hi"}},
|
||
|
|
MaxTokens: &maxTokens,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GenerateText() error = %v", err)
|
||
|
|
}
|
||
|
|
if result.Text != "hello" || result.FinishReason != "stop" {
|
||
|
|
t.Fatalf("result = %+v, want text=hello finish_reason=stop", result)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOpenAICompatible_GenerateText_FallsBackToReasoningContent(t *testing.T) {
|
||
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"","reasoning_content":"thinking only"},"finish_reason":"length"}]}`))
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
p := NewOpenAICompatible(enum.ProviderXAI, server.URL)
|
||
|
|
result, err := p.GenerateText(context.Background(), domai.GenerateRequest{
|
||
|
|
Model: "deepseek-test",
|
||
|
|
Credential: domai.Credential{APIKey: "test-token"},
|
||
|
|
Messages: []domai.Message{{Role: "user", Content: "hi"}},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GenerateText() error = %v", err)
|
||
|
|
}
|
||
|
|
if result.Text != "thinking only" {
|
||
|
|
t.Fatalf("result.Text = %q, want thinking only", result.Text)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOpenAICompatible_StreamText_ParsesReasoningContent(t *testing.T) {
|
||
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
||
|
|
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"think\"}}]}\n\n"))
|
||
|
|
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"answer\"},\"finish_reason\":\"stop\"}]}\n\n"))
|
||
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
p := NewOpenAICompatible(enum.ProviderXAI, server.URL)
|
||
|
|
stream, err := p.StreamText(context.Background(), domai.GenerateRequest{
|
||
|
|
Model: "deepseek-test",
|
||
|
|
Credential: domai.Credential{APIKey: "test-token"},
|
||
|
|
Messages: []domai.Message{{Role: "user", Content: "hi"}},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("StreamText() error = %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var deltas []string
|
||
|
|
var finishReason string
|
||
|
|
for event := range stream {
|
||
|
|
switch event.Type {
|
||
|
|
case "delta":
|
||
|
|
deltas = append(deltas, event.Text)
|
||
|
|
case "done":
|
||
|
|
finishReason = event.FinishReason
|
||
|
|
case "error":
|
||
|
|
t.Fatalf("stream error: %s", event.Error)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if strings.Join(deltas, "") != "thinkanswer" {
|
||
|
|
t.Fatalf("deltas = %v, want thinkanswer", deltas)
|
||
|
|
}
|
||
|
|
if finishReason != "stop" {
|
||
|
|
t.Fatalf("finishReason = %q, want stop", finishReason)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBuildChatCompletionRequest_DisablesThinkingForOpenCodeDeepSeek(t *testing.T) {
|
||
|
|
p := NewOpenAICompatible(enum.ProviderOpenCode, "https://example.com")
|
||
|
|
maxTokens := 128
|
||
|
|
body := p.buildChatCompletionRequest(domai.GenerateRequest{
|
||
|
|
Provider: enum.ProviderOpenCode,
|
||
|
|
Model: "deepseek-v4-flash",
|
||
|
|
Messages: []domai.Message{{Role: "user", Content: "hi"}},
|
||
|
|
MaxTokens: &maxTokens,
|
||
|
|
}, true)
|
||
|
|
|
||
|
|
if body.Thinking == nil || body.Thinking.Type != "disabled" {
|
||
|
|
t.Fatalf("thinking = %+v, want disabled", body.Thinking)
|
||
|
|
}
|
||
|
|
if body.MaxTokens == nil || *body.MaxTokens < 2048 {
|
||
|
|
t.Fatalf("max_tokens = %+v, want >= 2048", body.MaxTokens)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestExtractAssistantText_ContentArray(t *testing.T) {
|
||
|
|
msg := assistantMessage{
|
||
|
|
Content: json.RawMessage(`[{"type":"text","text":"array content"}]`),
|
||
|
|
}
|
||
|
|
if got := extractAssistantText(msg); got != "array content" {
|
||
|
|
t.Fatalf("extractAssistantText() = %q, want array content", got)
|
||
|
|
}
|
||
|
|
}
|