From 19985dd47636607c853c87b6a31fcdf3fd789101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=A7=E9=A9=8A?= Date: Fri, 3 Apr 2026 00:36:48 +0800 Subject: [PATCH] feat: Route chat completions to Gemini Web provider when configured - Add HandleGeminiChatCompletions for Gemini Web provider requests - Update router to route requests based on cfg.Provider - Support both streaming and non-streaming modes for Gemini - Map stream chunks to OpenAI-compatible SSE format --- internal/handlers/gemini_handler.go | 196 ++++++++++++++++++++++++++++ internal/router/router.go | 21 ++- 2 files changed, 211 insertions(+), 6 deletions(-) create mode 100644 internal/handlers/gemini_handler.go diff --git a/internal/handlers/gemini_handler.go b/internal/handlers/gemini_handler.go new file mode 100644 index 0000000..8a11483 --- /dev/null +++ b/internal/handlers/gemini_handler.go @@ -0,0 +1,196 @@ +package handlers + +import ( + "context" + "cursor-api-proxy/internal/apitypes" + "cursor-api-proxy/internal/config" + "cursor-api-proxy/internal/httputil" + "cursor-api-proxy/internal/logger" + "cursor-api-proxy/internal/providers/geminiweb" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" +) + +func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) { + _ = context.Background() // 確保 context 被使用 + var bodyMap map[string]interface{} + if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil { + httputil.WriteJSON(w, 400, map[string]interface{}{ + "error": map[string]string{"message": "invalid JSON body", "code": "bad_request"}, + }, nil) + return + } + + rawModel, _ := bodyMap["model"].(string) + if rawModel == "" { + rawModel = "gemini-2.0-flash" + } + + var messages []interface{} + if m, ok := bodyMap["messages"].([]interface{}); ok { + messages = m + } + + isStream := false + if s, ok := bodyMap["stream"].(bool); ok { + isStream = s + } + + // 轉換 messages 為 apitypes.Message + var apiMessages []apitypes.Message + for _, m := range messages { + if msgMap, ok := m.(map[string]interface{}); ok { + role, _ := msgMap["role"].(string) + content := "" + if c, ok := msgMap["content"].(string); ok { + content = c + } + apiMessages = append(apiMessages, apitypes.Message{ + Role: role, + Content: content, + }) + } + } + + logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream) + start := time.Now().UnixMilli() + + // 創建 Gemini provider + provider := geminiweb.NewProvider(cfg) + + if isStream { + httputil.WriteSSEHeaders(w, nil) + flusher, _ := w.(http.Flusher) + + id := "chatcmpl_" + uuid.New().String() + created := time.Now().Unix() + var accumulated string + + err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) { + if chunk.Type == apitypes.ChunkText { + accumulated += chunk.Text + respChunk := map[string]interface{}{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": rawModel, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]string{"content": chunk.Text}, + "finish_reason": nil, + }, + }, + } + data, _ := json.Marshal(respChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + } else if chunk.Type == apitypes.ChunkThinking { + respChunk := map[string]interface{}{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": rawModel, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{"reasoning_content": chunk.Thinking}, + "finish_reason": nil, + }, + }, + } + data, _ := json.Marshal(respChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + if flusher != nil { + flusher.Flush() + } + } else if chunk.Type == apitypes.ChunkDone { + stopChunk := map[string]interface{}{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": rawModel, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{}, + "finish_reason": "stop", + }, + }, + } + data, _ := json.Marshal(stopChunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + } + }) + + latencyMs := time.Now().UnixMilli() - start + if err != nil { + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) + logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1) + return + } + + logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true) + logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0) + return + } + + // 非串流模式 + var resultText string + var resultThinking string + + err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) { + if chunk.Type == apitypes.ChunkText { + resultText += chunk.Text + } else if chunk.Type == apitypes.ChunkThinking { + resultThinking += chunk.Thinking + } + }) + + latencyMs := time.Now().UnixMilli() - start + + if err != nil { + logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) + logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1) + httputil.WriteJSON(w, 500, map[string]interface{}{ + "error": map[string]string{"message": err.Error(), "code": "gemini_error"}, + }, nil) + return + } + + logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false) + logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0) + + id := "chatcmpl_" + uuid.New().String() + created := time.Now().Unix() + + resp := map[string]interface{}{ + "id": id, + "object": "chat.completion", + "created": created, + "model": rawModel, + "choices": []map[string]interface{}{ + { + "index": 0, + "message": map[string]interface{}{ + "role": "assistant", + "content": resultText, + }, + "finish_reason": "stop", + }, + }, + "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + httputil.WriteJSON(w, 200, resp, nil) +} diff --git a/internal/router/router.go b/internal/router/router.go index 58e8552..745f958 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -13,11 +13,11 @@ import ( ) type RouterOptions struct { - Version string - Config config.BridgeConfig - ModelCache *handlers.ModelCacheRef - LastModel *string - Pool pool.PoolHandle + Version string + Config config.BridgeConfig + ModelCache *handlers.ModelCacheRef + LastModel *string + Pool pool.PoolHandle } func NewRouter(opts RouterOptions) http.HandlerFunc { @@ -61,7 +61,16 @@ func NewRouter(opts RouterOptions) http.HandlerFunc { }, nil) return } - handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress) + // 根據 Provider 選擇處理方式 + provider := cfg.Provider + if provider == "" { + provider = "cursor" + } + if provider == "gemini-web" { + handlers.HandleGeminiChatCompletions(w, r, cfg, raw, method, pathname, remoteAddress) + } else { + handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress) + } case method == "POST" && pathname == "/v1/messages": raw, err := httputil.ReadBody(r)