feature/gemini-web-provider #1
|
|
@ -0,0 +1,196 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"cursor-api-proxy/internal/apitypes"
|
||||||
|
"cursor-api-proxy/internal/config"
|
||||||
|
"cursor-api-proxy/internal/httputil"
|
||||||
|
"cursor-api-proxy/internal/logger"
|
||||||
|
"cursor-api-proxy/internal/providers/geminiweb"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HandleGeminiChatCompletions(w http.ResponseWriter, r *http.Request, cfg config.BridgeConfig, rawBody, method, pathname, remoteAddress string) {
|
||||||
|
_ = context.Background() // 確保 context 被使用
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(rawBody), &bodyMap); err != nil {
|
||||||
|
httputil.WriteJSON(w, 400, map[string]interface{}{
|
||||||
|
"error": map[string]string{"message": "invalid JSON body", "code": "bad_request"},
|
||||||
|
}, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawModel, _ := bodyMap["model"].(string)
|
||||||
|
if rawModel == "" {
|
||||||
|
rawModel = "gemini-2.0-flash"
|
||||||
|
}
|
||||||
|
|
||||||
|
var messages []interface{}
|
||||||
|
if m, ok := bodyMap["messages"].([]interface{}); ok {
|
||||||
|
messages = m
|
||||||
|
}
|
||||||
|
|
||||||
|
isStream := false
|
||||||
|
if s, ok := bodyMap["stream"].(bool); ok {
|
||||||
|
isStream = s
|
||||||
|
}
|
||||||
|
|
||||||
|
// 轉換 messages 為 apitypes.Message
|
||||||
|
var apiMessages []apitypes.Message
|
||||||
|
for _, m := range messages {
|
||||||
|
if msgMap, ok := m.(map[string]interface{}); ok {
|
||||||
|
role, _ := msgMap["role"].(string)
|
||||||
|
content := ""
|
||||||
|
if c, ok := msgMap["content"].(string); ok {
|
||||||
|
content = c
|
||||||
|
}
|
||||||
|
apiMessages = append(apiMessages, apitypes.Message{
|
||||||
|
Role: role,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LogRequestStart(method, pathname, rawModel, cfg.TimeoutMs, isStream)
|
||||||
|
start := time.Now().UnixMilli()
|
||||||
|
|
||||||
|
// 創建 Gemini provider
|
||||||
|
provider := geminiweb.NewProvider(cfg)
|
||||||
|
|
||||||
|
if isStream {
|
||||||
|
httputil.WriteSSEHeaders(w, nil)
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
|
||||||
|
id := "chatcmpl_" + uuid.New().String()
|
||||||
|
created := time.Now().Unix()
|
||||||
|
var accumulated string
|
||||||
|
|
||||||
|
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
|
||||||
|
if chunk.Type == apitypes.ChunkText {
|
||||||
|
accumulated += chunk.Text
|
||||||
|
respChunk := map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": rawModel,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]string{"content": chunk.Text},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(respChunk)
|
||||||
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
} else if chunk.Type == apitypes.ChunkThinking {
|
||||||
|
respChunk := map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": rawModel,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{"reasoning_content": chunk.Thinking},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(respChunk)
|
||||||
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
} else if chunk.Type == apitypes.ChunkDone {
|
||||||
|
stopChunk := map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": rawModel,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(stopChunk)
|
||||||
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||||
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
latencyMs := time.Now().UnixMilli() - start
|
||||||
|
if err != nil {
|
||||||
|
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||||
|
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LogTrafficResponse(cfg.Verbose, rawModel, accumulated, true)
|
||||||
|
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非串流模式
|
||||||
|
var resultText string
|
||||||
|
var resultThinking string
|
||||||
|
|
||||||
|
err := provider.Generate(r.Context(), rawModel, apiMessages, nil, func(chunk apitypes.StreamChunk) {
|
||||||
|
if chunk.Type == apitypes.ChunkText {
|
||||||
|
resultText += chunk.Text
|
||||||
|
} else if chunk.Type == apitypes.ChunkThinking {
|
||||||
|
resultThinking += chunk.Thinking
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
latencyMs := time.Now().UnixMilli() - start
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.LogAgentError(cfg.SessionsLogPath, method, pathname, remoteAddress, -1, err.Error())
|
||||||
|
logger.LogRequestDone(method, pathname, rawModel, latencyMs, -1)
|
||||||
|
httputil.WriteJSON(w, 500, map[string]interface{}{
|
||||||
|
"error": map[string]string{"message": err.Error(), "code": "gemini_error"},
|
||||||
|
}, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LogTrafficResponse(cfg.Verbose, rawModel, resultText, false)
|
||||||
|
logger.LogRequestDone(method, pathname, rawModel, latencyMs, 0)
|
||||||
|
|
||||||
|
id := "chatcmpl_" + uuid.New().String()
|
||||||
|
created := time.Now().Unix()
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": created,
|
||||||
|
"model": rawModel,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": resultText,
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
httputil.WriteJSON(w, 200, resp, nil)
|
||||||
|
}
|
||||||
|
|
@ -61,7 +61,16 @@ func NewRouter(opts RouterOptions) http.HandlerFunc {
|
||||||
}, nil)
|
}, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 根據 Provider 選擇處理方式
|
||||||
|
provider := cfg.Provider
|
||||||
|
if provider == "" {
|
||||||
|
provider = "cursor"
|
||||||
|
}
|
||||||
|
if provider == "gemini-web" {
|
||||||
|
handlers.HandleGeminiChatCompletions(w, r, cfg, raw, method, pathname, remoteAddress)
|
||||||
|
} else {
|
||||||
handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
handlers.HandleChatCompletions(w, r, cfg, opts.Pool, opts.LastModel, raw, method, pathname, remoteAddress)
|
||||||
|
}
|
||||||
|
|
||||||
case method == "POST" && pathname == "/v1/messages":
|
case method == "POST" && pathname == "/v1/messages":
|
||||||
raw, err := httputil.ReadBody(r)
|
raw, err := httputil.ReadBody(r)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue