package handlers import ( "context" "cursor-api-proxy/internal/apitypes" "cursor-api-proxy/internal/config" "cursor-api-proxy/internal/httputil" "cursor-api-proxy/internal/logger" "cursor-api-proxy/internal/providers/geminiweb" "encoding/json" "fmt" "net/http" "time" "github.com/google/uuid" ) func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) { _ = context.Background() // 確保 context 被使用 var bodyMap map[string]interface{} if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil { httputil.WriteJSON(w, 400, map[string]interface{}{ "error": map[string]string{"message": "invalid JSON body", "code": "bad_request"}, }, nil) return } rawModel, _ := bodyMap["model"].(string) if rawModel == "" { rawModel = "gemini-2.0-flash" } var messages []interface{} if m, ok := bodyMap["messages"].([]interface{}); ok { messages = m } isStream := false if s, ok := bodyMap["stream"].(bool); ok { isStream = s } // 轉換 messages 為 apitypes.Message var apiMessages []apitypes.Message for _, m := range messages { if msgMap, ok := m.(map[string]interface{}); ok { role, _ := msgMap["role"].(string) content := "" if c, ok := msgMap["content"].(string); ok { content = c } apiMessages = append(apiMessages, apitypes.Message{ Role: role, Content: content, }) } } logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream) start := time.Now().UnixMilli() // 創建 Gemini provider (使用 Playwright) provider, provErr := geminiweb.NewPlaywrightProvider(cfg) if provErr != nil { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, provErr.Error()) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": provErr.Error(), "code": "provider_error"}, }, nil) return } if isStream { httputil.WriteSSEHeaders(w, nil) flusher, _ := w.(http.Flusher) id := "chatcmpl_" + uuid.New().String() created := time.Now().Unix() var accumulated string err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) { if chunk.Type == apitypes.ChunkText { accumulated += chunk.Text respChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": rawModel, "choices": []map[string]interface{}{ { "index": 0, "delta": map[string]string{"content": chunk.Text}, "finish_reason": nil, }, }, } data, _ := json.Marshal(respChunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } } else if chunk.Type == apitypes.ChunkThinking { respChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": rawModel, "choices": []map[string]interface{}{ { "index": 0, "delta": map[string]interface{}{"reasoning_content": chunk.Thinking}, "finish_reason": nil, }, }, } data, _ := json.Marshal(respChunk) fmt.Fprintf(w, "data: %s\n\n", data) if flusher != nil { flusher.Flush() } } else if chunk.Type == apitypes.ChunkDone { stopChunk := map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": created, "model": rawModel, "choices": []map[string]interface{}{ { "index": 0, "delta": map[string]interface{}{}, "finish_reason": "stop", }, }, } data, _ := json.Marshal(stopChunk) fmt.Fprintf(w, "data: %s\n\n", data) fmt.Fprintf(w, "data: [DONE]\n\n") if flusher != nil { flusher.Flush() } } }) latencyMs := time.Now().UnixMilli() - start if err != nil { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1) return } logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true) logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0) return } // 非串流模式 var resultText string var resultThinking string err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) { if chunk.Type == apitypes.ChunkText { resultText += chunk.Text } else if chunk.Type == apitypes.ChunkThinking { resultThinking += chunk.Thinking } }) latencyMs := time.Now().UnixMilli() - start if err != nil { logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error()) logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1) httputil.WriteJSON(w, 500, map[string]interface{}{ "error": map[string]string{"message": err.Error(), "code": "gemini_error"}, }, nil) return } logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false) logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0) id := "chatcmpl_" + uuid.New().String() created := time.Now().Unix() resp := map[string]interface{}{ "id": id, "object": "chat.completion", "created": created, "model": rawModel, "choices": []map[string]interface{}{ { "index": 0, "message": map[string]interface{}{ "role": "assistant", "content": resultText, }, "finish_reason": "stop", }, }, "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, } httputil.WriteJSON(w, 200, resp, nil) }