feat: Route chat completions to Gemini Web provider when configured

- Add HandleGeminiChatCompletions for Gemini Web provider requests
- Update router to route requests based on cfg.Provider
- Support both streaming and non-streaming modes for Gemini
- Map stream chunks to OpenAI-compatible SSE format
This commit is contained in:
王性驊 2026-04-03 00:36:48 +08:00
parent f33353897c
commit 19985dd476
2 changed files with 211 additions and 6 deletions

View File

@ -0,0 +1,196 @@
package handlers
import (
"context"
"cursor-api-proxy/internal/apitypes"
"cursor-api-proxy/internal/config"
"cursor-api-proxy/internal/httputil"
"cursor-api-proxy/internal/logger"
"cursor-api-proxy/internal/providers/geminiweb"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
)
func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) {
_ = context.Background() // 確保 context 被使用
var bodyMap map[string]interface{}
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
httputil.WriteJSON(w, 400, map[string]interface{}{
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
}, nil)
return
}
rawModel, _ := bodyMap["model"].(string)
if rawModel == "" {
rawModel = "gemini-2.0-flash"
}
var messages []interface{}
if m, ok := bodyMap["messages"].([]interface{}); ok {
messages = m
}
isStream := false
if s, ok := bodyMap["stream"].(bool); ok {
isStream = s
}
// 轉換 messages 為 apitypes.Message
var apiMessages []apitypes.Message
for _, m := range messages {
if msgMap, ok := m.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content := ""
if c, ok := msgMap["content"].(string); ok {
content = c
}
apiMessages = append(apiMessages, apitypes.Message{
Role: role,
Content: content,
})
}
}
logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream)
start := time.Now().UnixMilli()
// 創建 Gemini provider
provider := geminiweb.NewProvider(cfg)
if isStream {
httputil.WriteSSEHeaders(w, nil)
flusher, _ := w.(http.Flusher)
id := "chatcmpl_" + uuid.New().String()
created := time.Now().Unix()
var accumulated string
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
if chunk.Type == apitypes.ChunkText {
accumulated += chunk.Text
respChunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": rawModel,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]string{"content": chunk.Text},
"finish_reason": nil,
},
},
}
data, _ := json.Marshal(respChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
} else if chunk.Type == apitypes.ChunkThinking {
respChunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": rawModel,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{"reasoning_content": chunk.Thinking},
"finish_reason": nil,
},
},
}
data, _ := json.Marshal(respChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
if flusher != nil {
flusher.Flush()
}
} else if chunk.Type == apitypes.ChunkDone {
stopChunk := map[string]interface{}{
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": rawModel,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{},
"finish_reason": "stop",
},
},
}
data, _ := json.Marshal(stopChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
}
})
latencyMs := time.Now().UnixMilli() - start
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
return
}
logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true)
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
return
}
// 非串流模式
var resultText string
var resultThinking string
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
if chunk.Type == apitypes.ChunkText {
resultText += chunk.Text
} else if chunk.Type == apitypes.ChunkThinking {
resultThinking += chunk.Thinking
}
})
latencyMs := time.Now().UnixMilli() - start
if err != nil {
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
httputil.WriteJSON(w, 500, map[string]interface{}{
"error": map[string]string{"message": err.Error(), "code": "gemini_error"},
}, nil)
return
}
logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false)
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
id := "chatcmpl_" + uuid.New().String()
created := time.Now().Unix()
resp := map[string]interface{}{
"id": id,
"object": "chat.completion",
"created": created,
"model": rawModel,
"choices": []map[string]interface{}{
{
"index": 0,
"message": map[string]interface{}{
"role": "assistant",
"content": resultText,
},
"finish_reason": "stop",
},
},
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
httputil.WriteJSON(w, 200, resp, nil)
}

View File

@ -61,7 +61,16 @@ func NewRouter(opts RouterOptions) http.HandlerFunc {
}, nil) }, nil)
return return
} }
// 根據 Provider 選擇處理方式
provider := cfg.Provider
if provider == "" {
provider = "cursor"
}
if provider == "gemini-web" {
handlers.HandleGeminiChatCompletions(w, r, cfg, raw, method, pathname, remoteAddress)
} else {
handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress) handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
}
case method == "POST" && pathname == "/v1/messages": case method == "POST" && pathname == "/v1/messages":
raw, err := httputil.ReadBody(r) raw, err := httputil.ReadBody(r)