feat: Route chat completions to Gemini Web provider when configured
- Add HandleGeminiChatCompletions for Gemini Web provider requests - Update router to route requests based on cfg.Provider - Support both streaming and non-streaming modes for Gemini - Map stream chunks to OpenAI-compatible SSE format
This commit is contained in:
parent
f33353897c
commit
19985dd476
|
|
@ -0,0 +1,196 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cursor-api-proxy/internal/apitypes"
|
||||
"cursor-api-proxy/internal/config"
|
||||
"cursor-api-proxy/internal/httputil"
|
||||
"cursor-api-proxy/internal/logger"
|
||||
"cursor-api-proxy/internal/providers/geminiweb"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) {
|
||||
_ = context.Background() // 確保 context 被使用
|
||||
var bodyMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
||||
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
rawModel, _ := bodyMap["model"].(string)
|
||||
if rawModel == "" {
|
||||
rawModel = "gemini-2.0-flash"
|
||||
}
|
||||
|
||||
var messages []interface{}
|
||||
if m, ok := bodyMap["messages"].([]interface{}); ok {
|
||||
messages = m
|
||||
}
|
||||
|
||||
isStream := false
|
||||
if s, ok := bodyMap["stream"].(bool); ok {
|
||||
isStream = s
|
||||
}
|
||||
|
||||
// 轉換 messages 為 apitypes.Message
|
||||
var apiMessages []apitypes.Message
|
||||
for _, m := range messages {
|
||||
if msgMap, ok := m.(map[string]interface{}); ok {
|
||||
role, _ := msgMap["role"].(string)
|
||||
content := ""
|
||||
if c, ok := msgMap["content"].(string); ok {
|
||||
content = c
|
||||
}
|
||||
apiMessages = append(apiMessages, apitypes.Message{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream)
|
||||
start := time.Now().UnixMilli()
|
||||
|
||||
// 創建 Gemini provider
|
||||
provider := geminiweb.NewProvider(cfg)
|
||||
|
||||
if isStream {
|
||||
httputil.WriteSSEHeaders(w, nil)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
id := "chatcmpl_" + uuid.New().String()
|
||||
created := time.Now().Unix()
|
||||
var accumulated string
|
||||
|
||||
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
|
||||
if chunk.Type == apitypes.ChunkText {
|
||||
accumulated += chunk.Text
|
||||
respChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": rawModel,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]string{"content": chunk.Text},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(respChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
} else if chunk.Type == apitypes.ChunkThinking {
|
||||
respChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": rawModel,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{"reasoning_content": chunk.Thinking},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(respChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
} else if chunk.Type == apitypes.ChunkDone {
|
||||
stopChunk := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": rawModel,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(stopChunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - start
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true)
|
||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
|
||||
return
|
||||
}
|
||||
|
||||
// 非串流模式
|
||||
var resultText string
|
||||
var resultThinking string
|
||||
|
||||
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
|
||||
if chunk.Type == apitypes.ChunkText {
|
||||
resultText += chunk.Text
|
||||
} else if chunk.Type == apitypes.ChunkThinking {
|
||||
resultThinking += chunk.Thinking
|
||||
}
|
||||
})
|
||||
|
||||
latencyMs := time.Now().UnixMilli() - start
|
||||
|
||||
if err != nil {
|
||||
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
|
||||
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||
"error": map[string]string{"message": err.Error(), "code": "gemini_error"},
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false)
|
||||
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
|
||||
|
||||
id := "chatcmpl_" + uuid.New().String()
|
||||
created := time.Now().Unix()
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": rawModel,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"message": map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": resultText,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, 200, resp, nil)
|
||||
}
|
||||
|
|
@ -61,7 +61,16 @@ func NewRouter(opts RouterOptions) http.HandlerFunc {
|
|||
}, nil)
|
||||
return
|
||||
}
|
||||
// 根據 Provider 選擇處理方式
|
||||
provider := cfg.Provider
|
||||
if provider == "" {
|
||||
provider = "cursor"
|
||||
}
|
||||
if provider == "gemini-web" {
|
||||
handlers.HandleGeminiChatCompletions(w, r, cfg, raw, method, pathname, remoteAddress)
|
||||
} else {
|
||||
handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||
}
|
||||
|
||||
case method == "POST" && pathname == "/v1/messages":
|
||||
raw, err := httputil.ReadBody(r)
|
||||
|
|
|
|||
Loading…
Reference in New Issue