73 lines
1.9 KiB
Go
73 lines
1.9 KiB
Go
package ai
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"haixun-backend/internal/logic/ai"
|
|
"haixun-backend/internal/response"
|
|
"haixun-backend/internal/svc"
|
|
"haixun-backend/internal/types"
|
|
|
|
"github.com/zeromicro/go-zero/rest/httpx"
|
|
)
|
|
|
|
func ChatStreamHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
var req types.AIChatReq
|
|
if err := httpx.Parse(r, &req); err != nil {
|
|
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
|
|
return
|
|
}
|
|
if err := svcCtx.Validator.ValidateAll(&req); err != nil {
|
|
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
|
|
return
|
|
}
|
|
|
|
token, err := ai.BearerToken(r)
|
|
if err != nil {
|
|
response.Write(r.Context(), w, nil, err)
|
|
return
|
|
}
|
|
|
|
l := ai.NewChatStreamLogic(r.Context(), svcCtx)
|
|
stream, err := l.ChatStream(&req, token)
|
|
if err != nil {
|
|
response.Write(r.Context(), w, nil, err)
|
|
return
|
|
}
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
response.Write(r.Context(), w, nil, response.WrapRequestError(fmt.Errorf("server does not support streaming")))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
|
|
for event := range stream {
|
|
writeSSE(w, event.Type, event)
|
|
flusher.Flush()
|
|
if event.Type == "done" || event.Type == "error" {
|
|
return
|
|
}
|
|
}
|
|
writeSSE(w, "done", map[string]string{"finish_reason": "stop"})
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
func writeSSE(w http.ResponseWriter, eventName string, data any) {
|
|
payload, err := json.Marshal(data)
|
|
if err != nil {
|
|
payload = []byte(`{"type":"error","error":"failed to serialize SSE payload"}`)
|
|
eventName = "error"
|
|
}
|
|
_, _ = fmt.Fprintf(w, "event: %s\n", eventName)
|
|
_, _ = fmt.Fprintf(w, "data: %s\n\n", payload)
|
|
}
|