357 lines
9.8 KiB
Go
357 lines
9.8 KiB
Go
package provider
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
app "haixun-backend/internal/library/errors"
|
|
"haixun-backend/internal/library/errors/code"
|
|
"haixun-backend/internal/model/ai/domain/enum"
|
|
domai "haixun-backend/internal/model/ai/domain/usecase"
|
|
)
|
|
|
|
type OpenAICompatible struct {
|
|
id enum.ProviderID
|
|
baseURL string
|
|
client *http.Client
|
|
}
|
|
|
|
func NewOpenAICompatible(id enum.ProviderID, baseURL string) *OpenAICompatible {
|
|
return &OpenAICompatible{
|
|
id: id,
|
|
baseURL: strings.TrimRight(baseURL, "/"),
|
|
client: &http.Client{Timeout: 10 * time.Minute},
|
|
}
|
|
}
|
|
|
|
func (p *OpenAICompatible) ID() enum.ProviderID {
|
|
return p.id
|
|
}
|
|
|
|
func (p *OpenAICompatible) ListModels(ctx context.Context, credential domai.Credential) ([]string, error) {
|
|
if strings.TrimSpace(credential.APIKey) == "" {
|
|
return nil, app.For(code.AI).InputMissingRequired("missing AI provider token")
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, p.baseURL+"/models", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
httpReq.Header.Set("Authorization", "Bearer "+credential.APIKey)
|
|
httpReq.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := p.client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
data, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, providerHTTPError("AI provider models request failed", resp.StatusCode, resp.Status)
|
|
}
|
|
|
|
var payload openAIModelsResponse
|
|
if err := json.Unmarshal(data, &payload); err != nil {
|
|
return nil, app.For(code.AI).SvcThirdParty("failed to parse AI provider models response")
|
|
}
|
|
|
|
models := make([]string, 0, len(payload.Data))
|
|
for _, item := range payload.Data {
|
|
id := strings.TrimSpace(item.ID)
|
|
if id == "" {
|
|
continue
|
|
}
|
|
models = append(models, id)
|
|
}
|
|
sort.Strings(models)
|
|
return models, nil
|
|
}
|
|
|
|
func (p *OpenAICompatible) GenerateText(ctx context.Context, req domai.GenerateRequest) (*domai.GenerateResult, error) {
|
|
if strings.TrimSpace(req.Credential.APIKey) == "" {
|
|
return nil, app.For(code.AI).InputMissingRequired("missing AI provider token")
|
|
}
|
|
|
|
body := p.buildChatCompletionRequest(req, false)
|
|
payload, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(payload))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
httpReq.Header.Set("Authorization", "Bearer "+req.Credential.APIKey)
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := p.client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
data, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, providerHTTPError("AI provider request failed", resp.StatusCode, resp.Status)
|
|
}
|
|
|
|
var chatResp openAIChatResponse
|
|
if err := json.Unmarshal(data, &chatResp); err != nil {
|
|
return nil, app.For(code.AI).SvcThirdParty("failed to parse AI provider chat response")
|
|
}
|
|
if len(chatResp.Choices) == 0 {
|
|
return nil, app.For(code.AI).SvcThirdParty("AI provider returned no choices")
|
|
}
|
|
|
|
choice := chatResp.Choices[0]
|
|
return &domai.GenerateResult{
|
|
Text: extractAssistantText(choice.Message),
|
|
FinishReason: choice.FinishReason,
|
|
}, nil
|
|
}
|
|
|
|
func (p *OpenAICompatible) StreamText(ctx context.Context, req domai.GenerateRequest) (<-chan domai.StreamEvent, error) {
|
|
if strings.TrimSpace(req.Credential.APIKey) == "" {
|
|
return nil, app.For(code.AI).InputMissingRequired("missing AI provider token")
|
|
}
|
|
|
|
body := p.buildChatCompletionRequest(req, true)
|
|
payload, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(payload))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
httpReq.Header.Set("Authorization", "Bearer "+req.Credential.APIKey)
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Accept", "text/event-stream")
|
|
|
|
resp, err := p.client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
defer resp.Body.Close()
|
|
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
|
return nil, providerHTTPError("AI provider request failed", resp.StatusCode, resp.Status)
|
|
}
|
|
|
|
out := make(chan domai.StreamEvent)
|
|
go func() {
|
|
defer close(out)
|
|
defer resp.Body.Close()
|
|
parseOpenAIStream(resp.Body, out)
|
|
}()
|
|
return out, nil
|
|
}
|
|
|
|
type chatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []chatMessage `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
|
Thinking *thinkingConfig `json:"thinking,omitempty"`
|
|
}
|
|
|
|
type thinkingConfig struct {
|
|
Type string `json:"type,omitempty"`
|
|
}
|
|
|
|
type assistantMessage struct {
|
|
Content json.RawMessage `json:"content"`
|
|
ReasoningContent string `json:"reasoning_content"`
|
|
Reasoning string `json:"reasoning"`
|
|
}
|
|
|
|
type openAIChatResponse struct {
|
|
Choices []struct {
|
|
Message assistantMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
}
|
|
|
|
func (p *OpenAICompatible) buildChatCompletionRequest(req domai.GenerateRequest, stream bool) chatCompletionRequest {
|
|
return chatCompletionRequest{
|
|
Model: req.Model,
|
|
Messages: toChatMessages(req),
|
|
Stream: stream,
|
|
Temperature: normalizeTemperature(p.id, req.Model, req.Temperature),
|
|
MaxTokens: normalizeMaxTokens(p.id, req.Model, req.MaxTokens),
|
|
Thinking: normalizeThinking(p.id, req.Model),
|
|
}
|
|
}
|
|
|
|
func extractAssistantText(msg assistantMessage) string {
|
|
if text := parseMessageContent(msg.Content); strings.TrimSpace(text) != "" {
|
|
return text
|
|
}
|
|
if strings.TrimSpace(msg.ReasoningContent) != "" {
|
|
return msg.ReasoningContent
|
|
}
|
|
return msg.Reasoning
|
|
}
|
|
|
|
func parseMessageContent(raw json.RawMessage) string {
|
|
if len(raw) == 0 || string(raw) == "null" {
|
|
return ""
|
|
}
|
|
|
|
var text string
|
|
if err := json.Unmarshal(raw, &text); err == nil {
|
|
return text
|
|
}
|
|
|
|
var parts []struct {
|
|
Text string `json:"text"`
|
|
}
|
|
if err := json.Unmarshal(raw, &parts); err == nil {
|
|
var b strings.Builder
|
|
for _, part := range parts {
|
|
if part.Text != "" {
|
|
b.WriteString(part.Text)
|
|
}
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
type chatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
func toChatMessages(req domai.GenerateRequest) []chatMessage {
|
|
messages := make([]chatMessage, 0, len(req.Messages)+1)
|
|
if strings.TrimSpace(req.System) != "" {
|
|
messages = append(messages, chatMessage{Role: "system", Content: req.System})
|
|
}
|
|
for _, msg := range req.Messages {
|
|
messages = append(messages, chatMessage{Role: msg.Role, Content: msg.Content})
|
|
}
|
|
return messages
|
|
}
|
|
|
|
func normalizeTemperature(provider enum.ProviderID, model string, requested *float64) *float64 {
|
|
if provider == enum.ProviderOpenCode && strings.HasPrefix(model, "kimi-") {
|
|
v := 1.0
|
|
return &v
|
|
}
|
|
return requested
|
|
}
|
|
|
|
func normalizeMaxTokens(provider enum.ProviderID, model string, requested *int) *int {
|
|
if provider != enum.ProviderOpenCode || !strings.HasPrefix(model, "deepseek-") {
|
|
return requested
|
|
}
|
|
|
|
// DeepSeek V4 thinking can consume the full budget before content appears.
|
|
const minDeepSeekTokens = 2048
|
|
if requested == nil || *requested < minDeepSeekTokens {
|
|
v := minDeepSeekTokens
|
|
return &v
|
|
}
|
|
return requested
|
|
}
|
|
|
|
func normalizeThinking(provider enum.ProviderID, model string) *thinkingConfig {
|
|
// DeepSeek V4 on OpenCode Go defaults to thinking mode and may spend the entire
|
|
// max_tokens budget on reasoning_content before emitting content.
|
|
if provider == enum.ProviderOpenCode && strings.HasPrefix(model, "deepseek-") {
|
|
return &thinkingConfig{Type: "disabled"}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func streamDeltaText(delta streamDelta) string {
|
|
if delta.Content != "" {
|
|
return delta.Content
|
|
}
|
|
if delta.Text != "" {
|
|
return delta.Text
|
|
}
|
|
if delta.ReasoningContent != "" {
|
|
return delta.ReasoningContent
|
|
}
|
|
return delta.Reasoning
|
|
}
|
|
|
|
func parseOpenAIStream(body io.Reader, out chan<- domai.StreamEvent) {
|
|
scanner := bufio.NewScanner(body)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" || !strings.HasPrefix(line, "data:") {
|
|
continue
|
|
}
|
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
if data == "[DONE]" {
|
|
out <- domai.StreamEvent{Type: "done"}
|
|
return
|
|
}
|
|
var chunk openAIStreamChunk
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
out <- domai.StreamEvent{Type: "error", Error: "failed to parse AI stream payload"}
|
|
return
|
|
}
|
|
for _, choice := range chunk.Choices {
|
|
if text := streamDeltaText(choice.Delta); text != "" {
|
|
out <- domai.StreamEvent{Type: "delta", Text: text}
|
|
}
|
|
if choice.FinishReason != "" {
|
|
out <- domai.StreamEvent{Type: "done", FinishReason: choice.FinishReason}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
out <- domai.StreamEvent{Type: "error", Error: err.Error()}
|
|
}
|
|
}
|
|
|
|
type streamDelta struct {
|
|
Content string `json:"content"`
|
|
Text string `json:"text"`
|
|
ReasoningContent string `json:"reasoning_content"`
|
|
Reasoning string `json:"reasoning"`
|
|
}
|
|
|
|
type openAIStreamChunk struct {
|
|
Choices []struct {
|
|
Delta streamDelta `json:"delta"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
}
|
|
|
|
type openAIModelsResponse struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
func providerHTTPError(prefix string, statusCode int, statusLine string) error {
|
|
return app.For(code.AI).SvcThirdParty(fmt.Sprintf("%s: HTTP %d %s", prefix, statusCode, statusLine))
|
|
}
|