58 lines
1.3 KiB
Go
58 lines
1.3 KiB
Go
|
|
package server
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"net/http"
|
|||
|
|
|
|||
|
|
"github.com/daniel/cursor-adapter/internal/types"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// SSEWriter 封裝 http.ResponseWriter 用於 SSE streaming。
|
|||
|
|
type SSEWriter struct {
|
|||
|
|
w http.ResponseWriter
|
|||
|
|
flush http.Flusher
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewSSEWriter 建立 SSEWriter,設定必要的 headers。
|
|||
|
|
func NewSSEWriter(w http.ResponseWriter) *SSEWriter {
|
|||
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|||
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|||
|
|
w.Header().Set("Connection", "keep-alive")
|
|||
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|||
|
|
|
|||
|
|
flusher, _ := w.(http.Flusher)
|
|||
|
|
|
|||
|
|
return &SSEWriter{w: w, flush: flusher}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WriteChunk 寫入一個 SSE chunk。
|
|||
|
|
func (s *SSEWriter) WriteChunk(chunk types.ChatCompletionChunk) error {
|
|||
|
|
data, err := json.Marshal(chunk)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("marshal chunk: %w", err)
|
|||
|
|
}
|
|||
|
|
fmt.Fprintf(s.w, "data: %s\n\n", data)
|
|||
|
|
if s.flush != nil {
|
|||
|
|
s.flush.Flush()
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WriteDone 寫入 SSE 結束標記。
|
|||
|
|
func (s *SSEWriter) WriteDone() {
|
|||
|
|
fmt.Fprint(s.w, "data: [DONE]\n\n")
|
|||
|
|
if s.flush != nil {
|
|||
|
|
s.flush.Flush()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WriteError 寫入 SSE 格式的錯誤。
|
|||
|
|
func (s *SSEWriter) WriteError(errMsg string) {
|
|||
|
|
stopReason := "stop"
|
|||
|
|
chunk := types.NewChatCompletionChunk("error", 0, "", types.Delta{Content: &errMsg})
|
|||
|
|
chunk.Choices[0].FinishReason = &stopReason
|
|||
|
|
s.WriteChunk(chunk)
|
|||
|
|
s.WriteDone()
|
|||
|
|
}
|