thread-master/internal/model/ai/provider/openai_compatible.go

357 lines
9.8 KiB
Go

package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
app "haixun-backend/internal/library/errors"
"haixun-backend/internal/library/errors/code"
"haixun-backend/internal/model/ai/domain/enum"
domai "haixun-backend/internal/model/ai/domain/usecase"
)
type OpenAICompatible struct {
id enum.ProviderID
baseURL string
client *http.Client
}
func NewOpenAICompatible(id enum.ProviderID, baseURL string) *OpenAICompatible {
return &OpenAICompatible{
id: id,
baseURL: strings.TrimRight(baseURL, "/"),
client: &http.Client{Timeout: 10 * time.Minute},
}
}
func (p *OpenAICompatible) ID() enum.ProviderID {
return p.id
}
func (p *OpenAICompatible) ListModels(ctx context.Context, credential domai.Credential) ([]string, error) {
if strings.TrimSpace(credential.APIKey) == "" {
return nil, app.For(code.AI).InputMissingRequired("missing AI provider token")
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, p.baseURL+"/models", nil)
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+credential.APIKey)
httpReq.Header.Set("Accept", "application/json")
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, providerHTTPError("AI provider models request failed", resp.StatusCode, resp.Status)
}
var payload openAIModelsResponse
if err := json.Unmarshal(data, &payload); err != nil {
return nil, app.For(code.AI).SvcThirdParty("failed to parse AI provider models response")
}
models := make([]string, 0, len(payload.Data))
for _, item := range payload.Data {
id := strings.TrimSpace(item.ID)
if id == "" {
continue
}
models = append(models, id)
}
sort.Strings(models)
return models, nil
}
func (p *OpenAICompatible) GenerateText(ctx context.Context, req domai.GenerateRequest) (*domai.GenerateResult, error) {
if strings.TrimSpace(req.Credential.APIKey) == "" {
return nil, app.For(code.AI).InputMissingRequired("missing AI provider token")
}
body := p.buildChatCompletionRequest(req, false)
payload, err := json.Marshal(body)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(payload))
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+req.Credential.APIKey)
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json")
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, providerHTTPError("AI provider request failed", resp.StatusCode, resp.Status)
}
var chatResp openAIChatResponse
if err := json.Unmarshal(data, &chatResp); err != nil {
return nil, app.For(code.AI).SvcThirdParty("failed to parse AI provider chat response")
}
if len(chatResp.Choices) == 0 {
return nil, app.For(code.AI).SvcThirdParty("AI provider returned no choices")
}
choice := chatResp.Choices[0]
return &domai.GenerateResult{
Text: extractAssistantText(choice.Message),
FinishReason: choice.FinishReason,
}, nil
}
func (p *OpenAICompatible) StreamText(ctx context.Context, req domai.GenerateRequest) (<-chan domai.StreamEvent, error) {
if strings.TrimSpace(req.Credential.APIKey) == "" {
return nil, app.For(code.AI).InputMissingRequired("missing AI provider token")
}
body := p.buildChatCompletionRequest(req, true)
payload, err := json.Marshal(body)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(payload))
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+req.Credential.APIKey)
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream")
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return nil, providerHTTPError("AI provider request failed", resp.StatusCode, resp.Status)
}
out := make(chan domai.StreamEvent)
go func() {
defer close(out)
defer resp.Body.Close()
parseOpenAIStream(resp.Body, out)
}()
return out, nil
}
type chatCompletionRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Stream bool `json:"stream"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
Thinking *thinkingConfig `json:"thinking,omitempty"`
}
type thinkingConfig struct {
Type string `json:"type,omitempty"`
}
type assistantMessage struct {
Content json.RawMessage `json:"content"`
ReasoningContent string `json:"reasoning_content"`
Reasoning string `json:"reasoning"`
}
type openAIChatResponse struct {
Choices []struct {
Message assistantMessage `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
func (p *OpenAICompatible) buildChatCompletionRequest(req domai.GenerateRequest, stream bool) chatCompletionRequest {
return chatCompletionRequest{
Model: req.Model,
Messages: toChatMessages(req),
Stream: stream,
Temperature: normalizeTemperature(p.id, req.Model, req.Temperature),
MaxTokens: normalizeMaxTokens(p.id, req.Model, req.MaxTokens),
Thinking: normalizeThinking(p.id, req.Model),
}
}
func extractAssistantText(msg assistantMessage) string {
if text := parseMessageContent(msg.Content); strings.TrimSpace(text) != "" {
return text
}
if strings.TrimSpace(msg.ReasoningContent) != "" {
return msg.ReasoningContent
}
return msg.Reasoning
}
func parseMessageContent(raw json.RawMessage) string {
if len(raw) == 0 || string(raw) == "null" {
return ""
}
var text string
if err := json.Unmarshal(raw, &text); err == nil {
return text
}
var parts []struct {
Text string `json:"text"`
}
if err := json.Unmarshal(raw, &parts); err == nil {
var b strings.Builder
for _, part := range parts {
if part.Text != "" {
b.WriteString(part.Text)
}
}
return b.String()
}
return ""
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
func toChatMessages(req domai.GenerateRequest) []chatMessage {
messages := make([]chatMessage, 0, len(req.Messages)+1)
if strings.TrimSpace(req.System) != "" {
messages = append(messages, chatMessage{Role: "system", Content: req.System})
}
for _, msg := range req.Messages {
messages = append(messages, chatMessage{Role: msg.Role, Content: msg.Content})
}
return messages
}
func normalizeTemperature(provider enum.ProviderID, model string, requested *float64) *float64 {
if provider == enum.ProviderOpenCode && strings.HasPrefix(model, "kimi-") {
v := 1.0
return &v
}
return requested
}
func normalizeMaxTokens(provider enum.ProviderID, model string, requested *int) *int {
if provider != enum.ProviderOpenCode || !strings.HasPrefix(model, "deepseek-") {
return requested
}
// DeepSeek V4 thinking can consume the full budget before content appears.
const minDeepSeekTokens = 2048
if requested == nil || *requested < minDeepSeekTokens {
v := minDeepSeekTokens
return &v
}
return requested
}
func normalizeThinking(provider enum.ProviderID, model string) *thinkingConfig {
// DeepSeek V4 on OpenCode Go defaults to thinking mode and may spend the entire
// max_tokens budget on reasoning_content before emitting content.
if provider == enum.ProviderOpenCode && strings.HasPrefix(model, "deepseek-") {
return &thinkingConfig{Type: "disabled"}
}
return nil
}
func streamDeltaText(delta streamDelta) string {
if delta.Content != "" {
return delta.Content
}
if delta.Text != "" {
return delta.Text
}
if delta.ReasoningContent != "" {
return delta.ReasoningContent
}
return delta.Reasoning
}
func parseOpenAIStream(body io.Reader, out chan<- domai.StreamEvent) {
scanner := bufio.NewScanner(body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "[DONE]" {
out <- domai.StreamEvent{Type: "done"}
return
}
var chunk openAIStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
out <- domai.StreamEvent{Type: "error", Error: "failed to parse AI stream payload"}
return
}
for _, choice := range chunk.Choices {
if text := streamDeltaText(choice.Delta); text != "" {
out <- domai.StreamEvent{Type: "delta", Text: text}
}
if choice.FinishReason != "" {
out <- domai.StreamEvent{Type: "done", FinishReason: choice.FinishReason}
return
}
}
}
if err := scanner.Err(); err != nil {
out <- domai.StreamEvent{Type: "error", Error: err.Error()}
}
}
type streamDelta struct {
Content string `json:"content"`
Text string `json:"text"`
ReasoningContent string `json:"reasoning_content"`
Reasoning string `json:"reasoning"`
}
type openAIStreamChunk struct {
Choices []struct {
Delta streamDelta `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type openAIModelsResponse struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
func providerHTTPError(prefix string, statusCode int, statusLine string) error {
return app.For(code.AI).SvcThirdParty(fmt.Sprintf("%s: HTTP %d %s", prefix, statusCode, statusLine))
}