fix api server update version

This commit is contained in:
王性驊 2025-12-31 17:36:02 +08:00
commit 909ba157d0
92 changed files with 14152 additions and 0 deletions

217
.gitignore vendored Normal file
View File

@ -0,0 +1,217 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
*.out
chat
chat-api
# Go workspace file
go.work
# Go build cache
.cache/
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool
*.out
coverage.txt
coverage.html
# Dependency directories
vendor/
# Go module cache
go.sum
# IDE - VSCode
.vscode/
*.code-workspace
# IDE - GoLand / IntelliJ IDEA
.idea/
*.iml
*.iws
*.ipr
# IDE - Vim
*.swp
*.swo
*~
.vim/
# IDE - Emacs
*~
\#*\#
.\#*
# OS - macOS
.DS_Store
.AppleDouble
.LSOverride
._*
# OS - Windows
Thumbs.db
ehthumbs.db
Desktop.ini
$RECYCLE.BIN/
*.lnk
# OS - Linux
*~
.directory
.Trash-*
# Logs
*.log
logs/
*.log.*
# Environment variables
.env
.env.local
.env.*.local
# Configuration files with sensitive data
# 注意:如果配置文件包含敏感信息,應該使用環境變量或配置模板
# etc/*.yaml # 如果包含敏感信息,取消註釋這行
# Temporary files
tmp/
temp/
*.tmp
*.bak
*.swp
*.swo
# Frontend
frontend/node_modules/
frontend/.next/
frontend/.nuxt/
frontend/dist/
frontend/build/
frontend/.cache/
frontend/.parcel-cache/
frontend/.vite/
frontend/.svelte-kit/
frontend/.pnpm-store/
# Frontend package manager files
frontend/package-lock.json
frontend/yarn.lock
frontend/pnpm-lock.yaml
# Frontend environment
frontend/.env
frontend/.env.local
frontend/.env.*.local
# Docker
docker-compose.override.yaml
.dockerignore
# Database
*.db
*.sqlite
*.sqlite3
# Redis dump
dump.rdb
# Cassandra
*.cql
data/
# Deployment
deployment/*.log
deployment/data/
deployment/volumes/
# Generated files
*.pb.go
*.pb.gw.go
*.swagger.json
*.swagger.yaml
# go-zero generated files (if you want to ignore them)
# internal/handler/chat/*.go # 如果不想提交自動生成的文件,取消註釋
# internal/logic/chat/*.go # 如果不想提交自動生成的文件,取消註釋
# internal/types/types.go # 如果不想提交自動生成的文件,取消註釋
# Keep generated files but ignore specific patterns
# internal/handler/chat/*handler.go # 只忽略 handler
# internal/logic/chat/*logic.go # 只忽略 logic
# Build artifacts
bin/
build/
out/
# Profiling data
*.prof
*.pprof
# Local development
.local/
local/
# Secrets and keys
*.pem
*.key
*.crt
*.p12
*.pfx
secrets/
*.secret
# Backup files
*.backup
*.old
# JetBrains IDEs
.idea/
*.iml
*.iws
*.ipr
# VS Code
.vscode/
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
# Sublime Text
*.sublime-project
*.sublime-workspace
# Vim
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
Session.vim
.netrwhist
*~
# Emacs
*~
\#*\#
/.emacs.desktop
/.emacs.desktop.lock
*.elc
auto-save-list
tramp
.\#*
# Project specific
chat.json
*.json.bak

187
api/chat.api Normal file
View File

@ -0,0 +1,187 @@
syntax = "v1"
info (
title: "GaoBinYou"
desc: "this service will manager chat services"
version: "0.0.1"
contactName: "daniel wang"
contactEmail: "igs170911@gmail.com"
consumes: "application/json"
produces: "application/json"
schemes: "http"
host: "127.0.0.1:8091"
)
type BaseReq {}
type BaseResp {}
type BaseResponse {
Code string `json:"code"` // 狀態碼
Message string `json:"message"` // 訊息
Data interface{} `json:"data,omitempty"` // 資料
Error interface{} `json:"error,omitempty"` // 可選的錯誤信息
}
type AuthHeader {
Authorization string `header:"Authorization" binding:"required"`
}
type AnonLoginReq {
Name string `json:"name" required:"required"`
}
type AnonLoginResp {
UID string `json:"uid"`
Token string `json:"token"`
CentrifugoToken string `json:"centrifugo_token"` // Centrifugo WebSocket 連線用的 token
RefreshToken string `json:"refresh_token"`
ExpireAt int64 `json:"expire_at"`
}
type RefreshTokenReq {
Token string `json:"token"` // 舊的 token可以是已過期的
}
type SendMessageReq {
AuthHeader
RoomID string `path:"room_id"`
Content string `json:"content"`
ClientMsgID string `json:"client_msg_id"`
}
type Message {
MessageID string `json:"message_id"`
UID string `json:"uid"`
Content string `json:"content"`
Timestamp int64 `json:"timestamp"`
}
type ListMessageReq {
AuthHeader
RoomID string `path:"room_id"`
PageSize int64 `form:"page_size,default=20"`
PageIndex int64 `form:"page_index,default=1"`
}
type Pagination {
Total int64 `json:"total,example=100"`
Page int64 `json:"page,example=1"`
PageSize int64 `json:"pageSize,example=10"`
TotalPages int64 `json:"totalPages,example=10"`
}
type ListMessageResp {
Pager Pagination `json:"pager"`
Data []Message `json:"data"`
}
// 加入配對池
type MatchJoinReq {
AuthHeader
}
type MatchJoinResp {
Status string `json:"status"` // waiting | matched
}
// 查詢配對結果Polling 版,最快能上)
type MatchStatusResp {
Status string `json:"status"` // waiting | matched
RoomID string `json:"room_id,omitempty"`
}
@server (
group: chat
prefix: /api/v1
schemes: https
timeout: 10s
)
service chat-api {
@doc (
summary: "匿名登入"
description: "拿到匿名的token使得確認聊天的身分"
)
/*
@respdoc-200 (AnonLoginResp)
@respdoc-400 (BaseResponse) "請求參數格式錯誤"
@respdoc-500 (BaseResponse) // 伺服器內部錯誤
*/
@handler AnonLogin
post /auth/anon (AnonLoginReq) returns (AnonLoginResp)
@doc (
summary: "刷新 Token"
description: "使用現有的 token 來刷新,同時更新 API token 和 Centrifugo token"
)
/*
@respdoc-200 (AnonLoginResp) "返回新的 token"
@respdoc-400 (BaseResponse) "請求參數格式錯誤"
@respdoc-401 (BaseResponse) "Token 無效或已過期"
@respdoc-500 (BaseResponse) // 伺服器內部錯誤
*/
@handler RefreshToken
post /auth/refresh (RefreshTokenReq) returns (AnonLoginResp)
}
@server (
group: chat
prefix: /api/v1
schemes: https
timeout: 10s
middleware: AnonMiddleware // 所有此 group 的路由都需要經過 JWT 驗證
)
service chat-api {
@doc (
summary: "傳送訊息"
description: "傳送訊息"
)
/*
@respdoc-201
@respdoc-400 (BaseResponse) "請求參數格式錯誤"
@respdoc-401 (BaseResponse) "未授權或 Token 無效"
@respdoc-500 (BaseResponse) // 伺服器內部錯誤
*/
@handler SendMessage
post /rooms/:room_id/messages (SendMessageReq) returns (BaseResp)
@doc (
summary: "取得訊息"
description: "取得訊息"
)
/*
@respdoc-200 (ListMessageResp) "取得聊天訊息"
@respdoc-400 (BaseResponse) "請求參數格式錯誤"
@respdoc-401 (BaseResponse) "未授權或 Token 無效"
@respdoc-500 (BaseResponse) // 伺服器內部錯誤
*/
@handler ListMessages
get /rooms/:room_id/messages (ListMessageReq) returns (ListMessageResp)
@doc (
summary: "加入等待序列"
description: "加入等待序列"
)
/*
@respdoc-201 (MatchJoinResp) ""
@respdoc-400 (BaseResponse) "請求參數格式錯誤"
@respdoc-401 (BaseResponse) "未授權或 Token 無效"
@respdoc-500 (BaseResponse) // 伺服器內部錯誤
*/
@handler MatchJoin
post /matchmaking/join (MatchJoinReq) returns (MatchJoinResp)
@doc (
summary: "取得房間資訊"
description: "取得房間資訊"
)
/*
@respdoc-201 (MatchStatusResp) "取得房間資訊"
@respdoc-400 (BaseResponse) "請求參數格式錯誤"
@respdoc-401 (BaseResponse) "未授權或 Token 無效"
@respdoc-500 (BaseResponse) // 伺服器內部錯誤
*/
@handler MatchStatus
get /matchmaking/status (MatchJoinReq) returns (MatchStatusResp)
}

52
chat.go Normal file
View File

@ -0,0 +1,52 @@
package main
import (
"flag"
"net/http"
"github.com/zeromicro/go-zero/core/logx"
"chat/internal/config"
"chat/internal/handler"
"chat/internal/svc"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/rest"
)
var configFile = flag.String("f", "etc/chat-api.yaml", "the config file")
func main() {
flag.Parse()
var c config.Config
conf.MustLoad(*configFile, &c)
server := rest.MustNewServer(c.RestConf)
defer server.Stop()
// 全局處理 OPTIONS 請求CORS 預檢請求)
server.Use(func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 設置 CORS 標頭
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Max-Age", "3600")
// 處理預檢請求
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next(w, r)
}
})
ctx := svc.NewServiceContext(c)
handler.RegisterHandlers(server, ctx)
logx.Infof("Starting server at %s:%d...\n", c.Host, c.Port)
server.Start()
}

64
deployment/README.md Normal file
View File

@ -0,0 +1,64 @@
# Docker Compose 部署說明
## 服務說明
### Redis
- 端口6379
- 用途:配對佇列、使用者狀態、房間成員管理
### Centrifugo
- HTTP API 端口8000
- WebSocket 端口8001
- 用途:即時訊息推送
- 配置文件:`centrifugo.json`
### Cassandra
- 端口9042
- 用途:聊天歷史訊息儲存
## 配置說明
### Centrifugo 配置
1. **更新 `centrifugo.json`**
- `token_hmac_secret_key`:必須與 `etc/chat-api.yaml` 中的 `JWT.CentrifugoSecret``JWT.Secret` 相同
- `api_key`:必須與 `etc/chat-api.yaml` 中的 `Centrifugo.APIKey` 相同
2. **範例配置**
```json
{
"token_hmac_secret_key": "your-secret-key-change-in-production",
"api_key": "api-key"
}
```
3. **對應的 `etc/chat-api.yaml`**
```yaml
Centrifugo:
APIURL: http://localhost:8000/api
APIKey: "api-key"
JWT:
Secret: "your-secret-key-change-in-production"
CentrifugoSecret: "" # 留空則使用 Secret
```
## 啟動服務
```bash
cd deployment
docker-compose up -d
```
## 驗證服務
- Redis: `redis-cli ping`
- Centrifugo: `curl http://localhost:8000/health`
- Cassandra: `docker exec -it cassandra cqlsh`
## 注意事項
1. **生產環境**:請務必修改所有預設密碼和 secret key
2. **Centrifugo Redis**:目前配置使用 Redis 作為 Centrifugo 的後端,確保 Redis 先啟動
3. **網路配置**:如果應用程序也在 Docker 中運行,請使用服務名稱(如 `centrifugo:8000`)而不是 `localhost`

View File

@ -0,0 +1,44 @@
{
"token_hmac_secret_key": "your-secret-key-change-in-production",
"admin_password": "admin",
"admin_secret": "admin-secret",
"api_key": "api-key",
"allowed_origins": [
"*"
],
"log_level": "info",
"log_handler": "stdout",
"websocket_compression": true,
"websocket_read_buffer_size": 1024,
"websocket_write_buffer_size": 1024,
"namespaces": [
{
"name": "default",
"publish": true,
"subscribe_to_publish": true,
"presence": true,
"join_leave": true,
"history_size": 100,
"history_ttl": "300s"
},
{
"name": "room",
"publish": true,
"subscribe_to_publish": true,
"presence": true,
"join_leave": true,
"history_size": 100,
"history_ttl": "300s"
},
{
"name": "user",
"publish": false,
"subscribe_to_publish": false,
"presence": false,
"join_leave": false,
"history_size": 0,
"history_ttl": "0s"
}
],
"redis_address": "redis:6379"
}

View File

@ -0,0 +1,42 @@
services:
redis:
image: redis:7.0
container_name: redis
restart: always
ports:
- "6379:6379"
centrifugo:
image: centrifugo/centrifugo:v5
container_name: centrifugo
restart: always
ports:
- "8000:8000" # HTTP API
- "8001:8001" # WebSocket
volumes:
- ./centrifugo.json:/centrifugo/config.json:ro
command: centrifugo --config=/centrifugo/config.json
healthcheck:
test: ["CMD", "wget", "--quiet", "--tries=1", "--spider", "http://localhost:8000/health"]
interval: 10s
timeout: 5s
retries: 3
depends_on:
- redis
cassandra:
image: cassandra:5.0.4
restart: always
ports:
- "9042:9042"
environment:
TZ: ${TIMEZONE:-UTC}
MAX_HEAP_SIZE: 4G
HEAP_NEWSIZE: 2G
healthcheck:
test: ["CMD", "cqlsh", "-k", "sccflex"]
interval: 10s
timeout: 10s
retries: 12
mem_limit: 8g # <--- 單機 docker-compose up 時建議明確加這行
memswap_limit: 8g # <--- 關掉 swap

0
doc.md Normal file
View File

29
etc/chat-api.yaml Normal file
View File

@ -0,0 +1,29 @@
Name: chat-api
Host: 0.0.0.0
Port: 8888
Redis:
Host: localhost
Port: 6379
Password: ""
DB: 0
Centrifugo:
APIURL: http://localhost:8000/api
APIKey: "api-key"
Cassandra:
Hosts:
- localhost
Port: 9042
Keyspace: chat
Username: "cassandra"
Password: "cassandra"
UseAuth: false
JWT:
Secret: "your-secret-key-change-in-production"
Expire: 86400 # API token 和 Centrifugo token 共用此過期時間(秒)
# CentrifugoSecret: "" # 如果為空或未設置,會自動使用與 Secret 相同的值(簡化配置)
# 如果需要分開管理(例如安全隔離),可以設置不同的值
CentrifugoSecret: ""

4
frontend/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
node_modules/
.DS_Store
*.log

95
frontend/README.md Normal file
View File

@ -0,0 +1,95 @@
# GaoBinYou 前端 POC
這是一個簡單的前端應用,用於測試 GaoBinYou 聊天系統的完整功能。
## 功能
- ✅ 匿名登入
- ✅ Token 刷新
- ✅ 隨機配對
- ✅ 即時訊息WebSocket
- ✅ 歷史訊息載入
- ✅ 自動狀態輪詢
- ✅ CORS 支持(後端已配置)
## 使用方式
### 方法 1使用啟動腳本推薦
```bash
cd frontend
./start.sh 3000
```
然後在瀏覽器打開 `http://localhost:3000`
### 方法 2手動啟動 HTTP 服務器
```bash
# 在 frontend 目錄下啟動一個簡單的 HTTP 服務器
cd frontend
# 使用 Python
python3 -m http.server 3000
# 或使用 Node.js
npx http-server -p 3000
```
然後在瀏覽器打開 `http://localhost:3000`
### 方法 3使用 Live ServerVS Code 擴展)
1. 安裝 Live Server 擴展
2. 右鍵點擊 `index.html`
3. 選擇 "Open with Live Server"
## 配置
`app.js` 中修改以下配置:
```javascript
const API_BASE_URL = 'http://localhost:8888/api/v1'; // 後端 API 地址
const CENTRIFUGO_WS_URL = 'ws://localhost:8001/connection/websocket'; // Centrifugo WebSocket 地址
```
## 使用流程
1. **登入**:輸入暱稱(可選),點擊「開始聊天」
2. **配對**:點擊「加入配對」,等待系統匹配
3. **聊天**:配對成功後自動進入聊天室,可以發送和接收訊息
4. **刷新 Token**:如果 Token 即將過期,點擊「刷新 Token」按鈕
## 注意事項
1. 確保後端服務正在運行(`http://localhost:8888`
2. 確保 Centrifugo 服務正在運行(`ws://localhost:8001`
3. 後端已配置 CORS允許所有來源開發環境
4. 生產環境建議限制 CORS 來源
## 瀏覽器兼容性
- Chrome/Edge (推薦)
- Firefox
- Safari
## 專案結構
```
frontend/
├── index.html # 主頁面
├── styles.css # 樣式表
├── app.js # 應用邏輯
├── start.sh # 啟動腳本
├── README.md # 說明文檔
└── .gitignore # Git 忽略文件
```
## 未來改進
- [ ] 自動重連機制(已部分實現)
- [ ] 訊息已讀狀態
- [ ] 打字指示器
- [ ] 表情符號支持
- [ ] 圖片/文件上傳
- [ ] 使用 Centrifugo 官方 JavaScript SDK

586
frontend/app.js Normal file
View File

@ -0,0 +1,586 @@
// API 配置
const API_BASE_URL = 'http://localhost:8888/api/v1';
// Centrifugo WebSocket URL - 默認使用 8000 端口(與 HTTP API 同端口)
// 如果配置了不同的 WebSocket 端口,請修改此處
const CENTRIFUGO_WS_URL = 'ws://localhost:8000/connection/websocket';
// 應用狀態
let appState = {
uid: null,
token: null,
centrifugoToken: null,
expireAt: null,
roomID: null,
centrifugoClient: null,
matchStatus: null,
wsRetryCount: 0
};
// 工具函數
function log(message, type = 'info') {
const logContainer = document.getElementById('logContainer');
const entry = document.createElement('div');
entry.className = `log-entry ${type}`;
entry.textContent = `[${new Date().toLocaleTimeString()}] ${message}`;
logContainer.appendChild(entry);
logContainer.scrollTop = logContainer.scrollHeight;
console.log(`[${type.toUpperCase()}]`, message);
}
function showSection(sectionId) {
document.querySelectorAll('.section').forEach(section => {
section.classList.add('hidden');
});
document.getElementById(sectionId).classList.remove('hidden');
}
// API 調用函數
async function apiCall(endpoint, method = 'GET', body = null, needAuth = false) {
const options = {
method,
headers: {
'Content-Type': 'application/json',
}
};
if (needAuth && appState.token) {
options.headers['Authorization'] = `Bearer ${appState.token}`;
}
if (body) {
options.body = JSON.stringify(body);
}
try {
const response = await fetch(`${API_BASE_URL}${endpoint}`, options);
// 先讀取響應文本
const text = await response.text();
// 嘗試解析 JSON
let data;
try {
data = text ? JSON.parse(text) : {};
} catch (parseError) {
// 如果不是有效的 JSON使用原始文本作為錯誤信息
log(`響應不是有效的 JSON: ${text}`, 'error');
if (!response.ok) {
throw new Error(text || `HTTP ${response.status}`);
}
data = {};
}
if (!response.ok) {
const errorMsg = data.message || data.error || data.error?.message || text || `HTTP ${response.status}`;
log(`API 錯誤: ${errorMsg}`, 'error');
throw new Error(errorMsg);
}
return data;
} catch (error) {
if (error instanceof TypeError && error.message.includes('fetch')) {
log(`網路錯誤: ${error.message}`, 'error');
throw new Error('網路連接失敗,請檢查後端服務是否運行');
}
log(`API 錯誤: ${error.message}`, 'error');
throw error;
}
}
// 匿名登入
async function handleLogin() {
const userName = document.getElementById('userName').value || 'Anonymous';
const btn = document.getElementById('loginBtn');
btn.disabled = true;
btn.textContent = '登入中...';
try {
log('開始匿名登入...', 'info');
const response = await apiCall('/auth/anon', 'POST', { name: userName });
// 調試:檢查響應內容
log(`登入響應: uid=${response.uid}, hasToken=${!!response.token}, hasCentrifugoToken=${!!response.centrifugo_token}`, 'info');
appState.uid = response.uid;
appState.token = response.token;
appState.centrifugoToken = response.centrifugo_token;
appState.expireAt = response.expire_at;
// 驗證 token 是否正確獲取
if (!appState.centrifugoToken) {
log('警告:未獲取到 Centrifugo token請檢查後端響應', 'error');
log(`完整響應: ${JSON.stringify(response)}`, 'error');
} else {
const tokenPreview = appState.centrifugoToken.substring(0, 20) + '...';
log(`已獲取 Centrifugo token (前20字符): ${tokenPreview}`, 'info');
}
document.getElementById('userUID').textContent = appState.uid;
log(`登入成功UID: ${appState.uid}`, 'success');
showSection('matchingSection');
} catch (error) {
log(`登入失敗: ${error.message}`, 'error');
alert('登入失敗,請重試');
} finally {
btn.disabled = false;
btn.textContent = '開始聊天';
}
}
// 加入配對
async function handleJoinMatch() {
const btn = document.getElementById('joinMatchBtn');
btn.disabled = true;
btn.textContent = '配對中...';
try {
log('加入配對佇列...', 'info');
const response = await apiCall('/matchmaking/join', 'POST', null, true);
appState.matchStatus = response.status;
document.getElementById('matchStatus').textContent = response.status === 'waiting' ? '等待配對中...' : '已配對!';
log(`配對狀態: ${response.status}`, response.status === 'matched' ? 'success' : 'info');
if (response.status === 'matched') {
// 如果立即配對成功,需要查詢 roomID
await handleCheckStatus();
} else {
// 開始輪詢狀態
startStatusPolling();
}
} catch (error) {
log(`加入配對失敗: ${error.message}`, 'error');
} finally {
btn.disabled = false;
btn.textContent = '加入配對';
}
}
// 檢查配對狀態
async function handleCheckStatus() {
try {
const response = await apiCall('/matchmaking/status', 'GET', null, true);
appState.matchStatus = response.status;
document.getElementById('matchStatus').textContent =
response.status === 'waiting' ? '等待配對中...' : '已配對!';
if (response.status === 'matched' && response.room_id) {
appState.roomID = response.room_id;
document.getElementById('roomID').textContent = response.room_id;
log(`配對成功!房間 ID: ${response.room_id}`, 'success');
// 停止輪詢,進入聊天室
stopStatusPolling();
await enterChatRoom();
}
} catch (error) {
log(`檢查狀態失敗: ${error.message}`, 'error');
}
}
// 開始狀態輪詢
let pollingInterval = null;
function startStatusPolling() {
if (pollingInterval) return;
log('開始輪詢配對狀態...', 'info');
pollingInterval = setInterval(async () => {
await handleCheckStatus();
}, 2000); // 每 2 秒檢查一次
}
function stopStatusPolling() {
if (pollingInterval) {
clearInterval(pollingInterval);
pollingInterval = null;
log('停止輪詢配對狀態', 'info');
}
}
// 進入聊天室
async function enterChatRoom() {
showSection('chatSection');
// 進入聊天室前先刷新 token以確保獲取到最新的房間權限
// 因為匿名登入時還不知道房間ID所以當時的 token 沒有房間權限
await handleRefreshToken();
// 連接到 Centrifugo
connectToCentrifugo();
// 載入歷史訊息
await loadHistoryMessages();
}
// 連接到 Centrifugo
function connectToCentrifugo() {
if (!appState.centrifugoToken || !appState.roomID) {
log(`無法連接 Centrifugo缺少 token 或 roomID (token: ${appState.centrifugoToken ? '存在' : '不存在'}, roomID: ${appState.roomID || '不存在'})`, 'error');
return;
}
try {
// 關閉舊連接
if (appState.centrifugoClient) {
appState.centrifugoClient.close();
}
// 檢查 token 是否過期
if (appState.expireAt) {
const now = Math.floor(Date.now() / 1000);
if (now >= appState.expireAt) {
log('Centrifugo token 已過期,請先刷新 token', 'error');
return;
}
const timeUntilExpiry = appState.expireAt - now;
log(`Centrifugo token 將在 ${Math.floor(timeUntilExpiry / 60)} 分鐘後過期`, 'info');
}
// 記錄 token 前幾個字符用於調試(不記錄完整 token 以保護安全)
const tokenPreview = appState.centrifugoToken.substring(0, 20) + '...';
log(`正在連接 Centrifugo使用 token: ${tokenPreview}`, 'info');
// Centrifugo WebSocket 連線(使用 JSON 格式)
// Centrifugo v5: 必須先發送 connect 命令進行認證
const wsUrl = `${CENTRIFUGO_WS_URL}?format=json`;
log(`WebSocket URL: ${wsUrl}`, 'info');
const ws = new WebSocket(wsUrl);
let messageId = 1;
let isConnected = false;
ws.onopen = () => {
log('WebSocket 連接已建立,正在發送認證請求...', 'info');
// 重置重試計數
appState.wsRetryCount = 0;
// Centrifugo v5: 必須先發送 connect 命令進行認證
const connectMsg = {
id: messageId++,
connect: {
token: appState.centrifugoToken
}
};
ws.send(JSON.stringify(connectMsg));
log('已發送 connect 認證請求', 'info');
};
ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
// 跳過空回應(心跳 pong和只有 id 的回應
const isEmptyOrPong = Object.keys(data).length === 0 ||
(Object.keys(data).length === 1 && data.id);
// 只記錄有意義的訊息(非心跳、非空回應)
if (!isEmptyOrPong && !data.subscribe) {
log(`收到 Centrifugo 訊息: ${JSON.stringify(data)}`, 'info');
}
// Centrifugo v5 JSON 協議回應格式
// 處理 connect 回應
if (data.connect) {
isConnected = true;
appState.wsRetryCount = 0; // 重置重連計數
log(`已成功連接到 Centrifugoclient: ${data.connect.client}`, 'success');
// 設置心跳定時器Centrifugo 要求客戶端發送 ping
const pingInterval = (data.connect.ping || 25) * 1000;
if (appState.pingTimer) {
clearInterval(appState.pingTimer);
}
appState.pingTimer = setInterval(() => {
if (ws.readyState === WebSocket.OPEN) {
// 發送空物件作為 ping
ws.send('{}');
}
}, pingInterval);
log(`已設置心跳間隔: ${pingInterval / 1000}`, 'info');
// 檢查是否已經通過 token 自動訂閱了房間頻道
const roomChannel = `room:${appState.roomID}`;
if (data.connect.subs && data.connect.subs[roomChannel]) {
log(`已通過 token 自動訂閱頻道: ${roomChannel}`, 'success');
} else {
// 需要手動訂閱
const subscribeMsg = {
id: messageId++,
subscribe: {
channel: roomChannel
}
};
ws.send(JSON.stringify(subscribeMsg));
log(`已發送訂閱請求: ${roomChannel}`, 'info');
}
}
// 處理 subscribe 回應
else if (data.subscribe) {
log(`成功訂閱頻道`, 'success');
}
// 處理錯誤
else if (data.error) {
// code 105 = already subscribed這不是真正的錯誤
if (data.error.code === 105) {
log(`頻道已訂閱 (這是正常的)`, 'info');
} else {
log(`Centrifugo 錯誤: code=${data.error.code}, message=${data.error.message}`, 'error');
if (data.error.code === 109) {
log('Token 驗證失敗,請檢查 token 是否正確', 'error');
} else if (data.error.code === 103) {
log('頻道訂閱權限不足', 'error');
}
}
}
// 處理 push 訊息(新訊息推送)
else if (data.push) {
const push = data.push;
if (push.pub) {
// 收到發布的訊息
const message = push.pub.data;
if (message && typeof message === 'object') {
displayMessage(message);
log(`收到新訊息: ${message.content || JSON.stringify(message)}`, 'info');
}
} else if (push.join) {
log(`用戶 ${push.join.user} 加入了房間`, 'info');
} else if (push.leave) {
log(`用戶 ${push.leave.user} 離開了房間`, 'info');
}
}
// 處理 disconnect
else if (data.disconnect) {
log(`Centrifugo 服務器主動斷開連接: ${data.disconnect.reason}`, 'error');
}
} catch (error) {
log(`處理 WebSocket 訊息錯誤: ${error.message}`, 'error');
log(`原始訊息: ${event.data}`, 'error');
}
};
ws.onerror = (error) => {
log(`WebSocket 錯誤: ${error.message || error}`, 'error');
// 記錄更多錯誤信息
if (error.target && error.target.readyState === WebSocket.CLOSED) {
log('WebSocket 連接已關閉', 'error');
}
};
ws.onclose = (event) => {
const reason = event.reason || '無';
log(`WebSocket 連線已關閉 (code: ${event.code}, reason: ${reason})`, 'info');
// 清除心跳定時器
if (appState.pingTimer) {
clearInterval(appState.pingTimer);
appState.pingTimer = null;
}
// 1000 表示正常關閉(主動斷開)
const normalCloseCodes = [1000];
if (event.code === 1006) {
log('WebSocket 異常關閉,可能是網路問題或服務中斷', 'error');
}
// 非正常關閉且還在房間中,自動重連
if (!normalCloseCodes.includes(event.code) && appState.roomID && appState.centrifugoToken) {
const retryCount = appState.wsRetryCount || 0;
const maxRetries = 10; // 增加最大重試次數
if (retryCount < maxRetries) {
appState.wsRetryCount = retryCount + 1;
// 使用指數退避策略,最長等待 30 秒
const delay = Math.min(1000 * Math.pow(1.5, retryCount), 30000);
log(`將在 ${Math.round(delay / 1000)} 秒後重新連接... (${retryCount + 1}/${maxRetries})`, 'info');
setTimeout(() => {
if (appState.roomID && appState.centrifugoToken) {
connectToCentrifugo();
}
}, delay);
} else {
log('WebSocket 重連次數已達上限,請刷新頁面重試', 'error');
}
}
};
appState.centrifugoClient = ws;
} catch (error) {
log(`連接 Centrifugo 失敗: ${error.message}`, 'error');
}
}
// 載入歷史訊息
async function loadHistoryMessages() {
try {
log('載入歷史訊息...', 'info');
// 使用查詢參數傳遞 page_size 和 page_index
const response = await apiCall(
`/rooms/${appState.roomID}/messages?page_size=20&page_index=1`,
'GET',
null,
true
);
const messages = response.data || [];
messages.reverse(); // 從舊到新顯示
messages.forEach(msg => {
displayMessage(msg);
});
log(`已載入 ${messages.length} 條歷史訊息`, 'success');
} catch (error) {
log(`載入歷史訊息失敗: ${error.message}`, 'error');
// 即使載入失敗也不影響聊天功能
}
}
// 顯示訊息
function displayMessage(message) {
const container = document.getElementById('messagesContainer');
const messageDiv = document.createElement('div');
messageDiv.className = `message ${message.uid === appState.uid ? 'own' : ''}`;
const date = new Date(message.timestamp);
messageDiv.innerHTML = `
<div class="message-header">
<span>${message.uid === appState.uid ? '我' : '對方'}</span>
<span>${date.toLocaleTimeString()}</span>
</div>
<div class="message-content">${escapeHtml(message.content)}</div>
`;
container.appendChild(messageDiv);
container.scrollTop = container.scrollHeight;
}
// 發送訊息
async function handleSendMessage() {
const input = document.getElementById('messageInput');
const content = input.value.trim();
if (!content) {
alert('請輸入訊息內容');
return;
}
if (!appState.roomID) {
alert('尚未加入房間');
return;
}
const btn = document.getElementById('sendBtn');
btn.disabled = true;
try {
const clientMsgID = `msg_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
await apiCall(
`/rooms/${appState.roomID}/messages`,
'POST',
{
content: content,
client_msg_id: clientMsgID
},
true
);
log(`訊息已發送: ${content}`, 'success');
input.value = '';
} catch (error) {
log(`發送訊息失敗: ${error.message}`, 'error');
alert('發送訊息失敗,請重試');
} finally {
btn.disabled = false;
input.focus();
}
}
// 刷新 Token
async function handleRefreshToken() {
if (!appState.token) {
alert('請先登入');
return;
}
const btn = document.getElementById('refreshTokenBtn');
btn.disabled = true;
btn.textContent = '刷新中...';
try {
log('刷新 Token...', 'info');
const response = await apiCall('/auth/refresh', 'POST', {
token: appState.token
});
appState.token = response.token;
appState.centrifugoToken = response.centrifugo_token;
appState.expireAt = response.expire_at;
// 驗證新 token 是否正確獲取
if (!appState.centrifugoToken) {
log('警告:刷新後未獲取到 Centrifugo token', 'error');
} else {
const tokenPreview = appState.centrifugoToken.substring(0, 20) + '...';
log(`已獲取新的 Centrifugo token: ${tokenPreview}`, 'info');
}
log('Token 刷新成功!', 'success');
// 重新連接 Centrifugo
if (appState.centrifugoClient) {
appState.centrifugoClient.close();
}
if (appState.roomID) {
connectToCentrifugo();
}
} catch (error) {
log(`刷新 Token 失敗: ${error.message}`, 'error');
alert('Token 刷新失敗,請重新登入');
} finally {
btn.disabled = false;
btn.textContent = '刷新 Token';
}
}
// 鍵盤事件處理
function handleKeyPress(event) {
if (event.key === 'Enter') {
handleSendMessage();
}
}
// HTML 轉義
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
// 檢查 Token 過期
function checkTokenExpiry() {
if (appState.expireAt) {
const now = Math.floor(Date.now() / 1000);
const timeUntilExpiry = appState.expireAt - now;
if (timeUntilExpiry < 0) {
log('Token 已過期', 'error');
alert('Token 已過期,請刷新或重新登入');
} else if (timeUntilExpiry < 300) { // 5 分鐘內過期
log(`Token 將在 ${Math.floor(timeUntilExpiry / 60)} 分鐘後過期,建議刷新`, 'info');
}
}
}
// 定期檢查 Token 過期
setInterval(checkTokenExpiry, 60000); // 每分鐘檢查一次
// 初始化
log('應用程式已載入', 'info');

56
frontend/index.html Normal file
View File

@ -0,0 +1,56 @@
<!DOCTYPE html>
<html lang="zh-TW">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>GaoBinYou - 隨機配對聊天</title>
<link rel="stylesheet" href="styles.css">
</head>
<body>
<div class="container">
<!-- 登入區域 -->
<div id="loginSection" class="section">
<h1>GaoBinYou 聊天室</h1>
<div class="login-form">
<input type="text" id="userName" placeholder="輸入你的暱稱(可選)" />
<button id="loginBtn" onclick="handleLogin()">開始聊天</button>
</div>
</div>
<!-- 配對區域 -->
<div id="matchingSection" class="section hidden">
<div class="status-info">
<p>狀態: <span id="matchStatus">等待配對中...</span></p>
<p>UID: <span id="userUID"></span></p>
</div>
<button id="joinMatchBtn" onclick="handleJoinMatch()">加入配對</button>
<button id="checkStatusBtn" onclick="handleCheckStatus()">檢查狀態</button>
</div>
<!-- 聊天區域 -->
<div id="chatSection" class="section hidden">
<div class="chat-header">
<h2>聊天室: <span id="roomID"></span></h2>
<button id="refreshTokenBtn" onclick="handleRefreshToken()" class="btn-small">刷新 Token</button>
</div>
<div class="chat-container">
<div id="messagesContainer" class="messages"></div>
<div class="input-area">
<input type="text" id="messageInput" placeholder="輸入訊息..." onkeypress="handleKeyPress(event)" />
<button id="sendBtn" onclick="handleSendMessage()">發送</button>
</div>
</div>
</div>
<!-- 日誌區域 -->
<div class="log-section">
<h3>系統日誌</h3>
<div id="logContainer" class="log-container"></div>
</div>
</div>
<script src="app.js"></script>
</body>
</html>

29
frontend/start.sh Executable file
View File

@ -0,0 +1,29 @@
#!/bin/bash
# 簡單的 HTTP 服務器啟動腳本
PORT=${1:-3000}
echo "正在啟動前端服務器..."
echo "訪問地址: http://localhost:${PORT}"
echo ""
echo "按 Ctrl+C 停止服務器"
echo ""
# 檢查是否有 Python
if command -v python3 &> /dev/null; then
echo "使用 Python HTTP 服務器"
python3 -m http.server $PORT
elif command -v python &> /dev/null; then
echo "使用 Python HTTP 服務器"
python -m http.server $PORT
# 檢查是否有 Node.js 和 http-server
elif command -v npx &> /dev/null; then
echo "使用 Node.js http-server"
npx http-server -p $PORT -c-1
else
echo "錯誤: 未找到 Python 或 Node.js"
echo "請安裝 Python 3 或 Node.js或使用其他 HTTP 服務器"
exit 1
fi

228
frontend/styles.css Normal file
View File

@ -0,0 +1,228 @@
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 800px;
margin: 0 auto;
background: white;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.2);
overflow: hidden;
}
.section {
padding: 30px;
}
.hidden {
display: none;
}
h1 {
text-align: center;
color: #333;
margin-bottom: 30px;
}
h2 {
color: #333;
font-size: 1.2em;
margin-bottom: 15px;
}
h3 {
color: #666;
font-size: 1em;
margin-bottom: 10px;
padding: 10px 20px;
background: #f5f5f5;
border-bottom: 1px solid #ddd;
}
.login-form {
display: flex;
flex-direction: column;
gap: 15px;
}
input[type="text"] {
padding: 12px;
border: 2px solid #e0e0e0;
border-radius: 6px;
font-size: 16px;
transition: border-color 0.3s;
}
input[type="text"]:focus {
outline: none;
border-color: #667eea;
}
button {
padding: 12px 24px;
background: #667eea;
color: white;
border: none;
border-radius: 6px;
font-size: 16px;
cursor: pointer;
transition: background 0.3s;
}
button:hover {
background: #5568d3;
}
button:disabled {
background: #ccc;
cursor: not-allowed;
}
.btn-small {
padding: 6px 12px;
font-size: 14px;
}
.status-info {
background: #f8f9fa;
padding: 15px;
border-radius: 6px;
margin-bottom: 15px;
}
.status-info p {
margin: 5px 0;
color: #666;
}
.status-info span {
color: #333;
font-weight: bold;
}
.chat-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 15px;
padding-bottom: 15px;
border-bottom: 2px solid #e0e0e0;
}
.chat-container {
display: flex;
flex-direction: column;
height: 500px;
}
.messages {
flex: 1;
overflow-y: auto;
padding: 15px;
background: #f8f9fa;
border-radius: 6px;
margin-bottom: 15px;
}
.message {
margin-bottom: 15px;
padding: 10px;
background: white;
border-radius: 6px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
.message.own {
background: #e3f2fd;
margin-left: 20px;
}
.message-header {
display: flex;
justify-content: space-between;
margin-bottom: 5px;
font-size: 12px;
color: #666;
}
.message-content {
color: #333;
word-wrap: break-word;
}
.input-area {
display: flex;
gap: 10px;
}
.input-area input {
flex: 1;
}
.log-section {
border-top: 2px solid #e0e0e0;
}
.log-container {
max-height: 200px;
overflow-y: auto;
padding: 15px;
background: #f8f9fa;
font-family: 'Courier New', monospace;
font-size: 12px;
}
.log-entry {
margin-bottom: 5px;
padding: 5px;
border-left: 3px solid #667eea;
padding-left: 10px;
}
.log-entry.error {
border-left-color: #e74c3c;
color: #e74c3c;
}
.log-entry.success {
border-left-color: #27ae60;
color: #27ae60;
}
.log-entry.info {
border-left-color: #3498db;
color: #3498db;
}
/* 滾動條樣式 */
.messages::-webkit-scrollbar,
.log-container::-webkit-scrollbar {
width: 6px;
}
.messages::-webkit-scrollbar-track,
.log-container::-webkit-scrollbar-track {
background: #f1f1f1;
}
.messages::-webkit-scrollbar-thumb,
.log-container::-webkit-scrollbar-thumb {
background: #888;
border-radius: 3px;
}
.messages::-webkit-scrollbar-thumb:hover,
.log-container::-webkit-scrollbar-thumb:hover {
background: #555;
}

116
go.mod Normal file
View File

@ -0,0 +1,116 @@
module chat
go 1.25.1
require (
github.com/go-playground/validator/v10 v10.30.1
github.com/gocql/gocql v1.7.0
github.com/golang-jwt/jwt/v4 v4.5.2
github.com/google/uuid v1.6.0
github.com/panjf2000/ants/v2 v2.11.4
github.com/redis/go-redis/v9 v9.17.2
github.com/scylladb/gocqlx/v2 v2.8.0
github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go v0.40.0
github.com/zeromicro/go-zero v1.9.4
go.mongodb.org/mongo-driver/v2 v2.4.1
google.golang.org/grpc v1.74.0-dev
)
require (
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/grafana/pyroscope-go v1.2.7 // indirect
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
github.com/moby/sys/user v0.4.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/openzipkin/zipkin-go v0.4.3 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/prometheus/client_golang v1.21.1 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/scylladb/go-reflectx v1.0.1 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.39.0 // indirect
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.35.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.35.0 // indirect
go.opentelemetry.io/otel/exporters/zipkin v1.35.0 // indirect
go.opentelemetry.io/otel/metric v1.39.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.39.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/crypto v0.46.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251222181119-0a764e51fe1b // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

40
internal/config/config.go Normal file
View File

@ -0,0 +1,40 @@
package config
import (
"github.com/zeromicro/go-zero/rest"
)
type Config struct {
rest.RestConf
Redis RedisConf
Centrifugo CentrifugoConf
Cassandra CassandraConf
JWT JWTConf
}
type RedisConf struct {
Host string
Port int
Password string
DB int
}
type CentrifugoConf struct {
APIURL string
APIKey string
}
type CassandraConf struct {
Hosts []string
Port int
Keyspace string
Username string
Password string
UseAuth bool
}
type JWTConf struct {
Secret string
Expire int64 // seconds - API token 和 Centrifugo token 共用此過期時間
CentrifugoSecret string // for Centrifugo JWT - 如果為空,則使用與 Secret 相同的值(簡化配置)
}

View File

@ -0,0 +1,7 @@
package consts
const (
AnonUIDPrefix = "anon_"
RoomIDPrefix = "room_"
)

View File

@ -0,0 +1,7 @@
package consts
const (
StatusWaiting = "waiting"
StatusMatched = "matched"
)

View File

@ -0,0 +1,19 @@
package entity
// Message 對應 Cassandra 的 messages_by_room 表
// Primary Key: (room_id, bucket_day)
// Clustering Key: ts DESC, message_id
type Message struct {
RoomID string `db:"room_id" partition_key:"true"`
BucketDay string `db:"bucket_day" partition_key:"true"` // yyyyMMdd
TS int64 `db:"ts" clustering_key:"true"` // timestamp
MessageID string `db:"message_id" clustering_key:"true"`
UID string `db:"uid"`
Content string `db:"content"`
}
// TableName 返回表名
func (m Message) TableName() string {
return "messages_by_room"
}

View File

@ -0,0 +1,14 @@
package entity
// RoomMember 房間成員資訊
type RoomMember struct {
RoomID string
UID string
}
// MatchResult 配對結果
type MatchResult struct {
RoomID string
Members []string
}

View File

@ -0,0 +1,19 @@
package redis
import "fmt"
// MatchQueueKey 返回配對佇列的 Redis key
func MatchQueueKey() string {
return "match:queue"
}
// MatchUserKey 返回使用者配對狀態的 Redis key
func MatchUserKey(uid string) string {
return fmt.Sprintf("match:user:%s", uid)
}
// RoomMembersKey 返回房間成員的 Redis key
func RoomMembersKey(roomID string) string {
return fmt.Sprintf("room:%s:members", roomID)
}

View File

@ -0,0 +1,19 @@
package repository
import "context"
// MatchmakingRepository 定義配對相關的資料存取介面
type MatchmakingRepository interface {
// JoinQueue 加入配對佇列返回狀態和房間ID如果配對成功
JoinQueue(ctx context.Context, uid string) (status string, roomID string, err error)
// GetMatchStatus 查詢使用者的配對狀態
GetMatchStatus(ctx context.Context, uid string) (status string, roomID string, err error)
// CreateRoom 建立房間並添加成員
CreateRoom(ctx context.Context, roomID string, members []string) error
// IsRoomMember 檢查使用者是否為房間成員
IsRoomMember(ctx context.Context, roomID string, uid string) (bool, error)
}

View File

@ -0,0 +1,16 @@
package repository
import (
"chat/internal/domain/entity"
"context"
)
// MessageRepository 定義訊息相關的資料存取介面
type MessageRepository interface {
// Insert 插入訊息
Insert(ctx context.Context, msg *entity.Message) error
// ListByRoom 查詢房間訊息(分頁)
ListByRoom(ctx context.Context, roomID string, bucketDay string, pageSize int, pageIndex int) ([]entity.Message, int64, error)
}

View File

@ -0,0 +1,13 @@
package usecase
import "context"
// AuthUseCase 定義認證相關的業務邏輯介面
type AuthUseCase interface {
// AnonLogin 匿名登入,返回 UID、API token、Centrifugo token 和過期時間
AnonLogin(ctx context.Context, name string) (uid string, token string, centrifugoToken string, expireAt int64, err error)
// RefreshToken 刷新 token使用現有的 token 來生成新的 API token 和 Centrifugo token
// 返回新的 API token、Centrifugo token 和過期時間
RefreshToken(ctx context.Context, oldToken string) (uid string, token string, centrifugoToken string, expireAt int64, err error)
}

View File

@ -0,0 +1,13 @@
package usecase
import "context"
// MatchmakingUseCase 定義配對相關的業務邏輯介面
type MatchmakingUseCase interface {
// JoinQueue 加入配對佇列
JoinQueue(ctx context.Context, uid string) (status string, err error)
// GetStatus 查詢配對狀態
GetStatus(ctx context.Context, uid string) (status string, roomID string, err error)
}

View File

@ -0,0 +1,16 @@
package usecase
import (
"chat/internal/domain/entity"
"context"
)
// MessageUseCase 定義訊息相關的業務邏輯介面
type MessageUseCase interface {
// SendMessage 發送訊息
SendMessage(ctx context.Context, roomID string, uid string, content string, clientMsgID string) error
// ListMessages 查詢訊息列表(分頁)
ListMessages(ctx context.Context, roomID string, uid string, pageSize int, pageIndex int) ([]entity.Message, int64, error)
}

112
internal/handler/routes.go Normal file
View File

@ -0,0 +1,112 @@
// Code generated by goctl. DO NOT EDIT.
// goctl 1.9.2
package handler
import (
"net/http"
"time"
chat "chat/internal/handler/chat"
"chat/internal/middleware"
"chat/internal/svc"
"github.com/zeromicro/go-zero/rest"
)
func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
// OPTIONS 處理器CORS 預檢請求)
optionsHandler := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Max-Age", "3600")
w.WriteHeader(http.StatusNoContent)
}
server.AddRoutes(
rest.WithMiddlewares(
[]rest.Middleware{middleware.CORSMiddleware},
[]rest.Route{
{
// OPTIONS 支持
Method: http.MethodOptions,
Path: "/auth/anon",
Handler: optionsHandler,
},
{
// 匿名登入
Method: http.MethodPost,
Path: "/auth/anon",
Handler: chat.AnonLoginHandler(serverCtx),
},
{
// OPTIONS 支持
Method: http.MethodOptions,
Path: "/auth/refresh",
Handler: optionsHandler,
},
{
// 刷新 Token
Method: http.MethodPost,
Path: "/auth/refresh",
Handler: chat.RefreshTokenHandler(serverCtx),
},
}...,
),
rest.WithPrefix("/api/v1"),
rest.WithTimeout(10000*time.Millisecond),
)
server.AddRoutes(
rest.WithMiddlewares(
[]rest.Middleware{middleware.CORSMiddleware, serverCtx.AnonMiddleware},
[]rest.Route{
{
// OPTIONS 支持
Method: http.MethodOptions,
Path: "/matchmaking/join",
Handler: optionsHandler,
},
{
// 加入等待序列
Method: http.MethodPost,
Path: "/matchmaking/join",
Handler: chat.MatchJoinHandler(serverCtx),
},
{
// OPTIONS 支持
Method: http.MethodOptions,
Path: "/matchmaking/status",
Handler: optionsHandler,
},
{
// 取得房間資訊
Method: http.MethodGet,
Path: "/matchmaking/status",
Handler: chat.MatchStatusHandler(serverCtx),
},
{
// OPTIONS 支持
Method: http.MethodOptions,
Path: "/rooms/:room_id/messages",
Handler: optionsHandler,
},
{
// 傳送訊息
Method: http.MethodPost,
Path: "/rooms/:room_id/messages",
Handler: chat.SendMessageHandler(serverCtx),
},
{
// 取得訊息
Method: http.MethodGet,
Path: "/rooms/:room_id/messages",
Handler: chat.ListMessagesHandler(serverCtx),
},
}...,
),
rest.WithPrefix("/api/v1"),
rest.WithTimeout(10000*time.Millisecond),
)
}

View File

@ -0,0 +1,758 @@
# Cassandra Client Library
一個基於 Go Generics 的 Cassandra 客戶端庫,提供類型安全的 Repository 模式和流暢的查詢構建器 API。
## 功能特色
- **類型安全**: 使用 Go Generics 提供編譯時類型檢查
- **Repository 模式**: 簡潔的 CRUD 操作介面
- **流暢查詢**: 鏈式查詢構建器,支援條件、排序、限制
- **分散式鎖**: 基於 Cassandra 的 IF NOT EXISTS 實現分散式鎖
- **批次操作**: 支援批次插入、更新、刪除
- **SAI 索引支援**: 完整的 SAI (Storage-Attached Indexing) 索引管理功能
- **Option 模式**: 靈活的配置選項
- **錯誤處理**: 統一的錯誤處理機制
- **高效能**: 內建連接池、重試機制、Prepared Statement 快取
## 安裝
```bash
go get github.com/scylladb/gocqlx/v2
go get github.com/gocql/gocql
```
## 快速開始
### 1. 定義資料模型
```go
package main
import (
"time"
"github.com/gocql/gocql"
"backend/pkg/library/cassandra"
)
// User 定義用戶資料模型
type User struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}
// TableName 實現 Table 介面
func (u User) TableName() string {
return "users"
}
```
### 2. 初始化資料庫連接
```go
package main
import (
"context"
"fmt"
"log"
"backend/pkg/library/cassandra"
"github.com/gocql/gocql"
)
func main() {
// 創建資料庫連接
db, err := cassandra.New(
cassandra.WithHosts("127.0.0.1"),
cassandra.WithPort(9042),
cassandra.WithKeyspace("my_keyspace"),
cassandra.WithAuth("username", "password"),
cassandra.WithConsistency(gocql.Quorum),
)
if err != nil {
log.Fatal(err)
}
defer db.Close()
// 創建 Repository
userRepo, err := cassandra.NewRepository[User](db, "my_keyspace")
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// 使用 Repository...
_ = userRepo
}
```
## 詳細範例
### CRUD 操作
#### 插入資料
```go
// 插入單筆資料
user := User{
ID: gocql.TimeUUID(),
Name: "Alice",
Email: "alice@example.com",
Age: 30,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
err := userRepo.Insert(ctx, user)
if err != nil {
log.Printf("插入失敗: %v", err)
}
// 批次插入
users := []User{
{ID: gocql.TimeUUID(), Name: "Bob", Email: "bob@example.com"},
{ID: gocql.TimeUUID(), Name: "Charlie", Email: "charlie@example.com"},
}
err = userRepo.InsertMany(ctx, users)
if err != nil {
log.Printf("批次插入失敗: %v", err)
}
```
#### 查詢資料
```go
// 根據主鍵查詢
userID := gocql.TimeUUID()
user, err := userRepo.Get(ctx, userID)
if err != nil {
if cassandra.IsNotFound(err) {
log.Println("用戶不存在")
} else {
log.Printf("查詢失敗: %v", err)
}
return
}
fmt.Printf("用戶: %+v\n", user)
```
#### 更新資料
```go
// 更新資料(只更新非零值欄位)
user.Name = "Alice Updated"
user.Email = "alice.updated@example.com"
err = userRepo.Update(ctx, user)
if err != nil {
log.Printf("更新失敗: %v", err)
}
// 更新所有欄位(包括零值)
user.Age = 0 // 零值也會被更新
err = userRepo.UpdateAll(ctx, user)
if err != nil {
log.Printf("更新失敗: %v", err)
}
```
#### 刪除資料
```go
// 刪除資料
err = userRepo.Delete(ctx, userID)
if err != nil {
log.Printf("刪除失敗: %v", err)
}
```
### 查詢構建器
#### 基本查詢
```go
// 查詢所有符合條件的記錄
var users []User
err := userRepo.Query().
Where(cassandra.Eq("age", 30)).
OrderBy("created_at", cassandra.DESC).
Limit(10).
Scan(ctx, &users)
if err != nil {
log.Printf("查詢失敗: %v", err)
}
// 查詢單筆記錄
user, err := userRepo.Query().
Where(cassandra.Eq("email", "alice@example.com")).
One(ctx)
if err != nil {
if cassandra.IsNotFound(err) {
log.Println("用戶不存在")
} else {
log.Printf("查詢失敗: %v", err)
}
}
```
#### 條件查詢
```go
// 等於條件
userRepo.Query().Where(cassandra.Eq("name", "Alice"))
// IN 條件
userRepo.Query().Where(cassandra.In("id", []any{id1, id2, id3}))
// 大於條件
userRepo.Query().Where(cassandra.Gt("age", 18))
// 小於條件
userRepo.Query().Where(cassandra.Lt("age", 65))
// 組合多個條件
userRepo.Query().
Where(cassandra.Eq("status", "active")).
Where(cassandra.Gt("age", 18))
```
#### 排序和限制
```go
// 按建立時間降序排列,限制 20 筆
var users []User
err := userRepo.Query().
OrderBy("created_at", cassandra.DESC).
Limit(20).
Scan(ctx, &users)
// 多欄位排序
err = userRepo.Query().
OrderBy("status", cassandra.ASC).
OrderBy("created_at", cassandra.DESC).
Scan(ctx, &users)
```
#### 選擇特定欄位
```go
// 只查詢特定欄位
var users []User
err := userRepo.Query().
Select("id", "name", "email").
Where(cassandra.Eq("status", "active")).
Scan(ctx, &users)
```
#### 計數查詢
```go
// 計算符合條件的記錄數
count, err := userRepo.Query().
Where(cassandra.Eq("status", "active")).
Count(ctx)
if err != nil {
log.Printf("計數失敗: %v", err)
} else {
fmt.Printf("活躍用戶數: %d\n", count)
}
```
### 分散式鎖
```go
// 獲取鎖(預設 30 秒 TTL
lockUser := User{ID: userID}
err := userRepo.TryLock(ctx, lockUser)
if err != nil {
if cassandra.IsLockFailed(err) {
log.Println("獲取鎖失敗,資源已被鎖定")
} else {
log.Printf("鎖操作失敗: %v", err)
}
return
}
// 執行需要鎖定的操作
defer func() {
// 釋放鎖
if err := userRepo.UnLock(ctx, lockUser); err != nil {
log.Printf("釋放鎖失敗: %v", err)
}
}()
// 執行業務邏輯...
```
#### 自訂鎖 TTL
```go
// 設定鎖的 TTL 為 60 秒
err := userRepo.TryLock(ctx, lockUser, cassandra.WithLockTTL(60*time.Second))
// 永不自動解鎖
err := userRepo.TryLock(ctx, lockUser, cassandra.WithNoLockExpire())
```
### 複雜主鍵
#### 複合主鍵Partition Key + Clustering Key
```go
// 定義複合主鍵模型
type Order struct {
UserID gocql.UUID `db:"user_id" partition_key:"true"`
OrderID gocql.UUID `db:"order_id" clustering_key:"true"`
ProductID string `db:"product_id"`
Quantity int `db:"quantity"`
Price float64 `db:"price"`
CreatedAt time.Time `db:"created_at"`
}
func (o Order) TableName() string {
return "orders"
}
// 查詢時需要提供完整的主鍵
order, err := orderRepo.Get(ctx, Order{
UserID: userID,
OrderID: orderID,
})
```
#### 多欄位 Partition Key
```go
type Message struct {
ChatID gocql.UUID `db:"chat_id" partition_key:"true"`
MessageID gocql.UUID `db:"message_id" clustering_key:"true"`
UserID gocql.UUID `db:"user_id" partition_key:"true"`
Content string `db:"content"`
CreatedAt time.Time `db:"created_at"`
}
func (m Message) TableName() string {
return "messages"
}
// 查詢時需要提供所有 Partition Key
message, err := messageRepo.Get(ctx, Message{
ChatID: chatID,
UserID: userID,
MessageID: messageID,
})
```
## 配置選項
### 連接選項
```go
db, err := cassandra.New(
// 主機列表
cassandra.WithHosts("127.0.0.1", "127.0.0.2", "127.0.0.3"),
// 連接埠
cassandra.WithPort(9042),
// Keyspace
cassandra.WithKeyspace("my_keyspace"),
// 認證
cassandra.WithAuth("username", "password"),
// 一致性級別
cassandra.WithConsistency(gocql.Quorum),
// 連接超時
cassandra.WithConnectTimeout(10 * time.Second),
// 每個節點的連接數
cassandra.WithNumConns(10),
// 重試次數
cassandra.WithMaxRetries(3),
// 重試間隔
cassandra.WithRetryInterval(100*time.Millisecond, 1*time.Second),
// 重連間隔
cassandra.WithReconnectInterval(1*time.Second, 10*time.Second),
// CQL 版本
cassandra.WithCQLVersion("3.0.0"),
)
```
## 錯誤處理
### 錯誤類型
```go
// 檢查是否為特定錯誤
if cassandra.IsNotFound(err) {
// 記錄不存在
}
if cassandra.IsConflict(err) {
// 衝突錯誤(如唯一鍵衝突)
}
if cassandra.IsLockFailed(err) {
// 獲取鎖失敗
}
// 使用 errors.As 獲取詳細錯誤資訊
var cassandraErr *cassandra.Error
if errors.As(err, &cassandraErr) {
fmt.Printf("錯誤代碼: %s\n", cassandraErr.Code)
fmt.Printf("錯誤訊息: %s\n", cassandraErr.Message)
fmt.Printf("資料表: %s\n", cassandraErr.Table)
}
```
### 錯誤代碼
- `NOT_FOUND`: 記錄未找到
- `CONFLICT`: 衝突(如唯一鍵衝突、鎖獲取失敗)
- `INVALID_INPUT`: 輸入參數無效
- `MISSING_PARTITION_KEY`: 缺少 Partition Key
- `NO_FIELDS_TO_UPDATE`: 沒有欄位需要更新
- `MISSING_TABLE_NAME`: 缺少 TableName 方法
- `MISSING_WHERE_CONDITION`: 缺少 WHERE 條件
## 最佳實踐
### 1. 使用 Context
```go
// 所有操作都應該傳入 context以便支援超時和取消
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
user, err := userRepo.Get(ctx, userID)
```
### 2. 錯誤處理
```go
user, err := userRepo.Get(ctx, userID)
if err != nil {
if cassandra.IsNotFound(err) {
// 處理不存在的情況
return nil, ErrUserNotFound
}
// 處理其他錯誤
return nil, fmt.Errorf("查詢用戶失敗: %w", err)
}
```
### 3. 批次操作
```go
// 對於大量資料,使用批次插入
const batchSize = 100
for i := 0; i < len(users); i += batchSize {
end := i + batchSize
if end > len(users) {
end = len(users)
}
err := userRepo.InsertMany(ctx, users[i:end])
if err != nil {
log.Printf("批次插入失敗 (索引 %d-%d): %v", i, end, err)
}
}
```
### 4. 使用分散式鎖
```go
// 在需要保證原子性的操作中使用鎖
err := userRepo.TryLock(ctx, lockUser, cassandra.WithLockTTL(30*time.Second))
if err != nil {
return fmt.Errorf("獲取鎖失敗: %w", err)
}
defer userRepo.UnLock(ctx, lockUser)
// 執行需要原子性的操作
```
### 5. 查詢優化
```go
// 只選擇需要的欄位
var users []User
err := userRepo.Query().
Select("id", "name", "email"). // 只選擇需要的欄位
Where(cassandra.Eq("status", "active")).
Scan(ctx, &users)
// 使用適當的限制
err = userRepo.Query().
Where(cassandra.Eq("status", "active")).
Limit(100). // 限制結果數量
Scan(ctx, &users)
```
## SAI 索引管理
### 建立 SAI 索引
```go
// 檢查是否支援 SAI
if !db.SaiSupported() {
log.Fatal("SAI is not supported in this Cassandra version")
}
// 建立標準索引
err := db.CreateSAIIndex(ctx, "my_keyspace", "users", "email", "users_email_idx", nil)
if err != nil {
log.Printf("建立索引失敗: %v", err)
}
// 建立全文索引(不區分大小寫)
opts := &cassandra.SAIIndexOptions{
IndexType: cassandra.SAIIndexTypeFullText,
IsAsync: false,
CaseSensitive: false,
}
err = db.CreateSAIIndex(ctx, "my_keyspace", "posts", "content", "posts_content_ft_idx", opts)
```
### 查詢 SAI 索引
```go
// 列出資料表的所有 SAI 索引
indexes, err := db.ListSAIIndexes(ctx, "my_keyspace", "users")
if err != nil {
log.Printf("查詢索引失敗: %v", err)
} else {
for _, idx := range indexes {
fmt.Printf("索引: %s, 欄位: %s, 類型: %s\n", idx.Name, idx.Column, idx.Type)
}
}
// 檢查索引是否存在
exists, err := db.CheckSAIIndexExists(ctx, "my_keyspace", "users_email_idx")
if err != nil {
log.Printf("檢查索引失敗: %v", err)
} else if exists {
fmt.Println("索引存在")
}
```
### 刪除 SAI 索引
```go
// 刪除索引
err := db.DropSAIIndex(ctx, "my_keyspace", "users_email_idx")
if err != nil {
log.Printf("刪除索引失敗: %v", err)
}
```
### SAI 索引類型
- **SAIIndexTypeStandard**: 標準索引(等於查詢)
- **SAIIndexTypeCollection**: 集合索引(用於 list、set、map
- **SAIIndexTypeFullText**: 全文索引
### SAI 索引選項
```go
opts := &cassandra.SAIIndexOptions{
IndexType: cassandra.SAIIndexTypeFullText, // 索引類型
IsAsync: false, // 是否異步建立
CaseSensitive: true, // 是否區分大小寫
}
```
## 注意事項
### 1. 主鍵要求
- `Get``Delete` 操作必須提供完整的主鍵(所有 Partition Key 和 Clustering Key
- 單一主鍵值只適用於單一 Partition Key 且無 Clustering Key 的情況
### 2. 更新操作
- `Update` 只更新非零值欄位
- `UpdateAll` 更新所有欄位(包括零值)
- 更新操作必須包含主鍵欄位
### 3. 查詢限制
- Cassandra 的查詢必須包含所有 Partition Key
- 排序只能按 Clustering Key 進行
- 不支援 JOIN 操作
### 4. 分散式鎖
- 鎖使用 IF NOT EXISTS 實現,預設 30 秒 TTL
- 獲取鎖失敗時會返回 `CONFLICT` 錯誤
- 釋放鎖時會自動重試,最多 3 次
### 5. 批次操作
- 批次操作有大小限制(建議不超過 1000 筆)
- 批次操作中的所有操作必須屬於同一個 Partition Key
### 6. SAI 索引
- SAI 索引需要 Cassandra 4.0.9+ 版本(建議 5.0+
- 建立索引前請先檢查 `db.SaiSupported()`
- 索引建立是異步操作,可能需要一些時間
- 刪除索引時使用 `IF EXISTS`,避免索引不存在時報錯
- 使用 SAI 索引可以大幅提升非主鍵欄位的查詢效能
- 全文索引支援不區分大小寫的搜尋
## 完整範例
```go
package main
import (
"context"
"fmt"
"log"
"time"
"backend/pkg/library/cassandra"
"github.com/gocql/gocql"
)
type User struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
Status string `db:"status"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}
func (u User) TableName() string {
return "users"
}
func main() {
// 初始化資料庫連接
db, err := cassandra.New(
cassandra.WithHosts("127.0.0.1"),
cassandra.WithPort(9042),
cassandra.WithKeyspace("my_keyspace"),
)
if err != nil {
log.Fatal(err)
}
defer db.Close()
// 創建 Repository
userRepo, err := cassandra.NewRepository[User](db, "my_keyspace")
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// 插入用戶
user := User{
ID: gocql.TimeUUID(),
Name: "Alice",
Email: "alice@example.com",
Age: 30,
Status: "active",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := userRepo.Insert(ctx, user); err != nil {
log.Printf("插入失敗: %v", err)
return
}
// 查詢用戶
foundUser, err := userRepo.Get(ctx, user.ID)
if err != nil {
log.Printf("查詢失敗: %v", err)
return
}
fmt.Printf("查詢到的用戶: %+v\n", foundUser)
// 更新用戶
user.Name = "Alice Updated"
user.Email = "alice.updated@example.com"
if err := userRepo.Update(ctx, user); err != nil {
log.Printf("更新失敗: %v", err)
return
}
// 查詢活躍用戶
var activeUsers []User
if err := userRepo.Query().
Where(cassandra.Eq("status", "active")).
OrderBy("created_at", cassandra.DESC).
Limit(10).
Scan(ctx, &activeUsers); err != nil {
log.Printf("查詢失敗: %v", err)
return
}
fmt.Printf("活躍用戶數: %d\n", len(activeUsers))
// 使用分散式鎖
if err := userRepo.TryLock(ctx, user, cassandra.WithLockTTL(30*time.Second)); err != nil {
if cassandra.IsLockFailed(err) {
log.Println("獲取鎖失敗")
} else {
log.Printf("鎖操作失敗: %v", err)
}
return
}
defer userRepo.UnLock(ctx, user)
// 執行需要鎖定的操作
fmt.Println("執行需要鎖定的操作...")
// 刪除用戶
if err := userRepo.Delete(ctx, user.ID); err != nil {
log.Printf("刪除失敗: %v", err)
return
}
fmt.Println("操作完成")
}
```
## 測試
套件包含完整的測試覆蓋,包括:
- 單元測試table-driven tests
- 集成測試(使用 testcontainers
運行測試:
```bash
go test ./pkg/library/cassandra/...
```
查看測試覆蓋率:
```bash
go test ./pkg/library/cassandra/... -cover
```
## 授權
本專案遵循專案的主要授權協議。

View File

@ -0,0 +1,27 @@
package cassandra
import (
"time"
"github.com/gocql/gocql"
)
// 預設設定常數
const (
defaultNumConns = 10 // 預設每個節點的連線數量
defaultTimeoutSec = 10 // 預設連線逾時秒數
defaultMaxRetries = 3 // 預設重試次數
defaultPort = 9042
defaultConsistency = gocql.Quorum
defaultRetryMinInterval = 1 * time.Second
defaultRetryMaxInterval = 30 * time.Second
defaultReconnectInitialInterval = 1 * time.Second
defaultReconnectMaxInterval = 60 * time.Second
defaultCqlVersion = "3.0.0"
)
const (
DBFiledName = "db"
Pk = "partition_key"
ClusterKey = "clustering_key"
)

View File

@ -0,0 +1,158 @@
package cassandra
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v2"
)
// DB 是 Cassandra 的核心資料庫連接
type DB struct {
session gocqlx.Session
defaultKeyspace string
version string
saiSupported bool
}
// New 創建新的 DB 實例
func New(opts ...Option) (*DB, error) {
cfg := defaultConfig()
for _, opt := range opts {
opt(cfg)
}
if len(cfg.Hosts) == 0 {
return nil, fmt.Errorf("at least one host is required")
}
// 建立連線設定
cluster := gocql.NewCluster(cfg.Hosts...)
cluster.Port = cfg.Port
cluster.Consistency = cfg.Consistency
cluster.Timeout = time.Duration(cfg.ConnectTimeoutSec) * time.Second
cluster.NumConns = cfg.NumConns
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
NumRetries: cfg.MaxRetries,
Min: cfg.RetryMinInterval,
Max: cfg.RetryMaxInterval,
}
cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{
MaxRetries: cfg.MaxRetries,
InitialInterval: cfg.ReconnectInitialInterval,
MaxInterval: cfg.ReconnectMaxInterval,
}
// 若有提供 Keyspace 則指定
if cfg.Keyspace != "" {
cluster.Keyspace = cfg.Keyspace
}
// 若啟用驗證則設定帳號密碼
if cfg.UseAuth {
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: cfg.Username,
Password: cfg.Password,
}
}
// 建立 Session
session, err := gocqlx.WrapSession(cluster.CreateSession())
if err != nil {
return nil, fmt.Errorf("failed to connect to Cassandra cluster (hosts: %v, port: %d): %w", cfg.Hosts, cfg.Port, err)
}
db := &DB{
session: session,
defaultKeyspace: cfg.Keyspace,
}
// 初始化版本資訊
version, err := db.getVersion(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get DB version: %w", err)
}
db.version = version
db.saiSupported = isSAISupported(version)
return db, nil
}
// Close 關閉資料庫連線
func (db *DB) Close() {
db.session.Close()
}
// GetSession 返回底層的 gocqlx Session用於進階操作
func (db *DB) GetSession() gocqlx.Session {
return db.session
}
// GetDefaultKeyspace 返回預設的 keyspace
func (db *DB) GetDefaultKeyspace() string {
return db.defaultKeyspace
}
// Version 返回資料庫版本
func (db *DB) Version() string {
return db.version
}
// SaiSupported 返回是否支援 SAI
func (db *DB) SaiSupported() bool {
return db.saiSupported
}
// getVersion 獲取資料庫版本
func (db *DB) getVersion(ctx context.Context) (string, error) {
var version string
stmt := "SELECT release_version FROM system.local"
err := db.session.Query(stmt, []string{"release_version"}).
WithContext(ctx).
Consistency(gocql.One).
Scan(&version)
return version, err
}
// isSAISupported 檢查版本是否支援 SAI
func isSAISupported(version string) bool {
// 只要 major >=5 就支援
// 4.0.9+ 才有 SAI但不穩強烈建議 5.0+
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
major, _ := strconv.Atoi(parts[0])
minor, _ := strconv.Atoi(parts[1])
if major >= 5 {
return true
}
if major == 4 {
if minor > 0 { // 4.1.x、4.2.x 直接支援
return true
}
if minor == 0 {
patch := 0
if len(parts) >= 3 {
patch, _ = strconv.Atoi(parts[2])
}
if patch >= 9 {
return true
}
}
}
return false
}
// withContextAndTimestamp 為查詢添加 context 和時間戳
func (db *DB) withContextAndTimestamp(ctx context.Context, q *gocqlx.Queryx) *gocqlx.Queryx {
return q.WithContext(ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
}

View File

@ -0,0 +1,544 @@
package cassandra
import (
"errors"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsSAISupported(t *testing.T) {
tests := []struct {
name string
version string
expected bool
}{
{
name: "version 5.0.0 should support SAI",
version: "5.0.0",
expected: true,
},
{
name: "version 5.1.0 should support SAI",
version: "5.1.0",
expected: true,
},
{
name: "version 6.0.0 should support SAI",
version: "6.0.0",
expected: true,
},
{
name: "version 4.1.0 should support SAI",
version: "4.1.0",
expected: true,
},
{
name: "version 4.2.0 should support SAI",
version: "4.2.0",
expected: true,
},
{
name: "version 4.0.9 should support SAI",
version: "4.0.9",
expected: true,
},
{
name: "version 4.0.10 should support SAI",
version: "4.0.10",
expected: true,
},
{
name: "version 4.0.8 should not support SAI",
version: "4.0.8",
expected: false,
},
{
name: "version 4.0.0 should not support SAI",
version: "4.0.0",
expected: false,
},
{
name: "version 3.11.0 should not support SAI",
version: "3.11.0",
expected: false,
},
{
name: "invalid version format should not support SAI",
version: "invalid",
expected: false,
},
{
name: "empty version should not support SAI",
version: "",
expected: false,
},
{
name: "version with only major should not support SAI",
version: "5",
expected: false,
},
{
name: "version 4.0.9 with extra parts should support SAI",
version: "4.0.9.1",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isSAISupported(tt.version)
assert.Equal(t, tt.expected, result, "version %s should have SAI support = %v", tt.version, tt.expected)
})
}
}
func TestNew_Validation(t *testing.T) {
tests := []struct {
name string
opts []Option
wantErr bool
errMsg string
}{
{
name: "no hosts should return error",
opts: []Option{},
wantErr: true,
errMsg: "at least one host is required",
},
{
name: "empty hosts should return error",
opts: []Option{WithHosts()},
wantErr: true,
errMsg: "at least one host is required",
},
{
name: "valid hosts should not return error on validation",
opts: []Option{
WithHosts("localhost"),
},
wantErr: false,
},
{
name: "multiple hosts should not return error on validation",
opts: []Option{
WithHosts("localhost", "127.0.0.1"),
},
wantErr: false,
},
{
name: "with keyspace should not return error on validation",
opts: []Option{
WithHosts("localhost"),
WithKeyspace("test_keyspace"),
},
wantErr: false,
},
{
name: "with port should not return error on validation",
opts: []Option{
WithHosts("localhost"),
WithPort(9042),
},
wantErr: false,
},
{
name: "with auth should not return error on validation",
opts: []Option{
WithHosts("localhost"),
WithAuth("user", "pass"),
},
wantErr: false,
},
{
name: "with all options should not return error on validation",
opts: []Option{
WithHosts("localhost"),
WithKeyspace("test_keyspace"),
WithPort(9042),
WithAuth("user", "pass"),
WithConsistency(gocql.Quorum),
WithConnectTimeoutSec(10),
WithNumConns(10),
WithMaxRetries(3),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, err := New(tt.opts...)
if tt.wantErr {
require.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
assert.Nil(t, db)
} else {
// 注意:這裡可能會因為無法連接到真實的 Cassandra 而失敗
// 但至少驗證了配置驗證邏輯
if err != nil {
// 如果錯誤不是驗證錯誤,而是連接錯誤,這是可以接受的
assert.NotContains(t, err.Error(), "at least one host is required")
}
}
})
}
}
func TestDB_GetDefaultKeyspace(t *testing.T) {
tests := []struct {
name string
keyspace string
expectedResult string
}{
{
name: "empty keyspace should return empty string",
keyspace: "",
expectedResult: "",
},
{
name: "non-empty keyspace should return keyspace",
keyspace: "test_keyspace",
expectedResult: "test_keyspace",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,可能需要 mock 或使用 testcontainers
// 這裡只是展示測試結構
_ = tt
})
}
}
func TestDB_Version(t *testing.T) {
tests := []struct {
name string
version string
expected string
}{
{
name: "version 5.0.0",
version: "5.0.0",
expected: "5.0.0",
},
{
name: "version 4.0.9",
version: "4.0.9",
expected: "4.0.9",
},
{
name: "version 3.11.0",
version: "3.11.0",
expected: "3.11.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,可能需要 mock 或使用 testcontainers
_ = tt
})
}
}
func TestDB_SaiSupported(t *testing.T) {
tests := []struct {
name string
version string
expected bool
}{
{
name: "version 5.0.0 should support SAI",
version: "5.0.0",
expected: true,
},
{
name: "version 4.0.9 should support SAI",
version: "4.0.9",
expected: true,
},
{
name: "version 4.0.8 should not support SAI",
version: "4.0.8",
expected: false,
},
{
name: "version 3.11.0 should not support SAI",
version: "3.11.0",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,可能需要 mock 或使用 testcontainers
// 這裡只是展示測試結構
_ = tt
})
}
}
func TestDB_GetSession(t *testing.T) {
t.Run("GetSession should return non-nil session", func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,可能需要 mock 或使用 testcontainers
})
}
func TestDB_Close(t *testing.T) {
t.Run("Close should not panic", func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,可能需要 mock 或使用 testcontainers
})
}
func TestDB_getVersion(t *testing.T) {
tests := []struct {
name string
version string
queryErr error
wantErr bool
expectedVer string
}{
{
name: "successful version query",
version: "5.0.0",
queryErr: nil,
wantErr: false,
expectedVer: "5.0.0",
},
{
name: "query error should return error",
version: "",
queryErr: errors.New("connection failed"),
wantErr: true,
expectedVer: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}
func TestDB_withContextAndTimestamp(t *testing.T) {
t.Run("withContextAndTimestamp should add context and timestamp", func(t *testing.T) {
// 注意:這需要 mock query
// 在實際測試中,需要使用 mock
})
}
func TestDefaultConfig(t *testing.T) {
t.Run("defaultConfig should return valid config", func(t *testing.T) {
cfg := defaultConfig()
require.NotNil(t, cfg)
assert.Equal(t, defaultPort, cfg.Port)
assert.Equal(t, defaultConsistency, cfg.Consistency)
assert.Equal(t, defaultTimeoutSec, cfg.ConnectTimeoutSec)
assert.Equal(t, defaultNumConns, cfg.NumConns)
assert.Equal(t, defaultMaxRetries, cfg.MaxRetries)
assert.Equal(t, defaultRetryMinInterval, cfg.RetryMinInterval)
assert.Equal(t, defaultRetryMaxInterval, cfg.RetryMaxInterval)
assert.Equal(t, defaultReconnectInitialInterval, cfg.ReconnectInitialInterval)
assert.Equal(t, defaultReconnectMaxInterval, cfg.ReconnectMaxInterval)
assert.Equal(t, defaultCqlVersion, cfg.CQLVersion)
})
}
func TestOptionFunctions(t *testing.T) {
tests := []struct {
name string
opt Option
validateConfig func(*testing.T, *config)
}{
{
name: "WithHosts should set hosts",
opt: WithHosts("host1", "host2"),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, []string{"host1", "host2"}, c.Hosts)
},
},
{
name: "WithPort should set port",
opt: WithPort(9999),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, 9999, c.Port)
},
},
{
name: "WithKeyspace should set keyspace",
opt: WithKeyspace("test_keyspace"),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, "test_keyspace", c.Keyspace)
},
},
{
name: "WithAuth should set auth and enable UseAuth",
opt: WithAuth("user", "pass"),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, "user", c.Username)
assert.Equal(t, "pass", c.Password)
assert.True(t, c.UseAuth)
},
},
{
name: "WithConsistency should set consistency",
opt: WithConsistency(gocql.One),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, gocql.One, c.Consistency)
},
},
{
name: "WithConnectTimeoutSec should set timeout",
opt: WithConnectTimeoutSec(20),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, 20, c.ConnectTimeoutSec)
},
},
{
name: "WithConnectTimeoutSec with zero should use default",
opt: WithConnectTimeoutSec(0),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, defaultTimeoutSec, c.ConnectTimeoutSec)
},
},
{
name: "WithNumConns should set numConns",
opt: WithNumConns(20),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, 20, c.NumConns)
},
},
{
name: "WithNumConns with zero should use default",
opt: WithNumConns(0),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, defaultNumConns, c.NumConns)
},
},
{
name: "WithMaxRetries should set maxRetries",
opt: WithMaxRetries(5),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, 5, c.MaxRetries)
},
},
{
name: "WithMaxRetries with zero should use default",
opt: WithMaxRetries(0),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, defaultMaxRetries, c.MaxRetries)
},
},
{
name: "WithRetryMinInterval should set retryMinInterval",
opt: WithRetryMinInterval(2 * time.Second),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, 2*time.Second, c.RetryMinInterval)
},
},
{
name: "WithRetryMinInterval with zero should use default",
opt: WithRetryMinInterval(0),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, defaultRetryMinInterval, c.RetryMinInterval)
},
},
{
name: "WithRetryMaxInterval should set retryMaxInterval",
opt: WithRetryMaxInterval(60 * time.Second),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, 60*time.Second, c.RetryMaxInterval)
},
},
{
name: "WithRetryMaxInterval with zero should use default",
opt: WithRetryMaxInterval(0),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, defaultRetryMaxInterval, c.RetryMaxInterval)
},
},
{
name: "WithReconnectInitialInterval should set reconnectInitialInterval",
opt: WithReconnectInitialInterval(2 * time.Second),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, 2*time.Second, c.ReconnectInitialInterval)
},
},
{
name: "WithReconnectInitialInterval with zero should use default",
opt: WithReconnectInitialInterval(0),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, defaultReconnectInitialInterval, c.ReconnectInitialInterval)
},
},
{
name: "WithReconnectMaxInterval should set reconnectMaxInterval",
opt: WithReconnectMaxInterval(120 * time.Second),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, 120*time.Second, c.ReconnectMaxInterval)
},
},
{
name: "WithReconnectMaxInterval with zero should use default",
opt: WithReconnectMaxInterval(0),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, defaultReconnectMaxInterval, c.ReconnectMaxInterval)
},
},
{
name: "WithCQLVersion should set CQLVersion",
opt: WithCQLVersion("3.1.0"),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, "3.1.0", c.CQLVersion)
},
},
{
name: "WithCQLVersion with empty should use default",
opt: WithCQLVersion(""),
validateConfig: func(t *testing.T, c *config) {
assert.Equal(t, defaultCqlVersion, c.CQLVersion)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
tt.opt(cfg)
tt.validateConfig(t, cfg)
})
}
}
func TestMultipleOptions(t *testing.T) {
t.Run("multiple options should be applied correctly", func(t *testing.T) {
cfg := defaultConfig()
WithHosts("host1", "host2")(cfg)
WithPort(9999)(cfg)
WithKeyspace("test")(cfg)
WithAuth("user", "pass")(cfg)
assert.Equal(t, []string{"host1", "host2"}, cfg.Hosts)
assert.Equal(t, 9999, cfg.Port)
assert.Equal(t, "test", cfg.Keyspace)
assert.Equal(t, "user", cfg.Username)
assert.Equal(t, "pass", cfg.Password)
assert.True(t, cfg.UseAuth)
})
}

View File

@ -0,0 +1,151 @@
package cassandra
import (
"errors"
"fmt"
)
// ErrorCode 定義錯誤代碼
type ErrorCode string
const (
// ErrCodeNotFound 表示記錄未找到
ErrCodeNotFound ErrorCode = "NOT_FOUND"
// ErrCodeConflict 表示衝突(如唯一鍵衝突)
ErrCodeConflict ErrorCode = "CONFLICT"
// ErrCodeInvalidInput 表示輸入參數無效
ErrCodeInvalidInput ErrorCode = "INVALID_INPUT"
// ErrCodeMissingPartition 表示缺少 Partition Key
ErrCodeMissingPartition ErrorCode = "MISSING_PARTITION_KEY"
// ErrCodeNoFieldsToUpdate 表示沒有欄位需要更新
ErrCodeNoFieldsToUpdate ErrorCode = "NO_FIELDS_TO_UPDATE"
// ErrCodeMissingTableName 表示缺少 TableName 方法
ErrCodeMissingTableName ErrorCode = "MISSING_TABLE_NAME"
// ErrCodeMissingWhereCondition 表示缺少 WHERE 條件
ErrCodeMissingWhereCondition ErrorCode = "MISSING_WHERE_CONDITION"
// ErrCodeSAINotSupported 表示不支援 SAI
ErrCodeSAINotSupported ErrorCode = "SAI_NOT_SUPPORTED"
)
// Error 是統一的錯誤類型
type Error struct {
Code ErrorCode
Message string
Table string
Err error
}
// Error 實現 error 介面
func (e *Error) Error() string {
if e.Table != "" {
if e.Err != nil {
return fmt.Sprintf("cassandra[%s] (table: %s): %s: %v", e.Code, e.Table, e.Message, e.Err)
}
return fmt.Sprintf("cassandra[%s] (table: %s): %s", e.Code, e.Table, e.Message)
}
if e.Err != nil {
return fmt.Sprintf("cassandra[%s]: %s: %v", e.Code, e.Message, e.Err)
}
return fmt.Sprintf("cassandra[%s]: %s", e.Code, e.Message)
}
// Unwrap 返回底層錯誤
func (e *Error) Unwrap() error {
return e.Err
}
// WithTable 為錯誤添加表名資訊
func (e *Error) WithTable(table string) *Error {
return &Error{
Code: e.Code,
Message: e.Message,
Table: table,
Err: e.Err,
}
}
// WithError 為錯誤添加底層錯誤
func (e *Error) WithError(err error) *Error {
return &Error{
Code: e.Code,
Message: e.Message,
Table: e.Table,
Err: err,
}
}
// NewError 創建新的錯誤
func NewError(code ErrorCode, message string) *Error {
return &Error{
Code: code,
Message: message,
}
}
// 預定義錯誤
var (
// ErrNotFound 表示記錄未找到
ErrNotFound = &Error{
Code: ErrCodeNotFound,
Message: "record not found",
}
// ErrInvalidInput 表示輸入參數無效
ErrInvalidInput = &Error{
Code: ErrCodeInvalidInput,
Message: "invalid input parameter",
}
// ErrNoPartitionKey 表示缺少 Partition Key
ErrNoPartitionKey = &Error{
Code: ErrCodeMissingPartition,
Message: "no partition key defined in struct",
}
// ErrMissingTableName 表示缺少 TableName 方法
ErrMissingTableName = &Error{
Code: ErrCodeMissingTableName,
Message: "struct must implement TableName() method",
}
// ErrNoFieldsToUpdate 表示沒有欄位需要更新
ErrNoFieldsToUpdate = &Error{
Code: ErrCodeNoFieldsToUpdate,
Message: "no fields to update",
}
// ErrMissingWhereCondition 表示缺少 WHERE 條件
ErrMissingWhereCondition = &Error{
Code: ErrCodeMissingWhereCondition,
Message: "operation requires at least one WHERE condition for safety",
}
// ErrMissingPartitionKey 表示 WHERE 條件中缺少 Partition Key
ErrMissingPartitionKey = &Error{
Code: ErrCodeMissingPartition,
Message: "operation requires all partition keys in WHERE clause",
}
// ErrSAINotSupported 表示不支援 SAI
ErrSAINotSupported = &Error{
Code: ErrCodeSAINotSupported,
Message: "SAI (Storage-Attached Indexing) is not supported in this Cassandra version",
}
)
// IsNotFound 檢查錯誤是否為 NotFound
func IsNotFound(err error) bool {
var e *Error
if errors.As(err, &e) {
return e.Code == ErrCodeNotFound
}
return false
}
// IsConflict 檢查錯誤是否為 Conflict
func IsConflict(err error) bool {
var e *Error
if errors.As(err, &e) {
return e.Code == ErrCodeConflict
}
return false
}

View File

@ -0,0 +1,590 @@
package cassandra
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestError_Error(t *testing.T) {
tests := []struct {
name string
err *Error
want string
contains []string // 如果 want 為空,則檢查是否包含這些字串
}{
{
name: "error with code and message only",
err: &Error{
Code: ErrCodeNotFound,
Message: "record not found",
},
want: "cassandra[NOT_FOUND]: record not found",
},
{
name: "error with code, message and table",
err: &Error{
Code: ErrCodeNotFound,
Message: "record not found",
Table: "users",
},
want: "cassandra[NOT_FOUND] (table: users): record not found",
},
{
name: "error with code, message and underlying error",
err: &Error{
Code: ErrCodeInvalidInput,
Message: "invalid input parameter",
Err: errors.New("validation failed"),
},
contains: []string{
"cassandra[INVALID_INPUT]",
"invalid input parameter",
"validation failed",
},
},
{
name: "error with all fields",
err: &Error{
Code: ErrCodeConflict,
Message: "acquire lock failed",
Table: "locks",
Err: errors.New("lock already exists"),
},
contains: []string{
"cassandra[CONFLICT]",
"(table: locks)",
"acquire lock failed",
"lock already exists",
},
},
{
name: "error with empty message",
err: &Error{
Code: ErrCodeNotFound,
},
want: "cassandra[NOT_FOUND]: ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.err.Error()
if tt.want != "" {
assert.Equal(t, tt.want, result)
} else {
for _, substr := range tt.contains {
assert.Contains(t, result, substr)
}
}
})
}
}
func TestError_Unwrap(t *testing.T) {
tests := []struct {
name string
err *Error
wantErr error
}{
{
name: "error with underlying error",
err: &Error{
Code: ErrCodeInvalidInput,
Message: "invalid input",
Err: errors.New("underlying error"),
},
wantErr: errors.New("underlying error"),
},
{
name: "error without underlying error",
err: &Error{
Code: ErrCodeNotFound,
Message: "not found",
},
wantErr: nil,
},
{
name: "error with nil underlying error",
err: &Error{
Code: ErrCodeNotFound,
Message: "not found",
Err: nil,
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.err.Unwrap()
if tt.wantErr == nil {
assert.Nil(t, result)
} else {
assert.Equal(t, tt.wantErr.Error(), result.Error())
}
})
}
}
func TestError_WithTable(t *testing.T) {
tests := []struct {
name string
err *Error
table string
wantCode ErrorCode
wantMsg string
wantTbl string
}{
{
name: "add table to error without table",
err: &Error{
Code: ErrCodeNotFound,
Message: "record not found",
},
table: "users",
wantCode: ErrCodeNotFound,
wantMsg: "record not found",
wantTbl: "users",
},
{
name: "replace existing table",
err: &Error{
Code: ErrCodeNotFound,
Message: "record not found",
Table: "old_table",
},
table: "new_table",
wantCode: ErrCodeNotFound,
wantMsg: "record not found",
wantTbl: "new_table",
},
{
name: "add table to error with underlying error",
err: &Error{
Code: ErrCodeInvalidInput,
Message: "invalid input",
Err: errors.New("validation failed"),
},
table: "products",
wantCode: ErrCodeInvalidInput,
wantMsg: "invalid input",
wantTbl: "products",
},
{
name: "add empty table",
err: &Error{
Code: ErrCodeNotFound,
Message: "not found",
},
table: "",
wantCode: ErrCodeNotFound,
wantMsg: "not found",
wantTbl: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.err.WithTable(tt.table)
assert.NotNil(t, result)
assert.Equal(t, tt.wantCode, result.Code)
assert.Equal(t, tt.wantMsg, result.Message)
assert.Equal(t, tt.wantTbl, result.Table)
// 確保是新的實例,不是修改原來的
assert.NotSame(t, tt.err, result)
})
}
}
func TestError_WithError(t *testing.T) {
tests := []struct {
name string
err *Error
underlying error
wantCode ErrorCode
wantMsg string
wantErr error
}{
{
name: "add underlying error to error without error",
err: &Error{
Code: ErrCodeInvalidInput,
Message: "invalid input",
},
underlying: errors.New("validation failed"),
wantCode: ErrCodeInvalidInput,
wantMsg: "invalid input",
wantErr: errors.New("validation failed"),
},
{
name: "replace existing underlying error",
err: &Error{
Code: ErrCodeInvalidInput,
Message: "invalid input",
Err: errors.New("old error"),
},
underlying: errors.New("new error"),
wantCode: ErrCodeInvalidInput,
wantMsg: "invalid input",
wantErr: errors.New("new error"),
},
{
name: "add nil underlying error",
err: &Error{
Code: ErrCodeNotFound,
Message: "not found",
},
underlying: nil,
wantCode: ErrCodeNotFound,
wantMsg: "not found",
wantErr: nil,
},
{
name: "add error to error with table",
err: &Error{
Code: ErrCodeConflict,
Message: "conflict",
Table: "locks",
},
underlying: errors.New("lock exists"),
wantCode: ErrCodeConflict,
wantMsg: "conflict",
wantErr: errors.New("lock exists"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.err.WithError(tt.underlying)
assert.NotNil(t, result)
assert.Equal(t, tt.wantCode, result.Code)
assert.Equal(t, tt.wantMsg, result.Message)
// 確保是新的實例
assert.NotSame(t, tt.err, result)
// 檢查 underlying error
if tt.wantErr == nil {
assert.Nil(t, result.Err)
} else {
require.NotNil(t, result.Err)
assert.Equal(t, tt.wantErr.Error(), result.Err.Error())
}
})
}
}
func TestNewError(t *testing.T) {
tests := []struct {
name string
code ErrorCode
message string
want *Error
}{
{
name: "create NOT_FOUND error",
code: ErrCodeNotFound,
message: "record not found",
want: &Error{
Code: ErrCodeNotFound,
Message: "record not found",
},
},
{
name: "create CONFLICT error",
code: ErrCodeConflict,
message: "lock acquisition failed",
want: &Error{
Code: ErrCodeConflict,
Message: "lock acquisition failed",
},
},
{
name: "create INVALID_INPUT error",
code: ErrCodeInvalidInput,
message: "invalid parameter",
want: &Error{
Code: ErrCodeInvalidInput,
Message: "invalid parameter",
},
},
{
name: "create error with empty message",
code: ErrCodeNotFound,
message: "",
want: &Error{
Code: ErrCodeNotFound,
Message: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NewError(tt.code, tt.message)
assert.NotNil(t, result)
assert.Equal(t, tt.want.Code, result.Code)
assert.Equal(t, tt.want.Message, result.Message)
assert.Empty(t, result.Table)
assert.Nil(t, result.Err)
})
}
}
func TestIsNotFound(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "Error with NOT_FOUND code",
err: &Error{
Code: ErrCodeNotFound,
Message: "record not found",
},
want: true,
},
{
name: "Error with CONFLICT code",
err: &Error{
Code: ErrCodeConflict,
Message: "conflict",
},
want: false,
},
{
name: "Error with INVALID_INPUT code",
err: &Error{
Code: ErrCodeInvalidInput,
Message: "invalid input",
},
want: false,
},
{
name: "wrapped Error with NOT_FOUND code",
err: &Error{
Code: ErrCodeNotFound,
Message: "record not found",
Err: errors.New("underlying error"),
},
want: true,
},
{
name: "standard error",
err: errors.New("standard error"),
want: false,
},
{
name: "nil error",
err: nil,
want: false,
},
{
name: "predefined ErrNotFound",
err: ErrNotFound,
want: true,
},
{
name: "predefined ErrNotFound with table",
err: ErrNotFound.WithTable("users"),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsNotFound(tt.err)
assert.Equal(t, tt.want, result)
})
}
}
func TestIsConflict(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "Error with CONFLICT code",
err: &Error{
Code: ErrCodeConflict,
Message: "conflict",
},
want: true,
},
{
name: "Error with NOT_FOUND code",
err: &Error{
Code: ErrCodeNotFound,
Message: "record not found",
},
want: false,
},
{
name: "Error with INVALID_INPUT code",
err: &Error{
Code: ErrCodeInvalidInput,
Message: "invalid input",
},
want: false,
},
{
name: "wrapped Error with CONFLICT code",
err: &Error{
Code: ErrCodeConflict,
Message: "conflict",
Err: errors.New("underlying error"),
},
want: true,
},
{
name: "standard error",
err: errors.New("standard error"),
want: false,
},
{
name: "nil error",
err: nil,
want: false,
},
{
name: "NewError with CONFLICT code",
err: NewError(ErrCodeConflict, "lock failed"),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsConflict(tt.err)
assert.Equal(t, tt.want, result)
})
}
}
func TestPredefinedErrors(t *testing.T) {
tests := []struct {
name string
err *Error
wantCode ErrorCode
wantMsg string
}{
{
name: "ErrNotFound",
err: ErrNotFound,
wantCode: ErrCodeNotFound,
wantMsg: "record not found",
},
{
name: "ErrInvalidInput",
err: ErrInvalidInput,
wantCode: ErrCodeInvalidInput,
wantMsg: "invalid input parameter",
},
{
name: "ErrNoPartitionKey",
err: ErrNoPartitionKey,
wantCode: ErrCodeMissingPartition,
wantMsg: "no partition key defined in struct",
},
{
name: "ErrMissingTableName",
err: ErrMissingTableName,
wantCode: ErrCodeMissingTableName,
wantMsg: "struct must implement TableName() method",
},
{
name: "ErrNoFieldsToUpdate",
err: ErrNoFieldsToUpdate,
wantCode: ErrCodeNoFieldsToUpdate,
wantMsg: "no fields to update",
},
{
name: "ErrMissingWhereCondition",
err: ErrMissingWhereCondition,
wantCode: ErrCodeMissingWhereCondition,
wantMsg: "operation requires at least one WHERE condition for safety",
},
{
name: "ErrMissingPartitionKey",
err: ErrMissingPartitionKey,
wantCode: ErrCodeMissingPartition,
wantMsg: "operation requires all partition keys in WHERE clause",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.NotNil(t, tt.err)
assert.Equal(t, tt.wantCode, tt.err.Code)
assert.Equal(t, tt.wantMsg, tt.err.Message)
assert.Empty(t, tt.err.Table)
assert.Nil(t, tt.err.Err)
})
}
}
func TestError_Chaining(t *testing.T) {
t.Run("chain WithTable and WithError", func(t *testing.T) {
err := NewError(ErrCodeNotFound, "record not found").
WithTable("users").
WithError(errors.New("database error"))
assert.Equal(t, ErrCodeNotFound, err.Code)
assert.Equal(t, "record not found", err.Message)
assert.Equal(t, "users", err.Table)
assert.NotNil(t, err.Err)
assert.Equal(t, "database error", err.Err.Error())
assert.True(t, IsNotFound(err))
})
t.Run("chain multiple WithTable calls", func(t *testing.T) {
err1 := ErrNotFound.WithTable("table1")
err2 := err1.WithTable("table2")
assert.Equal(t, "table1", err1.Table)
assert.Equal(t, "table2", err2.Table)
assert.NotSame(t, err1, err2)
})
t.Run("chain multiple WithError calls", func(t *testing.T) {
err1 := ErrInvalidInput.WithError(errors.New("error1"))
err2 := err1.WithError(errors.New("error2"))
assert.Equal(t, "error1", err1.Err.Error())
assert.Equal(t, "error2", err2.Err.Error())
assert.NotSame(t, err1, err2)
})
}
func TestError_ErrorsAs(t *testing.T) {
t.Run("errors.As works with Error", func(t *testing.T) {
err := ErrNotFound.WithTable("users")
var target *Error
ok := errors.As(err, &target)
assert.True(t, ok)
assert.NotNil(t, target)
assert.Equal(t, ErrCodeNotFound, target.Code)
assert.Equal(t, "users", target.Table)
})
t.Run("errors.As works with wrapped Error", func(t *testing.T) {
underlying := errors.New("underlying error")
err := ErrInvalidInput.WithError(underlying)
var target *Error
ok := errors.As(err, &target)
assert.True(t, ok)
assert.NotNil(t, target)
assert.Equal(t, ErrCodeInvalidInput, target.Code)
assert.Equal(t, underlying, target.Err)
})
t.Run("errors.Is works with Error", func(t *testing.T) {
err := ErrNotFound
assert.True(t, errors.Is(err, ErrNotFound))
assert.False(t, errors.Is(err, ErrInvalidInput))
})
}

View File

@ -0,0 +1,120 @@
package cassandra
import (
"context"
"errors"
"fmt"
"time"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v2/qb"
)
const (
defaultLockTTLSec = 30
defaultLockRetry = 3
lockBaseDelay = 100 * time.Millisecond
)
// LockOption 用來設定 TryLock 的 TTL 行為
type LockOption func(*lockOptions)
type lockOptions struct {
ttlSeconds int // TTL單位秒<=0 代表不 expire
}
// WithLockTTL 設定鎖的 TTL
func WithLockTTL(d time.Duration) LockOption {
return func(o *lockOptions) {
o.ttlSeconds = int(d.Seconds())
}
}
// WithNoLockExpire 永不自動解鎖
func WithNoLockExpire() LockOption {
return func(o *lockOptions) {
o.ttlSeconds = 0
}
}
// TryLock 嘗試在表上插入一筆唯一鍵IF NOT EXISTS作為鎖
// 預設 30 秒 TTL可透過 option 調整或取消 TTL
func (r *repository[T]) TryLock(ctx context.Context, doc T, opts ...LockOption) error {
// 組合 option
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
for _, opt := range opts {
opt(options)
}
// 建 TTL 子句
builder := qb.Insert(r.table).
Unique(). // IF NOT EXISTS
Columns(r.metadata.Columns...)
if options.ttlSeconds > 0 {
ttl := time.Duration(options.ttlSeconds) * time.Second
builder = builder.TTL(ttl)
}
stmt, names := builder.ToCql()
// 執行 CAS
q := r.db.session.Query(stmt, names).BindStruct(doc).
WithContext(ctx).
WithTimestamp(time.Now().UnixNano() / 1e3).
SerialConsistency(gocql.Serial)
applied, err := q.ExecCASRelease()
if err != nil {
return ErrInvalidInput.WithTable(r.table).WithError(err)
}
if !applied {
return NewError(ErrCodeConflict, "acquire lock failed").WithTable(r.table)
}
return nil
}
// UnLock 釋放鎖,其實就是 Delete
func (r *repository[T]) UnLock(ctx context.Context, doc T) error {
var lastErr error
for i := 0; i < defaultLockRetry; i++ {
builder := qb.Delete(r.table).Existing()
// 動態添加 WHERE 條件(使用 Partition Key
for _, key := range r.metadata.PartKey {
builder = builder.Where(qb.Eq(key))
}
stmt, names := builder.ToCql()
q := r.db.session.Query(stmt, names).BindStruct(doc).
WithContext(ctx).
WithTimestamp(time.Now().UnixNano() / 1e3).
SerialConsistency(gocql.Serial)
applied, err := q.ExecCASRelease()
if err == nil && applied {
return nil
}
if err != nil {
lastErr = fmt.Errorf("unlock error: %w", err)
} else if !applied {
lastErr = fmt.Errorf("unlock not applied: row not found or not visible yet")
}
time.Sleep(lockBaseDelay * time.Duration(1<<i)) // 100ms → 200ms → 400ms
}
return ErrInvalidInput.WithTable(r.table).WithError(
fmt.Errorf("unlock failed after %d retries: %w", defaultLockRetry, lastErr),
)
}
// IsLockFailed 檢查錯誤是否為獲取鎖失敗
func IsLockFailed(err error) bool {
var e *Error
if errors.As(err, &e) {
return e.Code == ErrCodeConflict && e.Message == "acquire lock failed"
}
return false
}

View File

@ -0,0 +1,502 @@
package cassandra
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestWithLockTTL(t *testing.T) {
tests := []struct {
name string
duration time.Duration
wantTTL int
description string
}{
{
name: "30 seconds TTL",
duration: 30 * time.Second,
wantTTL: 30,
description: "should set TTL to 30 seconds",
},
{
name: "1 minute TTL",
duration: 1 * time.Minute,
wantTTL: 60,
description: "should set TTL to 60 seconds",
},
{
name: "5 minutes TTL",
duration: 5 * time.Minute,
wantTTL: 300,
description: "should set TTL to 300 seconds",
},
{
name: "1 hour TTL",
duration: 1 * time.Hour,
wantTTL: 3600,
description: "should set TTL to 3600 seconds",
},
{
name: "zero duration",
duration: 0,
wantTTL: 0,
description: "should set TTL to 0",
},
{
name: "negative duration",
duration: -10 * time.Second,
wantTTL: -10,
description: "should set TTL to negative value",
},
{
name: "fractional seconds",
duration: 1500 * time.Millisecond,
wantTTL: 1,
description: "should round down fractional seconds",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opt := WithLockTTL(tt.duration)
options := &lockOptions{}
opt(options)
assert.Equal(t, tt.wantTTL, options.ttlSeconds, tt.description)
})
}
}
func TestWithNoLockExpire(t *testing.T) {
t.Run("should set TTL to 0", func(t *testing.T) {
opt := WithNoLockExpire()
options := &lockOptions{ttlSeconds: 30} // 先設置一個值
opt(options)
assert.Equal(t, 0, options.ttlSeconds)
})
t.Run("should override existing TTL", func(t *testing.T) {
opt := WithNoLockExpire()
options := &lockOptions{ttlSeconds: 100}
opt(options)
assert.Equal(t, 0, options.ttlSeconds)
})
}
func TestLockOptions_Combination(t *testing.T) {
tests := []struct {
name string
opts []LockOption
wantTTL int
}{
{
name: "WithLockTTL then WithNoLockExpire",
opts: []LockOption{WithLockTTL(60 * time.Second), WithNoLockExpire()},
wantTTL: 0, // WithNoLockExpire should override
},
{
name: "WithNoLockExpire then WithLockTTL",
opts: []LockOption{WithNoLockExpire(), WithLockTTL(60 * time.Second)},
wantTTL: 60, // WithLockTTL should override
},
{
name: "multiple WithLockTTL calls",
opts: []LockOption{WithLockTTL(30 * time.Second), WithLockTTL(60 * time.Second)},
wantTTL: 60, // Last one wins
},
{
name: "multiple WithNoLockExpire calls",
opts: []LockOption{WithNoLockExpire(), WithNoLockExpire()},
wantTTL: 0,
},
{
name: "empty options should use default",
opts: []LockOption{},
wantTTL: defaultLockTTLSec,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
for _, opt := range tt.opts {
opt(options)
}
assert.Equal(t, tt.wantTTL, options.ttlSeconds)
})
}
}
func TestIsLockFailed(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "Error with CONFLICT code and correct message",
err: NewError(ErrCodeConflict, "acquire lock failed"),
want: true,
},
{
name: "Error with CONFLICT code and correct message with table",
err: NewError(ErrCodeConflict, "acquire lock failed").WithTable("locks"),
want: true,
},
{
name: "Error with CONFLICT code but wrong message",
err: NewError(ErrCodeConflict, "different message"),
want: false,
},
{
name: "Error with NOT_FOUND code and correct message",
err: NewError(ErrCodeNotFound, "acquire lock failed"),
want: false,
},
{
name: "Error with INVALID_INPUT code",
err: ErrInvalidInput,
want: false,
},
{
name: "wrapped Error with CONFLICT code and correct message",
err: NewError(ErrCodeConflict, "acquire lock failed").
WithError(errors.New("underlying error")),
want: true,
},
{
name: "standard error",
err: errors.New("standard error"),
want: false,
},
{
name: "nil error",
err: nil,
want: false,
},
{
name: "Error with CONFLICT code but empty message",
err: NewError(ErrCodeConflict, ""),
want: false,
},
{
name: "Error with CONFLICT code and similar but different message",
err: NewError(ErrCodeConflict, "acquire lock failed!"),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsLockFailed(tt.err)
assert.Equal(t, tt.want, result)
})
}
}
func TestLockConstants(t *testing.T) {
tests := []struct {
name string
constant interface{}
expected interface{}
}{
{
name: "defaultLockTTLSec should be 30",
constant: defaultLockTTLSec,
expected: 30,
},
{
name: "defaultLockRetry should be 3",
constant: defaultLockRetry,
expected: 3,
},
{
name: "lockBaseDelay should be 100ms",
constant: lockBaseDelay,
expected: 100 * time.Millisecond,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.constant)
})
}
}
func TestLockOptions_DefaultValues(t *testing.T) {
t.Run("default lockOptions should have default TTL", func(t *testing.T) {
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
assert.Equal(t, defaultLockTTLSec, options.ttlSeconds)
})
t.Run("lockOptions with zero TTL", func(t *testing.T) {
options := &lockOptions{ttlSeconds: 0}
assert.Equal(t, 0, options.ttlSeconds)
})
t.Run("lockOptions with negative TTL", func(t *testing.T) {
options := &lockOptions{ttlSeconds: -1}
assert.Equal(t, -1, options.ttlSeconds)
})
}
func TestTryLock_ErrorScenarios(t *testing.T) {
tests := []struct {
name string
description string
// 注意:實際的 TryLock 測試需要 mock session 或實際的資料庫連接
// 這裡只是定義測試結構
}{
{
name: "successful lock acquisition",
description: "should return nil when lock is successfully acquired",
},
{
name: "lock already exists",
description: "should return CONFLICT error when lock already exists",
},
{
name: "database error",
description: "should return INVALID_INPUT error with underlying error when database operation fails",
},
{
name: "context cancellation",
description: "should respect context cancellation",
},
{
name: "with custom TTL",
description: "should use custom TTL when provided",
},
{
name: "with no expire",
description: "should not set TTL when WithNoLockExpire is used",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}
func TestUnLock_ErrorScenarios(t *testing.T) {
tests := []struct {
name string
description string
// 注意:實際的 UnLock 測試需要 mock session 或實際的資料庫連接
// 這裡只是定義測試結構
}{
{
name: "successful unlock",
description: "should return nil when lock is successfully released",
},
{
name: "lock not found",
description: "should retry when lock is not found",
},
{
name: "database error",
description: "should retry on database error",
},
{
name: "max retries exceeded",
description: "should return error after max retries",
},
{
name: "context cancellation",
description: "should respect context cancellation",
},
{
name: "exponential backoff",
description: "should use exponential backoff between retries",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}
func TestLockOption_Type(t *testing.T) {
t.Run("WithLockTTL should return LockOption", func(t *testing.T) {
opt := WithLockTTL(30 * time.Second)
assert.NotNil(t, opt)
// 驗證它是一個函數
var lockOpt LockOption = opt
assert.NotNil(t, lockOpt)
})
t.Run("WithNoLockExpire should return LockOption", func(t *testing.T) {
opt := WithNoLockExpire()
assert.NotNil(t, opt)
// 驗證它是一個函數
var lockOpt LockOption = opt
assert.NotNil(t, lockOpt)
})
}
func TestLockOptions_ApplyOrder(t *testing.T) {
t.Run("last option should win", func(t *testing.T) {
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
WithLockTTL(60 * time.Second)(options)
assert.Equal(t, 60, options.ttlSeconds)
WithNoLockExpire()(options)
assert.Equal(t, 0, options.ttlSeconds)
WithLockTTL(120 * time.Second)(options)
assert.Equal(t, 120, options.ttlSeconds)
})
}
func TestIsLockFailed_EdgeCases(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "Error with CONFLICT code, correct message, and underlying error",
err: NewError(ErrCodeConflict, "acquire lock failed").
WithTable("locks").
WithError(errors.New("database error")),
want: true,
},
{
name: "Error with CONFLICT code but message with extra spaces",
err: NewError(ErrCodeConflict, " acquire lock failed "),
want: false,
},
{
name: "Error with CONFLICT code but message with different case",
err: NewError(ErrCodeConflict, "Acquire Lock Failed"),
want: false,
},
{
name: "chained errors with CONFLICT",
err: func() error {
err1 := NewError(ErrCodeConflict, "acquire lock failed")
err2 := errors.New("wrapped")
return errors.Join(err1, err2)
}(),
want: true, // errors.Join preserves Error type and errors.As can find it
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsLockFailed(tt.err)
assert.Equal(t, tt.want, result)
})
}
}
func TestLockOptions_ZeroValue(t *testing.T) {
t.Run("zero value lockOptions", func(t *testing.T) {
var options lockOptions
assert.Equal(t, 0, options.ttlSeconds)
})
t.Run("apply option to zero value", func(t *testing.T) {
var options lockOptions
WithLockTTL(30 * time.Second)(&options)
assert.Equal(t, 30, options.ttlSeconds)
})
}
func TestLockRetryDelay(t *testing.T) {
t.Run("verify exponential backoff calculation", func(t *testing.T) {
// 驗證重試延遲的計算邏輯
// 100ms → 200ms → 400ms
expectedDelays := []time.Duration{
lockBaseDelay * time.Duration(1<<0), // 100ms * 1 = 100ms
lockBaseDelay * time.Duration(1<<1), // 100ms * 2 = 200ms
lockBaseDelay * time.Duration(1<<2), // 100ms * 4 = 400ms
}
assert.Equal(t, 100*time.Millisecond, expectedDelays[0])
assert.Equal(t, 200*time.Millisecond, expectedDelays[1])
assert.Equal(t, 400*time.Millisecond, expectedDelays[2])
})
}
func TestLockOption_InterfaceCompliance(t *testing.T) {
t.Run("LockOption should be a function type", func(t *testing.T) {
// 驗證 LockOption 是一個函數類型
var fn func(*lockOptions) = WithLockTTL(30 * time.Second)
assert.NotNil(t, fn)
})
t.Run("LockOption can be assigned from WithLockTTL", func(t *testing.T) {
var opt LockOption = WithLockTTL(30 * time.Second)
assert.NotNil(t, opt)
})
t.Run("LockOption can be assigned from WithNoLockExpire", func(t *testing.T) {
var opt LockOption = WithNoLockExpire()
assert.NotNil(t, opt)
})
}
func TestLockOptions_RealWorldScenarios(t *testing.T) {
tests := []struct {
name string
scenario func(*lockOptions)
wantTTL int
}{
{
name: "short-lived lock (5 seconds)",
scenario: func(o *lockOptions) {
WithLockTTL(5 * time.Second)(o)
},
wantTTL: 5,
},
{
name: "medium-lived lock (5 minutes)",
scenario: func(o *lockOptions) {
WithLockTTL(5 * time.Minute)(o)
},
wantTTL: 300,
},
{
name: "long-lived lock (1 hour)",
scenario: func(o *lockOptions) {
WithLockTTL(1 * time.Hour)(o)
},
wantTTL: 3600,
},
{
name: "permanent lock",
scenario: func(o *lockOptions) {
WithNoLockExpire()(o)
},
wantTTL: 0,
},
{
name: "default lock",
scenario: func(o *lockOptions) {
// 不應用任何選項,使用預設值
},
wantTTL: defaultLockTTLSec,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
tt.scenario(options)
assert.Equal(t, tt.wantTTL, options.ttlSeconds)
})
}
}

View File

@ -0,0 +1,136 @@
package cassandra
import (
"fmt"
"reflect"
"sync"
"unicode"
"github.com/scylladb/gocqlx/v2/table"
)
var (
// metadataCache 快取已生成的 Metadata避免重複反射解析
// key: tableName + ":" + structType (不包含 keyspace因為同一個 struct 在不同 keyspace 結構相同)
metadataCache sync.Map
)
type cachedMetadata struct {
columns []string
partKeys []string
sortKeys []string
err error
}
// generateMetadata 根據傳入的 struct 產生 table.Metadata
// 使用快取機制避免重複反射解析,提升效能
func generateMetadata[T Table](doc T, keyspace string) (table.Metadata, error) {
// 取得型別資訊
t := reflect.TypeOf(doc)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// 取得表名稱
tableName := doc.TableName()
if tableName == "" {
return table.Metadata{}, ErrMissingTableName
}
// 構建快取 key: tableName:structType (不包含 keyspace)
cacheKey := fmt.Sprintf("%s:%s", tableName, t.String())
// 檢查快取
if cached, ok := metadataCache.Load(cacheKey); ok {
cachedMeta := cached.(cachedMetadata)
if cachedMeta.err != nil {
return table.Metadata{}, cachedMeta.err
}
// 從快取構建 metadata動態加上 keyspace
meta := table.Metadata{
Name: fmt.Sprintf("%s.%s", keyspace, tableName),
Columns: make([]string, len(cachedMeta.columns)),
PartKey: make([]string, len(cachedMeta.partKeys)),
SortKey: make([]string, len(cachedMeta.sortKeys)),
}
copy(meta.Columns, cachedMeta.columns)
copy(meta.PartKey, cachedMeta.partKeys)
copy(meta.SortKey, cachedMeta.sortKeys)
return meta, nil
}
// 快取未命中,生成 metadata
columns := make([]string, 0, t.NumField())
partKeys := make([]string, 0, t.NumField())
sortKeys := make([]string, 0, t.NumField())
// 遍歷所有 exported 欄位
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// 跳過 unexported 欄位
if field.PkgPath != "" {
continue
}
// 如果欄位有標記 db:"-" 則跳過
if tag := field.Tag.Get(DBFiledName); tag == "-" {
continue
}
// 取得欄位名稱
colName := field.Tag.Get(DBFiledName)
if colName == "" {
colName = toSnakeCase(field.Name)
}
columns = append(columns, colName)
// 若有 partition_key:"true" 標記,加入 PartKey
if field.Tag.Get(Pk) == "true" {
partKeys = append(partKeys, colName)
}
// 若有 clustering_key:"true" 標記,加入 SortKey
if field.Tag.Get(ClusterKey) == "true" {
sortKeys = append(sortKeys, colName)
}
}
if len(partKeys) == 0 {
err := ErrNoPartitionKey
// 快取錯誤結果
metadataCache.Store(cacheKey, cachedMetadata{err: err})
return table.Metadata{}, err
}
// 快取成功結果(只存結構資訊,不包含 keyspace
cachedMeta := cachedMetadata{
columns: make([]string, len(columns)),
partKeys: make([]string, len(partKeys)),
sortKeys: make([]string, len(sortKeys)),
}
copy(cachedMeta.columns, columns)
copy(cachedMeta.partKeys, partKeys)
copy(cachedMeta.sortKeys, sortKeys)
metadataCache.Store(cacheKey, cachedMeta)
// 組合並返回 Metadata包含 keyspace
meta := table.Metadata{
Name: fmt.Sprintf("%s.%s", keyspace, tableName),
Columns: columns,
PartKey: partKeys,
SortKey: sortKeys,
}
return meta, nil
}
// toSnakeCase 將 CamelCase 字串轉換為 snake_case
func toSnakeCase(s string) string {
var result []rune
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
result = append(result, '_')
}
result = append(result, unicode.ToLower(r))
} else {
result = append(result, r)
}
}
return string(result)
}

View File

@ -0,0 +1,500 @@
package cassandra
import (
"testing"
"github.com/scylladb/gocqlx/v2/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToSnakeCase(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple CamelCase",
input: "UserName",
expected: "user_name",
},
{
name: "single word",
input: "User",
expected: "user",
},
{
name: "multiple words",
input: "UserAccountBalance",
expected: "user_account_balance",
},
{
name: "already lowercase",
input: "username",
expected: "username",
},
{
name: "all uppercase",
input: "USERNAME",
expected: "u_s_e_r_n_a_m_e",
},
{
name: "mixed case",
input: "XMLParser",
expected: "x_m_l_parser",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "single character",
input: "A",
expected: "a",
},
{
name: "with numbers",
input: "UserID123",
expected: "user_i_d123",
},
{
name: "ID at end",
input: "UserID",
expected: "user_i_d",
},
{
name: "ID at start",
input: "IDUser",
expected: "i_d_user",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := toSnakeCase(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// 測試用的 struct 定義
type testUser struct {
ID string `db:"id" partition_key:"true"`
Name string `db:"name"`
Email string `db:"email"`
CreatedAt int64 `db:"created_at"`
}
func (t testUser) TableName() string {
return "users"
}
type testUserNoTableName struct {
ID string `db:"id" partition_key:"true"`
}
func (t testUserNoTableName) TableName() string {
return ""
}
type testUserNoPartitionKey struct {
ID string `db:"id"`
Name string `db:"name"`
}
func (t testUserNoPartitionKey) TableName() string {
return "users"
}
type testUserWithClusteringKey struct {
ID string `db:"id" partition_key:"true"`
Timestamp int64 `db:"timestamp" clustering_key:"true"`
Data string `db:"data"`
}
func (t testUserWithClusteringKey) TableName() string {
return "events"
}
type testUserWithMultiplePartitionKeys struct {
UserID string `db:"user_id" partition_key:"true"`
AccountID string `db:"account_id" partition_key:"true"`
Balance int64 `db:"balance"`
}
func (t testUserWithMultiplePartitionKeys) TableName() string {
return "accounts"
}
type testUserWithAutoSnakeCase struct {
UserID string `db:"user_id" partition_key:"true"`
AccountName string // 沒有 db tag應該自動轉換為 snake_case
EmailAddr string `db:"email_addr"`
}
func (t testUserWithAutoSnakeCase) TableName() string {
return "profiles"
}
type testUserWithIgnoredField struct {
ID string `db:"id" partition_key:"true"`
Name string `db:"name"`
Password string `db:"-"` // 應該被忽略
CreatedAt int64 `db:"created_at"`
}
func (t testUserWithIgnoredField) TableName() string {
return "users"
}
type testUserUnexported struct {
ID string `db:"id" partition_key:"true"`
name string // unexported應該被忽略
Email string `db:"email"`
createdAt int64 // unexported應該被忽略
}
func (t testUserUnexported) TableName() string {
return "users"
}
type testUserPointer struct {
ID *string `db:"id" partition_key:"true"`
Name string `db:"name"`
}
func (t testUserPointer) TableName() string {
return "users"
}
func TestGenerateMetadata_Basic(t *testing.T) {
tests := []struct {
name string
doc interface{}
keyspace string
wantErr bool
errCode ErrorCode
checkFunc func(*testing.T, table.Metadata, string)
}{
{
name: "valid user struct",
doc: testUser{ID: "1", Name: "Alice"},
keyspace: "test_keyspace",
wantErr: false,
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
assert.Equal(t, keyspace+".users", meta.Name)
assert.Contains(t, meta.Columns, "id")
assert.Contains(t, meta.Columns, "name")
assert.Contains(t, meta.Columns, "email")
assert.Contains(t, meta.Columns, "created_at")
assert.Contains(t, meta.PartKey, "id")
assert.Empty(t, meta.SortKey)
},
},
{
name: "user with clustering key",
doc: testUserWithClusteringKey{ID: "1", Timestamp: 1234567890},
keyspace: "events_db",
wantErr: false,
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
assert.Equal(t, keyspace+".events", meta.Name)
assert.Contains(t, meta.PartKey, "id")
assert.Contains(t, meta.SortKey, "timestamp")
assert.Contains(t, meta.Columns, "data")
},
},
{
name: "user with multiple partition keys",
doc: testUserWithMultiplePartitionKeys{UserID: "1", AccountID: "2"},
keyspace: "finance",
wantErr: false,
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
assert.Equal(t, keyspace+".accounts", meta.Name)
assert.Contains(t, meta.PartKey, "user_id")
assert.Contains(t, meta.PartKey, "account_id")
assert.Len(t, meta.PartKey, 2)
},
},
{
name: "user with auto snake_case conversion",
doc: testUserWithAutoSnakeCase{UserID: "1", AccountName: "test"},
keyspace: "test",
wantErr: false,
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
assert.Contains(t, meta.Columns, "account_name") // 自動轉換
assert.Contains(t, meta.Columns, "user_id")
assert.Contains(t, meta.Columns, "email_addr")
},
},
{
name: "user with ignored field",
doc: testUserWithIgnoredField{ID: "1", Name: "Alice"},
keyspace: "test",
wantErr: false,
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
assert.Contains(t, meta.Columns, "id")
assert.Contains(t, meta.Columns, "name")
assert.Contains(t, meta.Columns, "created_at")
assert.NotContains(t, meta.Columns, "password") // 應該被忽略
},
},
{
name: "user with unexported fields",
doc: testUserUnexported{ID: "1", Email: "test@example.com"},
keyspace: "test",
wantErr: false,
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
assert.Contains(t, meta.Columns, "id")
assert.Contains(t, meta.Columns, "email")
assert.NotContains(t, meta.Columns, "name") // unexported
assert.NotContains(t, meta.Columns, "created_at") // unexported
},
},
{
name: "user pointer type",
doc: &testUserPointer{ID: stringPtr("1"), Name: "Alice"},
keyspace: "test",
wantErr: false,
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
assert.Equal(t, keyspace+".users", meta.Name)
assert.Contains(t, meta.Columns, "id")
assert.Contains(t, meta.Columns, "name")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var meta table.Metadata
var err error
switch doc := tt.doc.(type) {
case testUser:
meta, err = generateMetadata(doc, tt.keyspace)
case testUserWithClusteringKey:
meta, err = generateMetadata(doc, tt.keyspace)
case testUserWithMultiplePartitionKeys:
meta, err = generateMetadata(doc, tt.keyspace)
case testUserWithAutoSnakeCase:
meta, err = generateMetadata(doc, tt.keyspace)
case testUserWithIgnoredField:
meta, err = generateMetadata(doc, tt.keyspace)
case testUserUnexported:
meta, err = generateMetadata(doc, tt.keyspace)
case *testUserPointer:
meta, err = generateMetadata(*doc, tt.keyspace)
default:
t.Fatalf("unsupported type: %T", doc)
}
if tt.wantErr {
require.Error(t, err)
if tt.errCode != "" {
var e *Error
if assert.ErrorAs(t, err, &e) {
assert.Equal(t, tt.errCode, e.Code)
}
}
} else {
require.NoError(t, err)
if tt.checkFunc != nil {
tt.checkFunc(t, meta, tt.keyspace)
}
}
})
}
}
func TestGenerateMetadata_ErrorCases(t *testing.T) {
tests := []struct {
name string
doc interface{}
keyspace string
wantErr bool
errCode ErrorCode
}{
{
name: "missing table name",
doc: testUserNoTableName{ID: "1"},
keyspace: "test",
wantErr: true,
errCode: ErrCodeMissingTableName,
},
{
name: "missing partition key",
doc: testUserNoPartitionKey{ID: "1", Name: "Alice"},
keyspace: "test",
wantErr: true,
errCode: ErrCodeMissingPartition,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
switch doc := tt.doc.(type) {
case testUserNoTableName:
_, err = generateMetadata(doc, tt.keyspace)
case testUserNoPartitionKey:
_, err = generateMetadata(doc, tt.keyspace)
default:
t.Fatalf("unsupported type: %T", doc)
}
if tt.wantErr {
require.Error(t, err)
if tt.errCode != "" {
var e *Error
if assert.ErrorAs(t, err, &e) {
assert.Equal(t, tt.errCode, e.Code)
}
}
} else {
require.NoError(t, err)
}
})
}
}
func TestGenerateMetadata_Cache(t *testing.T) {
t.Run("cache hit for same struct type", func(t *testing.T) {
doc1 := testUser{ID: "1", Name: "Alice"}
meta1, err1 := generateMetadata(doc1, "keyspace1")
require.NoError(t, err1)
// 使用不同的 keyspace但應該從快取獲取不包含 keyspace
doc2 := testUser{ID: "2", Name: "Bob"}
meta2, err2 := generateMetadata(doc2, "keyspace2")
require.NoError(t, err2)
// 驗證結構相同,但 keyspace 不同
assert.Equal(t, "keyspace1.users", meta1.Name)
assert.Equal(t, "keyspace2.users", meta2.Name)
assert.Equal(t, meta1.Columns, meta2.Columns)
assert.Equal(t, meta1.PartKey, meta2.PartKey)
assert.Equal(t, meta1.SortKey, meta2.SortKey)
})
t.Run("cache hit for error case", func(t *testing.T) {
doc1 := testUserNoPartitionKey{ID: "1", Name: "Alice"}
_, err1 := generateMetadata(doc1, "keyspace1")
require.Error(t, err1)
// 第二次調用應該從快取獲取錯誤
doc2 := testUserNoPartitionKey{ID: "2", Name: "Bob"}
_, err2 := generateMetadata(doc2, "keyspace2")
require.Error(t, err2)
// 錯誤應該相同
assert.Equal(t, err1.Error(), err2.Error())
})
t.Run("cache miss for different struct type", func(t *testing.T) {
doc1 := testUser{ID: "1"}
meta1, err1 := generateMetadata(doc1, "test")
require.NoError(t, err1)
doc2 := testUserWithClusteringKey{ID: "1", Timestamp: 123}
meta2, err2 := generateMetadata(doc2, "test")
require.NoError(t, err2)
// 應該是不同的 metadata
assert.NotEqual(t, meta1.Name, meta2.Name)
assert.NotEqual(t, meta1.Columns, meta2.Columns)
})
}
func TestGenerateMetadata_DifferentKeyspaces(t *testing.T) {
t.Run("same struct with different keyspaces", func(t *testing.T) {
doc := testUser{ID: "1", Name: "Alice"}
meta1, err1 := generateMetadata(doc, "keyspace1")
require.NoError(t, err1)
meta2, err2 := generateMetadata(doc, "keyspace2")
require.NoError(t, err2)
// 結構應該相同,但 keyspace 不同
assert.Equal(t, "keyspace1.users", meta1.Name)
assert.Equal(t, "keyspace2.users", meta2.Name)
assert.Equal(t, meta1.Columns, meta2.Columns)
assert.Equal(t, meta1.PartKey, meta2.PartKey)
})
}
func TestGenerateMetadata_EmptyKeyspace(t *testing.T) {
t.Run("empty keyspace", func(t *testing.T) {
doc := testUser{ID: "1", Name: "Alice"}
meta, err := generateMetadata(doc, "")
require.NoError(t, err)
assert.Equal(t, ".users", meta.Name)
})
}
func TestGenerateMetadata_PointerVsValue(t *testing.T) {
t.Run("pointer and value should produce same metadata", func(t *testing.T) {
doc1 := testUser{ID: "1", Name: "Alice"}
meta1, err1 := generateMetadata(doc1, "test")
require.NoError(t, err1)
doc2 := &testUser{ID: "2", Name: "Bob"}
meta2, err2 := generateMetadata(*doc2, "test")
require.NoError(t, err2)
// 應該產生相同的 metadata除了可能的值不同
assert.Equal(t, meta1.Name, meta2.Name)
assert.Equal(t, meta1.Columns, meta2.Columns)
assert.Equal(t, meta1.PartKey, meta2.PartKey)
})
}
func TestGenerateMetadata_ColumnOrder(t *testing.T) {
t.Run("columns should maintain struct field order", func(t *testing.T) {
doc := testUser{ID: "1", Name: "Alice", Email: "alice@example.com"}
meta, err := generateMetadata(doc, "test")
require.NoError(t, err)
// 驗證欄位順序(根據 struct 定義)
assert.Equal(t, "id", meta.Columns[0])
assert.Equal(t, "name", meta.Columns[1])
assert.Equal(t, "email", meta.Columns[2])
assert.Equal(t, "created_at", meta.Columns[3])
})
}
func TestGenerateMetadata_AllTagCombinations(t *testing.T) {
type testAllTags struct {
PartitionKey string `db:"partition_key" partition_key:"true"`
ClusteringKey string `db:"clustering_key" clustering_key:"true"`
RegularField string `db:"regular_field"`
AutoSnakeCase string // 沒有 db tag
IgnoredField string `db:"-"`
unexportedField string // unexported
}
var testAllTagsTableName = "all_tags"
testAllTagsTableNameFunc := func() string { return testAllTagsTableName }
// 使用反射來動態設置 TableName 方法
// 但由於 Go 的限制,我們需要一個實際的方法
// 這裡我們創建一個包裝類型
type testAllTagsWrapper struct {
testAllTags
}
// 這個方法無法在運行時添加,所以我們需要一個實際的實現
// 讓我們使用一個不同的方法
t.Run("all tag combinations", func(t *testing.T) {
// 由於無法動態添加方法,我們跳過這個測試
// 或者創建一個實際的 struct
_ = testAllTagsWrapper{}
_ = testAllTagsTableNameFunc
})
}
// 輔助函數
func stringPtr(s string) *string {
return &s
}

View File

@ -0,0 +1,162 @@
package cassandra
import (
"time"
"github.com/gocql/gocql"
)
// config 是初始化 DB 所需的內部設定(私有)
type config struct {
Hosts []string // Cassandra 主機列表
Port int // 連線埠
Keyspace string // 預設使用的 Keyspace
Username string // 認證用戶名
Password string // 認證密碼
Consistency gocql.Consistency // 一致性級別
ConnectTimeoutSec int // 連線逾時秒數
NumConns int // 每個節點連線數
MaxRetries int // 重試次數
UseAuth bool // 是否使用帳號密碼驗證
RetryMinInterval time.Duration // 重試間隔最小值
RetryMaxInterval time.Duration // 重試間隔最大值
ReconnectInitialInterval time.Duration // 重連初始間隔
ReconnectMaxInterval time.Duration // 重連最大間隔
CQLVersion string // 執行連線的CQL 版本號
}
// defaultConfig 返回預設配置
func defaultConfig() *config {
return &config{
Port: defaultPort,
Consistency: defaultConsistency,
ConnectTimeoutSec: defaultTimeoutSec,
NumConns: defaultNumConns,
MaxRetries: defaultMaxRetries,
RetryMinInterval: defaultRetryMinInterval,
RetryMaxInterval: defaultRetryMaxInterval,
ReconnectInitialInterval: defaultReconnectInitialInterval,
ReconnectMaxInterval: defaultReconnectMaxInterval,
CQLVersion: defaultCqlVersion,
}
}
// Option 是設定選項的函數型別
type Option func(*config)
// WithHosts 設定 Cassandra 主機列表
func WithHosts(hosts ...string) Option {
return func(c *config) {
c.Hosts = hosts
}
}
// WithPort 設定連線埠
func WithPort(port int) Option {
return func(c *config) {
c.Port = port
}
}
// WithKeyspace 設定預設 keyspace
func WithKeyspace(keyspace string) Option {
return func(c *config) {
c.Keyspace = keyspace
}
}
// WithAuth 設定認證資訊
func WithAuth(username, password string) Option {
return func(c *config) {
c.Username = username
c.Password = password
c.UseAuth = true
}
}
// WithConsistency 設定一致性級別
func WithConsistency(consistency gocql.Consistency) Option {
return func(c *config) {
c.Consistency = consistency
}
}
// WithConnectTimeoutSec 設定連線逾時秒數
func WithConnectTimeoutSec(timeout int) Option {
return func(c *config) {
if timeout <= 0 {
timeout = defaultTimeoutSec
}
c.ConnectTimeoutSec = timeout
}
}
// WithNumConns 設定每個節點的連線數
func WithNumConns(numConns int) Option {
return func(c *config) {
if numConns <= 0 {
numConns = defaultNumConns
}
c.NumConns = numConns
}
}
// WithMaxRetries 設定最大重試次數
func WithMaxRetries(maxRetries int) Option {
return func(c *config) {
if maxRetries <= 0 {
maxRetries = defaultMaxRetries
}
c.MaxRetries = maxRetries
}
}
// WithRetryMinInterval 設定最小重試間隔
func WithRetryMinInterval(duration time.Duration) Option {
return func(c *config) {
if duration <= 0 {
duration = defaultRetryMinInterval
}
c.RetryMinInterval = duration
}
}
// WithRetryMaxInterval 設定最大重試間隔
func WithRetryMaxInterval(duration time.Duration) Option {
return func(c *config) {
if duration <= 0 {
duration = defaultRetryMaxInterval
}
c.RetryMaxInterval = duration
}
}
// WithReconnectInitialInterval 設定初始重連間隔
func WithReconnectInitialInterval(duration time.Duration) Option {
return func(c *config) {
if duration <= 0 {
duration = defaultReconnectInitialInterval
}
c.ReconnectInitialInterval = duration
}
}
// WithReconnectMaxInterval 設定最大重連間隔
func WithReconnectMaxInterval(duration time.Duration) Option {
return func(c *config) {
if duration <= 0 {
duration = defaultReconnectMaxInterval
}
c.ReconnectMaxInterval = duration
}
}
// WithCQLVersion 設定 CQL 版本
func WithCQLVersion(version string) Option {
return func(c *config) {
if version == "" {
version = defaultCqlVersion
}
c.CQLVersion = version
}
}

View File

@ -0,0 +1,962 @@
package cassandra
import (
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOption_DefaultConfig(t *testing.T) {
t.Run("defaultConfig should return valid config with all defaults", func(t *testing.T) {
cfg := defaultConfig()
require.NotNil(t, cfg)
assert.Equal(t, defaultPort, cfg.Port)
assert.Equal(t, defaultConsistency, cfg.Consistency)
assert.Equal(t, defaultTimeoutSec, cfg.ConnectTimeoutSec)
assert.Equal(t, defaultNumConns, cfg.NumConns)
assert.Equal(t, defaultMaxRetries, cfg.MaxRetries)
assert.Equal(t, defaultRetryMinInterval, cfg.RetryMinInterval)
assert.Equal(t, defaultRetryMaxInterval, cfg.RetryMaxInterval)
assert.Equal(t, defaultReconnectInitialInterval, cfg.ReconnectInitialInterval)
assert.Equal(t, defaultReconnectMaxInterval, cfg.ReconnectMaxInterval)
assert.Equal(t, defaultCqlVersion, cfg.CQLVersion)
assert.Empty(t, cfg.Hosts)
assert.Empty(t, cfg.Keyspace)
assert.Empty(t, cfg.Username)
assert.Empty(t, cfg.Password)
assert.False(t, cfg.UseAuth)
})
}
func TestWithHosts(t *testing.T) {
tests := []struct {
name string
hosts []string
expected []string
}{
{
name: "single host",
hosts: []string{"localhost"},
expected: []string{"localhost"},
},
{
name: "multiple hosts",
hosts: []string{"localhost", "127.0.0.1", "192.168.1.1"},
expected: []string{"localhost", "127.0.0.1", "192.168.1.1"},
},
{
name: "empty hosts",
hosts: []string{},
expected: []string{},
},
{
name: "host with port",
hosts: []string{"localhost:9042"},
expected: []string{"localhost:9042"},
},
{
name: "host with domain",
hosts: []string{"cassandra.example.com"},
expected: []string{"cassandra.example.com"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithHosts(tt.hosts...)
opt(cfg)
assert.Equal(t, tt.expected, cfg.Hosts)
})
}
}
func TestWithPort(t *testing.T) {
tests := []struct {
name string
port int
expected int
}{
{
name: "default port",
port: 9042,
expected: 9042,
},
{
name: "custom port",
port: 9043,
expected: 9043,
},
{
name: "zero port",
port: 0,
expected: 0,
},
{
name: "negative port",
port: -1,
expected: -1,
},
{
name: "high port number",
port: 65535,
expected: 65535,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithPort(tt.port)
opt(cfg)
assert.Equal(t, tt.expected, cfg.Port)
})
}
}
func TestWithKeyspace(t *testing.T) {
tests := []struct {
name string
keyspace string
expected string
}{
{
name: "valid keyspace",
keyspace: "my_keyspace",
expected: "my_keyspace",
},
{
name: "empty keyspace",
keyspace: "",
expected: "",
},
{
name: "keyspace with underscore",
keyspace: "test_keyspace_1",
expected: "test_keyspace_1",
},
{
name: "keyspace with numbers",
keyspace: "keyspace123",
expected: "keyspace123",
},
{
name: "long keyspace name",
keyspace: "very_long_keyspace_name_that_might_exist",
expected: "very_long_keyspace_name_that_might_exist",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithKeyspace(tt.keyspace)
opt(cfg)
assert.Equal(t, tt.expected, cfg.Keyspace)
})
}
}
func TestWithAuth(t *testing.T) {
tests := []struct {
name string
username string
password string
expectedUser string
expectedPass string
expectedUseAuth bool
}{
{
name: "valid credentials",
username: "admin",
password: "password123",
expectedUser: "admin",
expectedPass: "password123",
expectedUseAuth: true,
},
{
name: "empty username",
username: "",
password: "password",
expectedUser: "",
expectedPass: "password",
expectedUseAuth: true,
},
{
name: "empty password",
username: "admin",
password: "",
expectedUser: "admin",
expectedPass: "",
expectedUseAuth: true,
},
{
name: "both empty",
username: "",
password: "",
expectedUser: "",
expectedPass: "",
expectedUseAuth: true,
},
{
name: "special characters in password",
username: "user",
password: "p@ssw0rd!#$%",
expectedUser: "user",
expectedPass: "p@ssw0rd!#$%",
expectedUseAuth: true,
},
{
name: "long username and password",
username: "very_long_username_that_might_exist",
password: "very_long_password_that_might_exist",
expectedUser: "very_long_username_that_might_exist",
expectedPass: "very_long_password_that_might_exist",
expectedUseAuth: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithAuth(tt.username, tt.password)
opt(cfg)
assert.Equal(t, tt.expectedUser, cfg.Username)
assert.Equal(t, tt.expectedPass, cfg.Password)
assert.Equal(t, tt.expectedUseAuth, cfg.UseAuth)
})
}
}
func TestWithConsistency(t *testing.T) {
tests := []struct {
name string
consistency gocql.Consistency
expected gocql.Consistency
}{
{
name: "Quorum consistency",
consistency: gocql.Quorum,
expected: gocql.Quorum,
},
{
name: "One consistency",
consistency: gocql.One,
expected: gocql.One,
},
{
name: "All consistency",
consistency: gocql.All,
expected: gocql.All,
},
{
name: "Any consistency",
consistency: gocql.Any,
expected: gocql.Any,
},
{
name: "LocalQuorum consistency",
consistency: gocql.LocalQuorum,
expected: gocql.LocalQuorum,
},
{
name: "EachQuorum consistency",
consistency: gocql.EachQuorum,
expected: gocql.EachQuorum,
},
{
name: "LocalOne consistency",
consistency: gocql.LocalOne,
expected: gocql.LocalOne,
},
{
name: "Two consistency",
consistency: gocql.Two,
expected: gocql.Two,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithConsistency(tt.consistency)
opt(cfg)
assert.Equal(t, tt.expected, cfg.Consistency)
})
}
}
func TestWithConnectTimeoutSec(t *testing.T) {
tests := []struct {
name string
timeout int
expected int
}{
{
name: "valid timeout",
timeout: 10,
expected: 10,
},
{
name: "zero timeout should use default",
timeout: 0,
expected: defaultTimeoutSec,
},
{
name: "negative timeout should use default",
timeout: -1,
expected: defaultTimeoutSec,
},
{
name: "large timeout",
timeout: 300,
expected: 300,
},
{
name: "small timeout",
timeout: 1,
expected: 1,
},
{
name: "very large timeout",
timeout: 3600,
expected: 3600,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithConnectTimeoutSec(tt.timeout)
opt(cfg)
assert.Equal(t, tt.expected, cfg.ConnectTimeoutSec)
})
}
}
func TestWithNumConns(t *testing.T) {
tests := []struct {
name string
numConns int
expected int
}{
{
name: "valid numConns",
numConns: 10,
expected: 10,
},
{
name: "zero numConns should use default",
numConns: 0,
expected: defaultNumConns,
},
{
name: "negative numConns should use default",
numConns: -1,
expected: defaultNumConns,
},
{
name: "large numConns",
numConns: 100,
expected: 100,
},
{
name: "small numConns",
numConns: 1,
expected: 1,
},
{
name: "very large numConns",
numConns: 1000,
expected: 1000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithNumConns(tt.numConns)
opt(cfg)
assert.Equal(t, tt.expected, cfg.NumConns)
})
}
}
func TestWithMaxRetries(t *testing.T) {
tests := []struct {
name string
maxRetries int
expected int
}{
{
name: "valid maxRetries",
maxRetries: 3,
expected: 3,
},
{
name: "zero maxRetries should use default",
maxRetries: 0,
expected: defaultMaxRetries,
},
{
name: "negative maxRetries should use default",
maxRetries: -1,
expected: defaultMaxRetries,
},
{
name: "large maxRetries",
maxRetries: 10,
expected: 10,
},
{
name: "small maxRetries",
maxRetries: 1,
expected: 1,
},
{
name: "very large maxRetries",
maxRetries: 100,
expected: 100,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithMaxRetries(tt.maxRetries)
opt(cfg)
assert.Equal(t, tt.expected, cfg.MaxRetries)
})
}
}
func TestWithRetryMinInterval(t *testing.T) {
tests := []struct {
name string
duration time.Duration
expected time.Duration
}{
{
name: "valid duration",
duration: 1 * time.Second,
expected: 1 * time.Second,
},
{
name: "zero duration should use default",
duration: 0,
expected: defaultRetryMinInterval,
},
{
name: "negative duration should use default",
duration: -1 * time.Second,
expected: defaultRetryMinInterval,
},
{
name: "milliseconds",
duration: 500 * time.Millisecond,
expected: 500 * time.Millisecond,
},
{
name: "minutes",
duration: 5 * time.Minute,
expected: 5 * time.Minute,
},
{
name: "hours",
duration: 1 * time.Hour,
expected: 1 * time.Hour,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithRetryMinInterval(tt.duration)
opt(cfg)
assert.Equal(t, tt.expected, cfg.RetryMinInterval)
})
}
}
func TestWithRetryMaxInterval(t *testing.T) {
tests := []struct {
name string
duration time.Duration
expected time.Duration
}{
{
name: "valid duration",
duration: 30 * time.Second,
expected: 30 * time.Second,
},
{
name: "zero duration should use default",
duration: 0,
expected: defaultRetryMaxInterval,
},
{
name: "negative duration should use default",
duration: -1 * time.Second,
expected: defaultRetryMaxInterval,
},
{
name: "milliseconds",
duration: 1000 * time.Millisecond,
expected: 1000 * time.Millisecond,
},
{
name: "minutes",
duration: 10 * time.Minute,
expected: 10 * time.Minute,
},
{
name: "hours",
duration: 2 * time.Hour,
expected: 2 * time.Hour,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithRetryMaxInterval(tt.duration)
opt(cfg)
assert.Equal(t, tt.expected, cfg.RetryMaxInterval)
})
}
}
func TestWithReconnectInitialInterval(t *testing.T) {
tests := []struct {
name string
duration time.Duration
expected time.Duration
}{
{
name: "valid duration",
duration: 1 * time.Second,
expected: 1 * time.Second,
},
{
name: "zero duration should use default",
duration: 0,
expected: defaultReconnectInitialInterval,
},
{
name: "negative duration should use default",
duration: -1 * time.Second,
expected: defaultReconnectInitialInterval,
},
{
name: "milliseconds",
duration: 500 * time.Millisecond,
expected: 500 * time.Millisecond,
},
{
name: "minutes",
duration: 2 * time.Minute,
expected: 2 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithReconnectInitialInterval(tt.duration)
opt(cfg)
assert.Equal(t, tt.expected, cfg.ReconnectInitialInterval)
})
}
}
func TestWithReconnectMaxInterval(t *testing.T) {
tests := []struct {
name string
duration time.Duration
expected time.Duration
}{
{
name: "valid duration",
duration: 60 * time.Second,
expected: 60 * time.Second,
},
{
name: "zero duration should use default",
duration: 0,
expected: defaultReconnectMaxInterval,
},
{
name: "negative duration should use default",
duration: -1 * time.Second,
expected: defaultReconnectMaxInterval,
},
{
name: "milliseconds",
duration: 5000 * time.Millisecond,
expected: 5000 * time.Millisecond,
},
{
name: "minutes",
duration: 5 * time.Minute,
expected: 5 * time.Minute,
},
{
name: "hours",
duration: 1 * time.Hour,
expected: 1 * time.Hour,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithReconnectMaxInterval(tt.duration)
opt(cfg)
assert.Equal(t, tt.expected, cfg.ReconnectMaxInterval)
})
}
}
func TestWithCQLVersion(t *testing.T) {
tests := []struct {
name string
version string
expected string
}{
{
name: "valid version",
version: "3.0.0",
expected: "3.0.0",
},
{
name: "empty version should use default",
version: "",
expected: defaultCqlVersion,
},
{
name: "version 3.1.0",
version: "3.1.0",
expected: "3.1.0",
},
{
name: "version 3.4.0",
version: "3.4.0",
expected: "3.4.0",
},
{
name: "version 4.0.0",
version: "4.0.0",
expected: "4.0.0",
},
{
name: "version with build",
version: "3.0.0-beta",
expected: "3.0.0-beta",
},
{
name: "version with snapshot",
version: "3.0.0-SNAPSHOT",
expected: "3.0.0-SNAPSHOT",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
opt := WithCQLVersion(tt.version)
opt(cfg)
assert.Equal(t, tt.expected, cfg.CQLVersion)
})
}
}
func TestOption_Combination(t *testing.T) {
tests := []struct {
name string
opts []Option
validate func(*testing.T, *config)
}{
{
name: "all options",
opts: []Option{
WithHosts("localhost", "127.0.0.1"),
WithPort(9042),
WithKeyspace("test_keyspace"),
WithAuth("user", "pass"),
WithConsistency(gocql.Quorum),
WithConnectTimeoutSec(10),
WithNumConns(10),
WithMaxRetries(3),
WithRetryMinInterval(1 * time.Second),
WithRetryMaxInterval(30 * time.Second),
WithReconnectInitialInterval(1 * time.Second),
WithReconnectMaxInterval(60 * time.Second),
WithCQLVersion("3.0.0"),
},
validate: func(t *testing.T, c *config) {
assert.Equal(t, []string{"localhost", "127.0.0.1"}, c.Hosts)
assert.Equal(t, 9042, c.Port)
assert.Equal(t, "test_keyspace", c.Keyspace)
assert.Equal(t, "user", c.Username)
assert.Equal(t, "pass", c.Password)
assert.True(t, c.UseAuth)
assert.Equal(t, gocql.Quorum, c.Consistency)
assert.Equal(t, 10, c.ConnectTimeoutSec)
assert.Equal(t, 10, c.NumConns)
assert.Equal(t, 3, c.MaxRetries)
assert.Equal(t, 1*time.Second, c.RetryMinInterval)
assert.Equal(t, 30*time.Second, c.RetryMaxInterval)
assert.Equal(t, 1*time.Second, c.ReconnectInitialInterval)
assert.Equal(t, 60*time.Second, c.ReconnectMaxInterval)
assert.Equal(t, "3.0.0", c.CQLVersion)
},
},
{
name: "minimal options",
opts: []Option{
WithHosts("localhost"),
},
validate: func(t *testing.T, c *config) {
assert.Equal(t, []string{"localhost"}, c.Hosts)
// 其他應該使用預設值
assert.Equal(t, defaultPort, c.Port)
assert.Equal(t, defaultConsistency, c.Consistency)
},
},
{
name: "options with zero values should use defaults",
opts: []Option{
WithHosts("localhost"),
WithConnectTimeoutSec(0),
WithNumConns(0),
WithMaxRetries(0),
WithRetryMinInterval(0),
WithRetryMaxInterval(0),
WithReconnectInitialInterval(0),
WithReconnectMaxInterval(0),
WithCQLVersion(""),
},
validate: func(t *testing.T, c *config) {
assert.Equal(t, []string{"localhost"}, c.Hosts)
assert.Equal(t, defaultTimeoutSec, c.ConnectTimeoutSec)
assert.Equal(t, defaultNumConns, c.NumConns)
assert.Equal(t, defaultMaxRetries, c.MaxRetries)
assert.Equal(t, defaultRetryMinInterval, c.RetryMinInterval)
assert.Equal(t, defaultRetryMaxInterval, c.RetryMaxInterval)
assert.Equal(t, defaultReconnectInitialInterval, c.ReconnectInitialInterval)
assert.Equal(t, defaultReconnectMaxInterval, c.ReconnectMaxInterval)
assert.Equal(t, defaultCqlVersion, c.CQLVersion)
},
},
{
name: "options with negative values should use defaults",
opts: []Option{
WithHosts("localhost"),
WithConnectTimeoutSec(-1),
WithNumConns(-1),
WithMaxRetries(-1),
WithRetryMinInterval(-1 * time.Second),
WithRetryMaxInterval(-1 * time.Second),
WithReconnectInitialInterval(-1 * time.Second),
WithReconnectMaxInterval(-1 * time.Second),
},
validate: func(t *testing.T, c *config) {
assert.Equal(t, []string{"localhost"}, c.Hosts)
assert.Equal(t, defaultTimeoutSec, c.ConnectTimeoutSec)
assert.Equal(t, defaultNumConns, c.NumConns)
assert.Equal(t, defaultMaxRetries, c.MaxRetries)
assert.Equal(t, defaultRetryMinInterval, c.RetryMinInterval)
assert.Equal(t, defaultRetryMaxInterval, c.RetryMaxInterval)
assert.Equal(t, defaultReconnectInitialInterval, c.ReconnectInitialInterval)
assert.Equal(t, defaultReconnectMaxInterval, c.ReconnectMaxInterval)
},
},
{
name: "multiple options applied in sequence",
opts: []Option{
WithHosts("host1"),
WithHosts("host2", "host3"), // 應該覆蓋
WithPort(9042),
WithPort(9043), // 應該覆蓋
},
validate: func(t *testing.T, c *config) {
assert.Equal(t, []string{"host2", "host3"}, c.Hosts)
assert.Equal(t, 9043, c.Port)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
for _, opt := range tt.opts {
opt(cfg)
}
tt.validate(t, cfg)
})
}
}
func TestOption_Type(t *testing.T) {
t.Run("all options should return Option type", func(t *testing.T) {
var opt Option
opt = WithHosts("localhost")
assert.NotNil(t, opt)
opt = WithPort(9042)
assert.NotNil(t, opt)
opt = WithKeyspace("test")
assert.NotNil(t, opt)
opt = WithAuth("user", "pass")
assert.NotNil(t, opt)
opt = WithConsistency(gocql.Quorum)
assert.NotNil(t, opt)
opt = WithConnectTimeoutSec(10)
assert.NotNil(t, opt)
opt = WithNumConns(10)
assert.NotNil(t, opt)
opt = WithMaxRetries(3)
assert.NotNil(t, opt)
opt = WithRetryMinInterval(1 * time.Second)
assert.NotNil(t, opt)
opt = WithRetryMaxInterval(30 * time.Second)
assert.NotNil(t, opt)
opt = WithReconnectInitialInterval(1 * time.Second)
assert.NotNil(t, opt)
opt = WithReconnectMaxInterval(60 * time.Second)
assert.NotNil(t, opt)
opt = WithCQLVersion("3.0.0")
assert.NotNil(t, opt)
})
}
func TestOption_EdgeCases(t *testing.T) {
t.Run("empty option slice", func(t *testing.T) {
cfg := defaultConfig()
opts := []Option{}
for _, opt := range opts {
opt(cfg)
}
// 應該保持預設值
assert.Equal(t, defaultPort, cfg.Port)
assert.Equal(t, defaultConsistency, cfg.Consistency)
})
t.Run("zero value option function", func(t *testing.T) {
cfg := defaultConfig()
var opt Option
// 零值的 Option 是 nil調用會 panic所以不應該調用
// 這裡只是驗證零值不會影響配置
_ = opt
// 應該保持預設值
assert.Equal(t, defaultPort, cfg.Port)
})
t.Run("very long strings", func(t *testing.T) {
cfg := defaultConfig()
longString := string(make([]byte, 10000))
WithKeyspace(longString)(cfg)
assert.Equal(t, longString, cfg.Keyspace)
WithAuth(longString, longString)(cfg)
assert.Equal(t, longString, cfg.Username)
assert.Equal(t, longString, cfg.Password)
})
t.Run("special characters in strings", func(t *testing.T) {
cfg := defaultConfig()
specialChars := "!@#$%^&*()_+-=[]{}|;:,.<>?"
WithKeyspace(specialChars)(cfg)
assert.Equal(t, specialChars, cfg.Keyspace)
WithAuth(specialChars, specialChars)(cfg)
assert.Equal(t, specialChars, cfg.Username)
assert.Equal(t, specialChars, cfg.Password)
})
}
func TestOption_RealWorldScenarios(t *testing.T) {
tests := []struct {
name string
scenario string
opts []Option
validate func(*testing.T, *config)
}{
{
name: "production-like configuration",
scenario: "typical production setup",
opts: []Option{
WithHosts("cassandra1.example.com", "cassandra2.example.com", "cassandra3.example.com"),
WithPort(9042),
WithKeyspace("production_keyspace"),
WithAuth("prod_user", "secure_password"),
WithConsistency(gocql.Quorum),
WithConnectTimeoutSec(30),
WithNumConns(50),
WithMaxRetries(5),
},
validate: func(t *testing.T, c *config) {
assert.Len(t, c.Hosts, 3)
assert.Equal(t, 9042, c.Port)
assert.Equal(t, "production_keyspace", c.Keyspace)
assert.True(t, c.UseAuth)
assert.Equal(t, gocql.Quorum, c.Consistency)
assert.Equal(t, 30, c.ConnectTimeoutSec)
assert.Equal(t, 50, c.NumConns)
assert.Equal(t, 5, c.MaxRetries)
},
},
{
name: "development configuration",
scenario: "local development setup",
opts: []Option{
WithHosts("localhost"),
WithKeyspace("dev_keyspace"),
},
validate: func(t *testing.T, c *config) {
assert.Equal(t, []string{"localhost"}, c.Hosts)
assert.Equal(t, "dev_keyspace", c.Keyspace)
assert.False(t, c.UseAuth)
},
},
{
name: "high availability configuration",
scenario: "HA setup with multiple hosts",
opts: []Option{
WithHosts("node1", "node2", "node3", "node4", "node5"),
WithConsistency(gocql.All),
WithMaxRetries(10),
},
validate: func(t *testing.T, c *config) {
assert.Len(t, c.Hosts, 5)
assert.Equal(t, gocql.All, c.Consistency)
assert.Equal(t, 10, c.MaxRetries)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
for _, opt := range tt.opts {
opt(cfg)
}
tt.validate(t, cfg)
})
}
}

View File

@ -0,0 +1,226 @@
package cassandra
import (
"context"
"fmt"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v2/qb"
)
// Condition 定義查詢條件介面
type Condition interface {
Build() (qb.Cmp, map[string]any)
}
// Eq 等於條件
func Eq(column string, value any) Condition {
return &eqCondition{column: column, value: value}
}
type eqCondition struct {
column string
value any
}
func (c *eqCondition) Build() (qb.Cmp, map[string]any) {
return qb.Eq(c.column), map[string]any{c.column: c.value}
}
// In IN 條件
func In(column string, values []any) Condition {
return &inCondition{column: column, values: values}
}
type inCondition struct {
column string
values []any
}
func (c *inCondition) Build() (qb.Cmp, map[string]any) {
return qb.In(c.column), map[string]any{c.column: c.values}
}
// Gt 大於條件
func Gt(column string, value any) Condition {
return &gtCondition{column: column, value: value}
}
type gtCondition struct {
column string
value any
}
func (c *gtCondition) Build() (qb.Cmp, map[string]any) {
return qb.Gt(c.column), map[string]any{c.column: c.value}
}
// Lt 小於條件
func Lt(column string, value any) Condition {
return &ltCondition{column: column, value: value}
}
type ltCondition struct {
column string
value any
}
func (c *ltCondition) Build() (qb.Cmp, map[string]any) {
return qb.Lt(c.column), map[string]any{c.column: c.value}
}
// QueryBuilder 定義查詢構建器介面
type QueryBuilder[T Table] interface {
Where(condition Condition) QueryBuilder[T]
OrderBy(column string, order Order) QueryBuilder[T]
Limit(n int) QueryBuilder[T]
Select(columns ...string) QueryBuilder[T]
Scan(ctx context.Context, dest *[]T) error
One(ctx context.Context) (T, error)
Count(ctx context.Context) (int64, error)
}
// queryBuilder 是 QueryBuilder 的具體實作
type queryBuilder[T Table] struct {
repo *repository[T]
conditions []Condition
orders []orderBy
limit int
columns []string
}
type orderBy struct {
column string
order Order
}
// newQueryBuilder 創建新的查詢構建器
func newQueryBuilder[T Table](repo *repository[T]) QueryBuilder[T] {
return &queryBuilder[T]{
repo: repo,
}
}
// Where 添加 WHERE 條件
func (q *queryBuilder[T]) Where(condition Condition) QueryBuilder[T] {
q.conditions = append(q.conditions, condition)
return q
}
// OrderBy 添加排序
func (q *queryBuilder[T]) OrderBy(column string, order Order) QueryBuilder[T] {
q.orders = append(q.orders, orderBy{column: column, order: order})
return q
}
// Limit 設置限制
func (q *queryBuilder[T]) Limit(n int) QueryBuilder[T] {
q.limit = n
return q
}
// Select 指定要查詢的欄位
func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] {
q.columns = append(q.columns, columns...)
return q
}
// Scan 執行查詢並將結果掃描到 dest
func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
if dest == nil {
return ErrInvalidInput.WithTable(q.repo.table).WithError(
fmt.Errorf("destination cannot be nil"),
)
}
builder := qb.Select(q.repo.table)
// 添加欄位
if len(q.columns) > 0 {
builder = builder.Columns(q.columns...)
} else {
builder = builder.Columns(q.repo.metadata.Columns...)
}
// 添加條件
bindMap := make(map[string]any)
var cmps []qb.Cmp
for _, cond := range q.conditions {
cmp, binds := cond.Build()
cmps = append(cmps, cmp)
for k, v := range binds {
bindMap[k] = v
}
}
if len(cmps) > 0 {
builder = builder.Where(cmps...)
}
// 添加排序
for _, o := range q.orders {
order := qb.ASC
if o.order == DESC {
order = qb.DESC
}
builder = builder.OrderBy(o.column, order)
}
// 添加限制
if q.limit > 0 {
builder = builder.Limit(uint(q.limit))
}
stmt, names := builder.ToCql()
query := q.repo.db.withContextAndTimestamp(ctx,
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
return query.SelectRelease(dest)
}
// One 執行查詢並返回單筆結果
func (q *queryBuilder[T]) One(ctx context.Context) (T, error) {
var zero T
q.limit = 1
var results []T
if err := q.Scan(ctx, &results); err != nil {
return zero, err
}
if len(results) == 0 {
return zero, ErrNotFound.WithTable(q.repo.table)
}
return results[0], nil
}
// Count 計算符合條件的記錄數
func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) {
builder := qb.Select(q.repo.table).Columns("COUNT(*)")
// 添加條件
bindMap := make(map[string]any)
var cmps []qb.Cmp
for _, cond := range q.conditions {
cmp, binds := cond.Build()
cmps = append(cmps, cmp)
for k, v := range binds {
bindMap[k] = v
}
}
if len(cmps) > 0 {
builder = builder.Where(cmps...)
}
stmt, names := builder.ToCql()
query := q.repo.db.withContextAndTimestamp(ctx,
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
var count int64
err := query.GetRelease(&count)
if err == gocql.ErrNotFound {
return 0, nil // COUNT 查詢不會返回 ErrNotFound但為了安全起見
}
return count, err
}

View File

@ -0,0 +1,519 @@
package cassandra
import (
"testing"
"github.com/scylladb/gocqlx/v2/qb"
"github.com/stretchr/testify/assert"
)
func TestEq(t *testing.T) {
tests := []struct {
name string
column string
value any
validate func(*testing.T, Condition)
}{
{
name: "string value",
column: "name",
value: "Alice",
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, "Alice", binds["name"])
},
},
{
name: "int value",
column: "age",
value: 25,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, 25, binds["age"])
},
},
{
name: "nil value",
column: "description",
value: nil,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Nil(t, binds["description"])
},
},
{
name: "empty string",
column: "email",
value: "",
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, "", binds["email"])
},
},
{
name: "boolean value",
column: "active",
value: true,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, true, binds["active"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cond := Eq(tt.column, tt.value)
assert.NotNil(t, cond)
tt.validate(t, cond)
})
}
}
func TestIn(t *testing.T) {
tests := []struct {
name string
column string
values []any
validate func(*testing.T, Condition)
}{
{
name: "string values",
column: "status",
values: []any{"active", "pending", "completed"},
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, []any{"active", "pending", "completed"}, binds["status"])
},
},
{
name: "int values",
column: "ids",
values: []any{1, 2, 3, 4, 5},
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, []any{1, 2, 3, 4, 5}, binds["ids"])
},
},
{
name: "empty slice",
column: "tags",
values: []any{},
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, []any{}, binds["tags"])
},
},
{
name: "single value",
column: "id",
values: []any{1},
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, []any{1}, binds["id"])
},
},
{
name: "mixed types",
column: "values",
values: []any{"string", 123, true},
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, []any{"string", 123, true}, binds["values"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cond := In(tt.column, tt.values)
assert.NotNil(t, cond)
tt.validate(t, cond)
})
}
}
func TestGt(t *testing.T) {
tests := []struct {
name string
column string
value any
validate func(*testing.T, Condition)
}{
{
name: "int value",
column: "age",
value: 18,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, 18, binds["age"])
},
},
{
name: "float value",
column: "price",
value: 99.99,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, 99.99, binds["price"])
},
},
{
name: "zero value",
column: "count",
value: 0,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, 0, binds["count"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cond := Gt(tt.column, tt.value)
assert.NotNil(t, cond)
tt.validate(t, cond)
})
}
}
func TestLt(t *testing.T) {
tests := []struct {
name string
column string
value any
validate func(*testing.T, Condition)
}{
{
name: "int value",
column: "age",
value: 65,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, 65, binds["age"])
},
},
{
name: "float value",
column: "price",
value: 199.99,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, 199.99, binds["price"])
},
},
{
name: "negative value",
column: "balance",
value: -100,
validate: func(t *testing.T, cond Condition) {
cmp, binds := cond.Build()
assert.NotNil(t, cmp)
assert.Equal(t, -100, binds["balance"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cond := Lt(tt.column, tt.value)
assert.NotNil(t, cond)
tt.validate(t, cond)
})
}
}
func TestCondition_Build(t *testing.T) {
tests := []struct {
name string
cond Condition
validate func(*testing.T, qb.Cmp, map[string]any)
}{
{
name: "Eq condition",
cond: Eq("name", "test"),
validate: func(t *testing.T, cmp qb.Cmp, binds map[string]any) {
assert.NotNil(t, cmp)
assert.Equal(t, "test", binds["name"])
},
},
{
name: "In condition",
cond: In("ids", []any{1, 2, 3}),
validate: func(t *testing.T, cmp qb.Cmp, binds map[string]any) {
assert.NotNil(t, cmp)
assert.Equal(t, []any{1, 2, 3}, binds["ids"])
},
},
{
name: "Gt condition",
cond: Gt("age", 18),
validate: func(t *testing.T, cmp qb.Cmp, binds map[string]any) {
assert.NotNil(t, cmp)
assert.Equal(t, 18, binds["age"])
},
},
{
name: "Lt condition",
cond: Lt("price", 100),
validate: func(t *testing.T, cmp qb.Cmp, binds map[string]any) {
assert.NotNil(t, cmp)
assert.Equal(t, 100, binds["price"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmp, binds := tt.cond.Build()
tt.validate(t, cmp, binds)
})
}
}
func TestQueryBuilder_Where(t *testing.T) {
tests := []struct {
name string
condition Condition
validate func(*testing.T, *queryBuilder[testUser])
}{
{
name: "single condition",
condition: Eq("name", "Alice"),
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
assert.Len(t, qb.conditions, 1)
},
},
{
name: "multiple conditions",
condition: In("status", []any{"active", "pending"}),
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
// 添加多個條件
cond := In("status", []any{"active", "pending"})
qb.Where(Eq("name", "test"))
qb.Where(cond)
assert.Len(t, qb.conditions, 2)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 repository但我們可以測試鏈式調用
// 實際的執行需要資料庫連接
_ = tt
})
}
}
func TestQueryBuilder_OrderBy(t *testing.T) {
tests := []struct {
name string
column string
order Order
validate func(*testing.T, *queryBuilder[testUser])
}{
{
name: "ASC order",
column: "created_at",
order: ASC,
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
assert.Len(t, qb.orders, 1)
assert.Equal(t, "created_at", qb.orders[0].column)
assert.Equal(t, ASC, qb.orders[0].order)
},
},
{
name: "DESC order",
column: "updated_at",
order: DESC,
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
assert.Len(t, qb.orders, 1)
assert.Equal(t, "updated_at", qb.orders[0].column)
assert.Equal(t, DESC, qb.orders[0].order)
},
},
{
name: "multiple orders",
column: "name",
order: ASC,
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
qb.OrderBy("created_at", DESC)
qb.OrderBy("name", ASC)
assert.Len(t, qb.orders, 2)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 repository
_ = tt
})
}
}
func TestQueryBuilder_Limit(t *testing.T) {
tests := []struct {
name string
limit int
expected int
}{
{
name: "positive limit",
limit: 10,
expected: 10,
},
{
name: "zero limit",
limit: 0,
expected: 0,
},
{
name: "large limit",
limit: 1000,
expected: 1000,
},
{
name: "negative limit",
limit: -1,
expected: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 repository
_ = tt
})
}
}
func TestQueryBuilder_Select(t *testing.T) {
tests := []struct {
name string
columns []string
expected int
}{
{
name: "single column",
columns: []string{"name"},
expected: 1,
},
{
name: "multiple columns",
columns: []string{"name", "email", "age"},
expected: 3,
},
{
name: "empty columns",
columns: []string{},
expected: 0,
},
{
name: "duplicate columns",
columns: []string{"name", "name"},
expected: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 repository
_ = tt
})
}
}
func TestQueryBuilder_Chaining(t *testing.T) {
t.Run("chain multiple methods", func(t *testing.T) {
// 注意:這需要一個有效的 repository
// 實際的執行需要資料庫連接
// 這裡只是展示測試結構
})
}
func TestQueryBuilder_Scan_ErrorCases(t *testing.T) {
tests := []struct {
name string
description string
}{
{
name: "nil destination",
description: "should return error when destination is nil",
},
{
name: "invalid query",
description: "should return error when query is invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
_ = tt
})
}
}
func TestQueryBuilder_One_ErrorCases(t *testing.T) {
tests := []struct {
name string
description string
}{
{
name: "no results",
description: "should return ErrNotFound when no results found",
},
{
name: "query error",
description: "should return error when query fails",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
_ = tt
})
}
}
func TestQueryBuilder_Count_ErrorCases(t *testing.T) {
tests := []struct {
name string
description string
}{
{
name: "query error",
description: "should return error when query fails",
},
{
name: "ErrNotFound should return 0",
description: "should return 0 when ErrNotFound",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
_ = tt
})
}
}

View File

View File

@ -0,0 +1,265 @@
package cassandra
import (
"context"
"errors"
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v2"
"github.com/scylladb/gocqlx/v2/qb"
"github.com/scylladb/gocqlx/v2/table"
)
// Repository 定義資料存取介面(小介面,符合 M3
type Repository[T Table] interface {
Insert(ctx context.Context, doc T) error
Get(ctx context.Context, pk any) (T, error)
Update(ctx context.Context, doc T) error
Delete(ctx context.Context, pk any) error
InsertMany(ctx context.Context, docs []T) error
Query() QueryBuilder[T]
TryLock(ctx context.Context, doc T, opts ...LockOption) error
UnLock(ctx context.Context, doc T) error
}
// repository 是 Repository 的具體實作
type repository[T Table] struct {
db *DB
keyspace string
table string
metadata table.Metadata
}
// NewRepository 獲取指定類型的 Repository
// keyspace 如果為空,使用預設 keyspace
func NewRepository[T Table](db *DB, keyspace string) (Repository[T], error) {
if keyspace == "" {
keyspace = db.defaultKeyspace
}
var zero T
metadata, err := generateMetadata(zero, keyspace)
if err != nil {
return nil, fmt.Errorf("failed to generate metadata: %w", err)
}
return &repository[T]{
db: db,
keyspace: keyspace,
table: metadata.Name,
metadata: metadata,
}, nil
}
// Insert 插入單筆資料
func (r *repository[T]) Insert(ctx context.Context, doc T) error {
t := table.New(r.metadata)
q := r.db.withContextAndTimestamp(ctx,
r.db.session.Query(t.Insert()).BindStruct(doc))
return q.ExecRelease()
}
// Get 根據主鍵查詢單筆資料
// 注意pk 必須是完整的 Primary Key包含所有 Partition Key 和 Clustering Key
// 如果主鍵是多欄位,需要傳入包含所有主鍵欄位的 struct
// pk 可以是string, int, int64, gocql.UUID, []byte 或包含主鍵欄位的 struct
func (r *repository[T]) Get(ctx context.Context, pk any) (T, error) {
var zero T
t := table.New(r.metadata)
// 使用 table.Get() 方法,它會自動根據 metadata 構建主鍵查詢
// 如果 pk 是 struct使用 BindStruct否則使用 Bind
var q *gocqlx.Queryx
if reflect.TypeOf(pk).Kind() == reflect.Struct {
q = r.db.withContextAndTimestamp(ctx,
r.db.session.Query(t.Get()).BindStruct(pk))
} else {
// 單一主鍵欄位的情況
// 注意:這只適用於單一 Partition Key 且無 Clustering Key 的情況
if len(r.metadata.PartKey) != 1 || len(r.metadata.SortKey) > 0 {
return zero, ErrInvalidInput.WithTable(r.table).WithError(
fmt.Errorf("single value primary key only supported for single partition key without clustering key"),
)
}
q = r.db.withContextAndTimestamp(ctx,
r.db.session.Query(t.Get()).Bind(pk))
}
var result T
err := q.GetRelease(&result)
if errors.Is(err, gocql.ErrNotFound) {
return zero, ErrNotFound.WithTable(r.table)
}
if err != nil {
return zero, ErrInvalidInput.WithTable(r.table).WithError(err)
}
return result, nil
}
// Update 更新資料(只更新非零值欄位)
func (r *repository[T]) Update(ctx context.Context, doc T) error {
return r.updateSelective(ctx, doc, false)
}
// UpdateAll 更新所有欄位(包括零值)
func (r *repository[T]) UpdateAll(ctx context.Context, doc T) error {
return r.updateSelective(ctx, doc, true)
}
// updateSelective 選擇性更新
func (r *repository[T]) updateSelective(ctx context.Context, doc T, includeZero bool) error {
// 重用現有的 BuildUpdateFields 邏輯
// 由於在不同套件,我們需要重新實作或導入
fields, err := r.buildUpdateFields(doc, includeZero)
if err != nil {
return err
}
stmt, names := r.buildUpdateStatement(fields.setCols, fields.whereCols)
setVals := append(fields.setVals, fields.whereVals...)
q := r.db.withContextAndTimestamp(ctx,
r.db.session.Query(stmt, names).Bind(setVals...))
return q.ExecRelease()
}
// Delete 刪除資料
// pk 可以是string, int, int64, gocql.UUID, []byte 或包含主鍵欄位的 struct
func (r *repository[T]) Delete(ctx context.Context, pk any) error {
t := table.New(r.metadata)
stmt, names := t.Delete()
q := r.db.withContextAndTimestamp(ctx,
r.db.session.Query(stmt, names).Bind(pk))
return q.ExecRelease()
}
// InsertMany 批次插入資料
func (r *repository[T]) InsertMany(ctx context.Context, docs []T) error {
if len(docs) == 0 {
return nil
}
// 使用 Batch 操作
batch := r.db.session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
t := table.New(r.metadata)
stmt, names := t.Insert()
for _, doc := range docs {
// 在 v2 中,需要手動提取值
v := reflect.ValueOf(doc)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
values := make([]interface{}, len(names))
for i, name := range names {
// 根據 metadata 找到對應的欄位
for j, col := range r.metadata.Columns {
if col == name {
fieldValue := v.Field(j)
values[i] = fieldValue.Interface()
break
}
}
}
batch.Query(stmt, values...)
}
return r.db.session.ExecuteBatch(batch)
}
// Query 返回查詢構建器
func (r *repository[T]) Query() QueryBuilder[T] {
return newQueryBuilder(r)
}
// updateFields 包含更新操作所需的欄位資訊
type updateFields struct {
setCols []string
setVals []any
whereCols []string
whereVals []any
}
// buildUpdateFields 從 document 中提取更新所需的欄位資訊
func (r *repository[T]) buildUpdateFields(doc T, includeZero bool) (*updateFields, error) {
v := reflect.ValueOf(doc)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
typ := v.Type()
setCols := make([]string, 0)
setVals := make([]any, 0)
whereCols := make([]string, 0)
whereVals := make([]any, 0)
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get(DBFiledName)
if tag == "" || tag == "-" {
continue
}
val := v.Field(i)
if !val.IsValid() {
continue
}
// 主鍵欄位放入 WHERE 條件
if contains(r.metadata.PartKey, tag) || contains(r.metadata.SortKey, tag) {
whereCols = append(whereCols, tag)
whereVals = append(whereVals, val.Interface())
continue
}
// 根據 includeZero 決定是否包含零值欄位
if !includeZero && isZero(val) {
continue
}
setCols = append(setCols, tag)
setVals = append(setVals, val.Interface())
}
if len(setCols) == 0 {
return nil, ErrNoFieldsToUpdate.WithTable(r.table)
}
return &updateFields{
setCols: setCols,
setVals: setVals,
whereCols: whereCols,
whereVals: whereVals,
}, nil
}
// buildUpdateStatement 構建 UPDATE CQL 語句
func (r *repository[T]) buildUpdateStatement(setCols, whereCols []string) (string, []string) {
builder := qb.Update(r.table).Set(setCols...)
for _, col := range whereCols {
builder = builder.Where(qb.Eq(col))
}
return builder.ToCql()
}
// contains 判斷字串是否存在於 slice 中
func contains(list []string, target string) bool {
for _, item := range list {
if item == target {
return true
}
}
return false
}
// isZero 判斷欄位是否為零值或 nil
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice:
return v.IsNil()
default:
return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
}
}

View File

@ -0,0 +1,546 @@
package cassandra
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestContains(t *testing.T) {
tests := []struct {
name string
list []string
target string
want bool
}{
{
name: "target exists in list",
list: []string{"a", "b", "c"},
target: "b",
want: true,
},
{
name: "target at beginning",
list: []string{"a", "b", "c"},
target: "a",
want: true,
},
{
name: "target at end",
list: []string{"a", "b", "c"},
target: "c",
want: true,
},
{
name: "target not in list",
list: []string{"a", "b", "c"},
target: "d",
want: false,
},
{
name: "empty list",
list: []string{},
target: "a",
want: false,
},
{
name: "empty target",
list: []string{"a", "b", "c"},
target: "",
want: false,
},
{
name: "target in single element list",
list: []string{"a"},
target: "a",
want: true,
},
{
name: "case sensitive",
list: []string{"A", "B", "C"},
target: "a",
want: false,
},
{
name: "duplicate values",
list: []string{"a", "b", "a", "c"},
target: "a",
want: true,
},
{
name: "long list",
list: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"},
target: "j",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.list, tt.target)
assert.Equal(t, tt.want, result)
})
}
}
func TestIsZero(t *testing.T) {
tests := []struct {
name string
value any
expected bool
skip bool
}{
{
name: "nil pointer",
value: (*string)(nil),
expected: true,
skip: false,
},
{
name: "non-nil pointer",
value: stringPtr("test"),
expected: false,
skip: false,
},
{
name: "nil slice",
value: []string(nil),
expected: true,
skip: false,
},
{
name: "empty slice",
value: []string{},
expected: false, // 空 slice 不是 nil
skip: false,
},
{
name: "nil map",
value: map[string]int(nil),
expected: true,
skip: false,
},
{
name: "empty map",
value: map[string]int{},
expected: false, // 空 map 不是 nil
skip: false,
},
{
name: "zero int",
value: 0,
expected: true,
skip: false,
},
{
name: "non-zero int",
value: 42,
expected: false,
skip: false,
},
{
name: "zero int64",
value: int64(0),
expected: true,
skip: false,
},
{
name: "non-zero int64",
value: int64(42),
expected: false,
skip: false,
},
{
name: "zero float64",
value: 0.0,
expected: true,
skip: false,
},
{
name: "non-zero float64",
value: 3.14,
expected: false,
skip: false,
},
{
name: "empty string",
value: "",
expected: true,
skip: false,
},
{
name: "non-empty string",
value: "test",
expected: false,
skip: false,
},
{
name: "false bool",
value: false,
expected: true,
skip: false,
},
{
name: "true bool",
value: true,
expected: false,
skip: false,
},
{
name: "struct with zero values",
value: testUser{},
expected: true, // 所有欄位都是零值,應該返回 true
skip: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skip {
t.Skip("Skipping test")
return
}
// 使用 reflect.ValueOf 來獲取 reflect.Value
v := reflect.ValueOf(tt.value)
// 檢查是否為零值nil interface 會導致 zero Value
if !v.IsValid() {
// 對於 nil interface直接返回 true
assert.True(t, tt.expected)
return
}
result := isZero(v)
assert.Equal(t, tt.expected, result)
})
}
}
func TestNewRepository(t *testing.T) {
tests := []struct {
name string
keyspace string
wantErr bool
validate func(*testing.T, Repository[testUser], *DB)
}{
{
name: "valid keyspace",
keyspace: "test_keyspace",
wantErr: false,
validate: func(t *testing.T, repo Repository[testUser], db *DB) {
assert.NotNil(t, repo)
},
},
{
name: "empty keyspace uses default",
keyspace: "",
wantErr: false,
validate: func(t *testing.T, repo Repository[testUser], db *DB) {
assert.NotNil(t, repo)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}
func TestRepository_Insert(t *testing.T) {
tests := []struct {
name string
description string
}{
{
name: "successful insert",
description: "should insert document successfully",
},
{
name: "duplicate key",
description: "should return error on duplicate key",
},
{
name: "invalid document",
description: "should return error for invalid document",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
_ = tt
})
}
}
func TestRepository_Get(t *testing.T) {
tests := []struct {
name string
pk any
description string
wantErr bool
}{
{
name: "found with string key",
pk: "test-id",
description: "should return document when found",
wantErr: false,
},
{
name: "not found",
pk: "non-existent",
description: "should return ErrNotFound when not found",
wantErr: true,
},
{
name: "invalid primary key structure",
pk: "single-key",
description: "should return error for invalid key structure",
wantErr: true,
},
{
name: "struct primary key",
pk: testUser{ID: "test-id"},
description: "should work with struct primary key",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
_ = tt
})
}
}
func TestRepository_Update(t *testing.T) {
tests := []struct {
name string
description string
wantErr bool
}{
{
name: "successful update",
description: "should update document successfully",
wantErr: false,
},
{
name: "not found",
description: "should return error when document not found",
wantErr: true,
},
{
name: "no fields to update",
description: "should return error when no fields to update",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
_ = tt
})
}
}
func TestRepository_Delete(t *testing.T) {
tests := []struct {
name string
pk any
description string
wantErr bool
}{
{
name: "successful delete",
pk: "test-id",
description: "should delete document successfully",
wantErr: false,
},
{
name: "not found",
pk: "non-existent",
description: "should not return error when not found (idempotent)",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
_ = tt
})
}
}
func TestRepository_InsertMany(t *testing.T) {
tests := []struct {
name string
docs []testUser
description string
wantErr bool
}{
{
name: "empty slice",
docs: []testUser{},
description: "should return nil for empty slice",
wantErr: false,
},
{
name: "single document",
docs: []testUser{{ID: "1", Name: "Alice"}},
description: "should insert single document",
wantErr: false,
},
{
name: "multiple documents",
docs: []testUser{{ID: "1", Name: "Alice"}, {ID: "2", Name: "Bob"}},
description: "should insert multiple documents",
wantErr: false,
},
{
name: "large batch",
docs: make([]testUser, 100),
description: "should handle large batch",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要 mock session 或實際的資料庫連接
_ = tt
})
}
}
func TestRepository_Query(t *testing.T) {
t.Run("should return QueryBuilder", func(t *testing.T) {
// 注意:這需要一個有效的 repository
// 實際的執行需要資料庫連接
})
}
func TestBuildUpdateStatement(t *testing.T) {
tests := []struct {
name string
setCols []string
whereCols []string
table string
validate func(*testing.T, string, []string)
}{
{
name: "single set column, single where column",
setCols: []string{"name"},
whereCols: []string{"id"},
table: "users",
validate: func(t *testing.T, stmt string, names []string) {
assert.Contains(t, stmt, "UPDATE")
assert.Contains(t, stmt, "users")
assert.Contains(t, stmt, "SET")
assert.Contains(t, stmt, "WHERE")
assert.Len(t, names, 2) // name, id
},
},
{
name: "multiple set columns, single where column",
setCols: []string{"name", "email", "age"},
whereCols: []string{"id"},
table: "users",
validate: func(t *testing.T, stmt string, names []string) {
assert.Contains(t, stmt, "UPDATE")
assert.Contains(t, stmt, "users")
assert.Len(t, names, 4) // name, email, age, id
},
},
{
name: "single set column, multiple where columns",
setCols: []string{"status"},
whereCols: []string{"user_id", "account_id"},
table: "accounts",
validate: func(t *testing.T, stmt string, names []string) {
assert.Contains(t, stmt, "UPDATE")
assert.Contains(t, stmt, "accounts")
assert.Len(t, names, 3) // status, user_id, account_id
},
},
{
name: "multiple set and where columns",
setCols: []string{"name", "email"},
whereCols: []string{"id", "version"},
table: "users",
validate: func(t *testing.T, stmt string, names []string) {
assert.Contains(t, stmt, "UPDATE")
assert.Len(t, names, 4) // name, email, id, version
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 創建一個臨時的 repository 來測試 buildUpdateStatement
// 注意:這需要一個有效的 metadata
// 使用 testUser 的 metadata
var zero testUser
metadata, err := generateMetadata(zero, "test_keyspace")
require.NoError(t, err)
repo := &repository[testUser]{
table: tt.table,
metadata: metadata,
}
stmt, names := repo.buildUpdateStatement(tt.setCols, tt.whereCols)
tt.validate(t, stmt, names)
})
}
}
func TestBuildUpdateFields(t *testing.T) {
tests := []struct {
name string
doc testUser
includeZero bool
wantErr bool
validate func(*testing.T, *updateFields)
}{
{
name: "update with includeZero false",
doc: testUser{ID: "1", Name: "Alice", Email: "alice@example.com"},
includeZero: false,
wantErr: false,
validate: func(t *testing.T, fields *updateFields) {
assert.NotEmpty(t, fields.setCols)
assert.Contains(t, fields.whereCols, "id")
},
},
{
name: "update with includeZero true",
doc: testUser{ID: "1", Name: "", Email: ""},
includeZero: true,
wantErr: false,
validate: func(t *testing.T, fields *updateFields) {
assert.NotEmpty(t, fields.setCols)
},
},
{
name: "no fields to update",
doc: testUser{ID: "1"},
includeZero: false,
wantErr: true,
validate: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 repository 和 metadata
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}

View File

@ -0,0 +1,289 @@
package cassandra
import (
"context"
"fmt"
"strings"
"github.com/gocql/gocql"
)
// SAIIndexType 定義 SAI 索引類型
type SAIIndexType string
const (
// SAIIndexTypeStandard 標準索引(等於查詢)
SAIIndexTypeStandard SAIIndexType = "STANDARD"
// SAIIndexTypeCollection 集合索引(用於 list、set、map
SAIIndexTypeCollection SAIIndexType = "COLLECTION"
// SAIIndexTypeFullText 全文索引
SAIIndexTypeFullText SAIIndexType = "FULL_TEXT"
)
// SAIIndexOptions 定義 SAI 索引選項
type SAIIndexOptions struct {
IndexType SAIIndexType // 索引類型
IsAsync bool // 是否異步建立索引
CaseSensitive bool // 是否區分大小寫(用於全文索引)
}
// DefaultSAIIndexOptions 返回預設的 SAI 索引選項
func DefaultSAIIndexOptions() *SAIIndexOptions {
return &SAIIndexOptions{
IndexType: SAIIndexTypeStandard,
IsAsync: false,
CaseSensitive: true,
}
}
// CreateSAIIndex 建立 SAI 索引
// keyspace: keyspace 名稱
// table: 資料表名稱
// column: 欄位名稱
// indexName: 索引名稱(可選,如果為空則自動生成)
// opts: 索引選項(可選,如果為 nil 則使用預設選項)
func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column, indexName string, opts *SAIIndexOptions) error {
// 檢查是否支援 SAI
if !db.saiSupported {
return ErrInvalidInput.WithError(fmt.Errorf("SAI is not supported in Cassandra version %s (requires 4.0.9+ or 5.0+)", db.version))
}
// 驗證參數
if keyspace == "" {
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
}
if table == "" {
return ErrInvalidInput.WithError(fmt.Errorf("table is required"))
}
if column == "" {
return ErrInvalidInput.WithError(fmt.Errorf("column is required"))
}
// 使用預設選項如果未提供
if opts == nil {
opts = DefaultSAIIndexOptions()
}
// 生成索引名稱如果未提供
if indexName == "" {
indexName = fmt.Sprintf("%s_%s_sai_idx", table, column)
}
// 構建 CREATE INDEX 語句
var stmt strings.Builder
stmt.WriteString("CREATE CUSTOM INDEX IF NOT EXISTS ")
stmt.WriteString(indexName)
stmt.WriteString(" ON ")
stmt.WriteString(keyspace)
stmt.WriteString(".")
stmt.WriteString(table)
stmt.WriteString(" (")
stmt.WriteString(column)
stmt.WriteString(") USING 'StorageAttachedIndex'")
// 添加選項
var options []string
if opts.IsAsync {
options = append(options, "'async'='true'")
}
// 根據索引類型添加特定選項
switch opts.IndexType {
case SAIIndexTypeFullText:
if !opts.CaseSensitive {
options = append(options, "'case_sensitive'='false'")
} else {
options = append(options, "'case_sensitive'='true'")
}
case SAIIndexTypeCollection:
// Collection 索引不需要額外選項
}
// 如果有選項,添加到語句中
if len(options) > 0 {
stmt.WriteString(" WITH OPTIONS = {")
stmt.WriteString(strings.Join(options, ", "))
stmt.WriteString("}")
}
// 執行建立索引語句
query := db.session.Query(stmt.String(), nil).
WithContext(ctx).
Consistency(gocql.Quorum)
err := query.ExecRelease()
if err != nil {
return ErrInvalidInput.WithError(fmt.Errorf("failed to create SAI index: %w", err))
}
return nil
}
// DropSAIIndex 刪除 SAI 索引
// keyspace: keyspace 名稱
// indexName: 索引名稱
func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) error {
// 驗證參數
if keyspace == "" {
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
}
if indexName == "" {
return ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
}
// 構建 DROP INDEX 語句
stmt := fmt.Sprintf("DROP INDEX IF EXISTS %s.%s", keyspace, indexName)
// 執行刪除索引語句
query := db.session.Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum)
err := query.ExecRelease()
if err != nil {
return ErrInvalidInput.WithError(fmt.Errorf("failed to drop SAI index: %w", err))
}
return nil
}
// ListSAIIndexes 列出指定資料表的所有 SAI 索引
// keyspace: keyspace 名稱
// table: 資料表名稱
func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAIIndexInfo, error) {
// 驗證參數
if keyspace == "" {
return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
}
if table == "" {
return nil, ErrInvalidInput.WithError(fmt.Errorf("table is required"))
}
// 查詢系統表獲取索引資訊
// system_schema.indexes 表的結構keyspace_name, table_name, index_name, kind, options
stmt := `
SELECT index_name, kind, options
FROM system_schema.indexes
WHERE keyspace_name = ? AND table_name = ?
`
var indexes []SAIIndexInfo
iter := db.session.Query(stmt, []string{"keyspace_name", "table_name"}).
WithContext(ctx).
Consistency(gocql.One).
Bind(keyspace, table).
Iter()
var indexName, kind string
var options map[string]string
for iter.Scan(&indexName, &kind, &options) {
// 檢查是否為 SAI 索引kind = 'CUSTOM' 且 class_name 包含 StorageAttachedIndex
if kind == "CUSTOM" {
if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") {
// 從 options 中提取 target欄位名稱
columnName := ""
if target, ok := options["target"]; ok {
columnName = strings.Trim(target, "()\"'")
}
indexes = append(indexes, SAIIndexInfo{
Name: indexName,
Type: "StorageAttachedIndex",
Options: options,
Column: columnName,
})
}
}
}
if err := iter.Close(); err != nil {
return nil, ErrInvalidInput.WithError(fmt.Errorf("failed to list SAI indexes: %w", err))
}
return indexes, nil
}
// SAIIndexInfo 表示 SAI 索引資訊
type SAIIndexInfo struct {
Name string // 索引名稱
Type string // 索引類型
Options map[string]string // 索引選項
Column string // 索引欄位名稱
}
// CheckSAIIndexExists 檢查 SAI 索引是否存在
// keyspace: keyspace 名稱
// indexName: 索引名稱
func (db *DB) CheckSAIIndexExists(ctx context.Context, keyspace, indexName string) (bool, error) {
// 驗證參數
if keyspace == "" {
return false, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
}
if indexName == "" {
return false, ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
}
// 查詢系統表檢查索引是否存在
stmt := `
SELECT index_name, kind, options
FROM system_schema.indexes
WHERE keyspace_name = ? AND index_name = ?
LIMIT 1
`
var foundIndexName, kind string
var options map[string]string
err := db.session.Query(stmt, []string{"keyspace_name", "index_name"}).
WithContext(ctx).
Consistency(gocql.One).
Bind(keyspace, indexName).
Scan(&foundIndexName, &kind, &options)
if err == gocql.ErrNotFound {
return false, nil
}
if err != nil {
return false, ErrInvalidInput.WithError(fmt.Errorf("failed to check SAI index existence: %w", err))
}
// 檢查是否為 SAI 索引
if kind == "CUSTOM" {
if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") {
return true, nil
}
}
return false, nil
}
// WaitForSAIIndex 等待 SAI 索引建立完成(用於異步建立)
// keyspace: keyspace 名稱
// indexName: 索引名稱
// maxWaitTime: 最大等待時間(秒)
func (db *DB) WaitForSAIIndex(ctx context.Context, keyspace, indexName string, maxWaitTime int) error {
// 驗證參數
if keyspace == "" {
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
}
if indexName == "" {
return ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
}
// 查詢索引狀態
// 注意Cassandra 沒有直接的索引狀態查詢,這裡需要通過檢查索引是否可用來判斷
// 實際實作可能需要根據具體的 Cassandra 版本調整
// 簡單實作:檢查索引是否存在
exists, err := db.CheckSAIIndexExists(ctx, keyspace, indexName)
if err != nil {
return err
}
if !exists {
return ErrInvalidInput.WithError(fmt.Errorf("index %s does not exist", indexName))
}
// 注意:實際的等待邏輯可能需要查詢系統表或使用其他方法
// 這裡只是基本框架,實際使用時可能需要根據具體需求調整
return nil
}

View File

@ -0,0 +1,267 @@
package cassandra
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDefaultSAIIndexOptions(t *testing.T) {
opts := DefaultSAIIndexOptions()
assert.NotNil(t, opts)
assert.Equal(t, SAIIndexTypeStandard, opts.IndexType)
assert.False(t, opts.IsAsync)
assert.True(t, opts.CaseSensitive)
}
func TestCreateSAIIndex_Validation(t *testing.T) {
tests := []struct {
name string
keyspace string
table string
column string
indexName string
opts *SAIIndexOptions
wantErr bool
errMsg string
}{
{
name: "missing keyspace",
keyspace: "",
table: "test_table",
column: "test_column",
indexName: "test_idx",
opts: nil,
wantErr: true,
errMsg: "keyspace is required",
},
{
name: "missing table",
keyspace: "test_keyspace",
table: "",
column: "test_column",
indexName: "test_idx",
opts: nil,
wantErr: true,
errMsg: "table is required",
},
{
name: "missing column",
keyspace: "test_keyspace",
table: "test_table",
column: "",
indexName: "test_idx",
opts: nil,
wantErr: true,
errMsg: "column is required",
},
{
name: "valid parameters with default options",
keyspace: "test_keyspace",
table: "test_table",
column: "test_column",
indexName: "test_idx",
opts: nil,
wantErr: false,
},
{
name: "valid parameters with custom options",
keyspace: "test_keyspace",
table: "test_table",
column: "test_column",
indexName: "test_idx",
opts: &SAIIndexOptions{
IndexType: SAIIndexTypeFullText,
IsAsync: true,
CaseSensitive: false,
},
wantErr: false,
},
{
name: "auto-generate index name",
keyspace: "test_keyspace",
table: "test_table",
column: "test_column",
indexName: "",
opts: nil,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例和 SAI 支援
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}
func TestDropSAIIndex_Validation(t *testing.T) {
tests := []struct {
name string
keyspace string
indexName string
wantErr bool
errMsg string
}{
{
name: "missing keyspace",
keyspace: "",
indexName: "test_idx",
wantErr: true,
errMsg: "keyspace is required",
},
{
name: "missing index name",
keyspace: "test_keyspace",
indexName: "",
wantErr: true,
errMsg: "index name is required",
},
{
name: "valid parameters",
keyspace: "test_keyspace",
indexName: "test_idx",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}
func TestListSAIIndexes_Validation(t *testing.T) {
tests := []struct {
name string
keyspace string
table string
wantErr bool
errMsg string
}{
{
name: "missing keyspace",
keyspace: "",
table: "test_table",
wantErr: true,
errMsg: "keyspace is required",
},
{
name: "missing table",
keyspace: "test_keyspace",
table: "",
wantErr: true,
errMsg: "table is required",
},
{
name: "valid parameters",
keyspace: "test_keyspace",
table: "test_table",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}
func TestCheckSAIIndexExists_Validation(t *testing.T) {
tests := []struct {
name string
keyspace string
indexName string
wantErr bool
errMsg string
}{
{
name: "missing keyspace",
keyspace: "",
indexName: "test_idx",
wantErr: true,
errMsg: "keyspace is required",
},
{
name: "missing index name",
keyspace: "test_keyspace",
indexName: "",
wantErr: true,
errMsg: "index name is required",
},
{
name: "valid parameters",
keyspace: "test_keyspace",
indexName: "test_idx",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例
// 在實際測試中,需要使用 mock 或 testcontainers
_ = tt
})
}
}
func TestSAIIndexType_Constants(t *testing.T) {
tests := []struct {
name string
indexType SAIIndexType
expected string
}{
{
name: "standard index type",
indexType: SAIIndexTypeStandard,
expected: "STANDARD",
},
{
name: "collection index type",
indexType: SAIIndexTypeCollection,
expected: "COLLECTION",
},
{
name: "full text index type",
indexType: SAIIndexTypeFullText,
expected: "FULL_TEXT",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, string(tt.indexType))
})
}
}
func TestCreateSAIIndex_NotSupported(t *testing.T) {
t.Run("should return error when SAI not supported", func(t *testing.T) {
// 注意:這需要一個不支援 SAI 的 DB 實例
// 在實際測試中,需要使用 mock 或 testcontainers
})
}
func TestCreateSAIIndex_IndexNameGeneration(t *testing.T) {
t.Run("should generate index name when not provided", func(t *testing.T) {
// 測試自動生成索引名稱的邏輯
// 格式應該是: {table}_{column}_sai_idx
table := "users"
column := "email"
expected := "users_email_sai_idx"
// 這裡只是測試命名邏輯,實際建立需要 DB 實例
generated := fmt.Sprintf("%s_%s_sai_idx", table, column)
assert.Equal(t, expected, generated)
})
}

View File

@ -0,0 +1,91 @@
package cassandra
import (
"context"
"fmt"
"strconv"
"testing"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
// startCassandraContainer 啟動 Cassandra 測試容器
func startCassandraContainer(ctx context.Context) (string, string, func(), error) {
req := testcontainers.ContainerRequest{
Image: "cassandra:4.1",
ExposedPorts: []string{"9042/tcp"},
WaitingFor: wait.ForListeningPort("9042/tcp"),
Env: map[string]string{
"CASSANDRA_CLUSTER_NAME": "test-cluster",
},
}
cassandraC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
return "", "", nil, fmt.Errorf("failed to start Cassandra container: %w", err)
}
port, err := cassandraC.MappedPort(ctx, "9042")
if err != nil {
cassandraC.Terminate(ctx)
return "", "", nil, fmt.Errorf("failed to get mapped port: %w", err)
}
host, err := cassandraC.Host(ctx)
if err != nil {
cassandraC.Terminate(ctx)
return "", "", nil, fmt.Errorf("failed to get host: %w", err)
}
tearDown := func() {
_ = cassandraC.Terminate(ctx)
}
fmt.Printf("Cassandra test container started: %s:%s\n", host, port.Port())
return host, port.Port(), tearDown, nil
}
// setupTestDB 設置測試用的 DB 實例
func setupTestDB(t testing.TB) (*DB, func()) {
ctx := context.Background()
host, port, tearDown, err := startCassandraContainer(ctx)
if err != nil {
t.Fatalf("Failed to start Cassandra container: %v", err)
}
portInt, err := strconv.Atoi(port)
if err != nil {
tearDown()
t.Fatalf("Failed to convert port to int: %v", err)
}
db, err := New(
WithHosts(host),
WithPort(portInt),
WithKeyspace("test_keyspace"),
)
if err != nil {
tearDown()
t.Fatalf("Failed to create DB: %v", err)
}
// 創建 keyspace
createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
if err := db.session.Query(createKeyspaceStmt, nil).Exec(); err != nil {
db.Close()
tearDown()
t.Fatalf("Failed to create keyspace: %v", err)
}
cleanup := func() {
db.Close()
tearDown()
}
return db, cleanup
}

View File

@ -0,0 +1,32 @@
package cassandra
import (
"github.com/gocql/gocql"
)
// Table 定義資料表模型必須實作的介面
type Table interface {
TableName() string
}
// PrimaryKey 定義主鍵類型(使用類型約束)
// 注意Go 1.18+ 才支持類型約束,如果需要兼容舊版本,可以使用 interface{}
type PrimaryKey interface {
~string | ~int | ~int64 | gocql.UUID | []byte
}
// Order 定義排序順序
type Order int
const (
ASC Order = 0
DESC Order = 1
)
// 將 Order 轉換為 toGocqlX 的 Order
func (o Order) toGocqlX() string {
if o == DESC {
return "DESC"
}
return "ASC"
}

View File

@ -0,0 +1,139 @@
package cassandra
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestOrder_ToGocqlX(t *testing.T) {
tests := []struct {
name string
order Order
expected string
}{
{
name: "ASC order",
order: ASC,
expected: "ASC",
},
{
name: "DESC order",
order: DESC,
expected: "DESC",
},
{
name: "zero value (defaults to ASC)",
order: Order(0),
expected: "ASC",
},
{
name: "invalid order value (defaults to ASC)",
order: Order(99),
expected: "ASC",
},
{
name: "negative order value (defaults to ASC)",
order: Order(-1),
expected: "ASC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.order.toGocqlX()
assert.Equal(t, tt.expected, result)
})
}
}
func TestOrder_Constants(t *testing.T) {
tests := []struct {
name string
constant Order
expected int
}{
{
name: "ASC constant",
constant: ASC,
expected: 0,
},
{
name: "DESC constant",
constant: DESC,
expected: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, int(tt.constant))
})
}
}
func TestOrder_StringConversion(t *testing.T) {
tests := []struct {
name string
order Order
expected string
}{
{
name: "ASC to string",
order: ASC,
expected: "ASC",
},
{
name: "DESC to string",
order: DESC,
expected: "DESC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.order.toGocqlX()
assert.Equal(t, tt.expected, result)
})
}
}
func TestOrder_Comparison(t *testing.T) {
t.Run("ASC should equal 0", func(t *testing.T) {
assert.Equal(t, Order(0), ASC)
})
t.Run("DESC should equal 1", func(t *testing.T) {
assert.Equal(t, Order(1), DESC)
})
t.Run("ASC should not equal DESC", func(t *testing.T) {
assert.NotEqual(t, ASC, DESC)
})
}
func TestOrder_EdgeCases(t *testing.T) {
tests := []struct {
name string
order Order
expected string
}{
{
name: "maximum int value",
order: Order(^int(0)),
expected: "ASC", // 不是 DESC所以返回 ASC
},
{
name: "minimum int value",
order: Order(-^int(0) - 1),
expected: "ASC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.order.toGocqlX()
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -0,0 +1,102 @@
package centrifugo
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// Client Centrifugo 客戶端
type Client struct {
apiURL string
apiKey string
client *http.Client
}
// NewClient 創建新的 Centrifugo 客戶端
func NewClient(apiURL, apiKey string) *Client {
return &Client{
apiURL: apiURL,
apiKey: apiKey,
client: &http.Client{
Timeout: 5 * time.Second,
},
}
}
// PublishRequest Centrifugo 發布請求
type PublishRequest struct {
Channel string `json:"channel"`
Data interface{} `json:"data"`
}
// PublishResponse Centrifugo 發布響應
type PublishResponse struct {
Error string `json:"error,omitempty"`
Result interface{} `json:"result,omitempty"`
}
// Publish 發布訊息到指定頻道
func (c *Client) Publish(channel string, data []byte) error {
req := PublishRequest{
Channel: channel,
Data: json.RawMessage(data),
}
return c.publishJSON(req)
}
// PublishJSON 發布 JSON 訊息到指定頻道
func (c *Client) PublishJSON(channel string, data interface{}) error {
req := PublishRequest{
Channel: channel,
Data: data,
}
return c.publishJSON(req)
}
func (c *Client) publishJSON(req PublishRequest) error {
body, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", c.apiURL+"/publish", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
httpReq.Header.Set("Authorization", "apikey "+c.apiKey)
}
resp, err := c.client.Do(httpReq)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("centrifugo returned status %d: %s", resp.StatusCode, string(respBody))
}
var publishResp PublishResponse
if err := json.Unmarshal(respBody, &publishResp); err != nil {
return fmt.Errorf("failed to unmarshal response: %w", err)
}
if publishResp.Error != "" {
return fmt.Errorf("centrifugo error: %s", publishResp.Error)
}
return nil
}

View File

@ -0,0 +1,55 @@
run:
timeout: 3m
issues-exit-code: 2
tests: false # 不檢查測試檔案
linters:
enable:
- govet # 官方靜態分析,抓潛在 bug
- staticcheck # 最強 bug/反模式偵測
- revive # golint 進化版,風格與註解規範
- gofmt # 風格格式化檢查
- goimports # import 排序
- errcheck # error 忽略警告
- ineffassign # 無效賦值
- unused # 未使用變數
- bodyclose # HTTP body close
- gosimple # 靜態分析簡化警告staticcheck 也包含,可選)
- typecheck # 型別檢查
- misspell # 拼字檢查
- gocritic # bug-prone code
- gosec # 資安檢查
- prealloc # slice/array 預分配
- unparam # 未使用參數
issues:
exclude-rules:
- path: _test\.go
linters:
- funlen
- goconst
- cyclop
- gocognit
- lll
- wrapcheck
- contextcheck
linters-settings:
revive:
severity: warning
rules:
- name: blank-imports
severity: error
gofmt:
simplify: true
lll:
line-length: 140
# 可自訂目錄忽略(視專案需求加上)
# skip-dirs:
# - vendor
# - third_party
# 可以設定本機與 CI 上都一致
# env:
# GOLANGCI_LINT_CACHE: ".golangci-lint-cache"

View File

@ -0,0 +1,13 @@
GOFMT ?= gofmt "-s"
GOFILES := $(shell find . -name "*.go")
.PHONY: test
test: # 進行測試
go test -v --cover ./...
.PHONY: fmt
fmt: # 格式優化
$(GOFMT) -w $(GOFILES)
goimports -w ./
golangci-lint run

View File

@ -0,0 +1,186 @@
# 錯誤碼 × HTTP 對照表
這份文件專門整理 **infra-core/errors** 的「錯誤碼 → HTTP Status」對照並提供**實務範例**。
錯誤系統採用 8 碼格式 `SSCCCDDD`
- `SS` = Scope服務/模組,兩位數)
- `CCC` = Category類別三位數影響 HTTP 狀態)
- `DDD` = Detail細節三位數自定義業務碼
> 例如:`10101000` → Scope=10、Category=101InputInvalidFormat、Detail=000。
## 目錄
- [1) 快速查表](#1-快速查表依類別整理)
- [2) 使用範例](#2-使用範例)
- [3) 小撇步與慣例](#3-小撇步與慣例)
- [4) 安裝與測試](#4-安裝與測試)
- [5) 變更日誌](#5-變更日誌)
---
## 1) 快速查表(依類別整理)
### A. InputCategory 1xx
| Category 常數 | 說明 | HTTP | 原因/說明 |
|---|---------------|:----:|---|
| `InputInvalidFormat` (101) | 無效格式 | **400 Bad Request** | 格式不符、缺欄位、型別錯。 |
| `InputNotValidImplementation` (102) | 非有效實作 | **422 Unprocessable Entity** | 語意正確但無法處理。 |
| `InputInvalidRange` (103) | 無效範圍 | **422 Unprocessable Entity** | 值超域、邊界條件不合。 |
### B. DBCategory 2xx
| Category 常數 | 說明 | HTTP | 原因/說明 |
|---|-------------|:----:|---|
| `DBError` (201) | 資料庫一般錯誤 | **500 Internal Server Error** | 後端故障/不可預期。 |
| `DBDataConvert` (202) | 資料轉換錯誤 | **422 Unprocessable Entity** | 可修正的資料問題(格式/型別轉換失敗)。 |
| `DBDuplicate` (203) | 資料重複 | **409 Conflict** | 唯一鍵衝突、重複建立。 |
### C. ResourceCategory 3xx
| Category 常數 | 說明 | HTTP | 原因/說明 |
|---|-------------------|:----:|---|
| `ResNotFound` (301) | 資源未找到 | **404 Not Found** | 目標不存在/無此 ID。 |
| `ResInvalidFormat` (302) | 無效資源格式 | **422 Unprocessable Entity** | 表示層/Schema 不符。 |
| `ResAlreadyExist` (303) | 資源已存在 | **409 Conflict** | 重複建立/命名衝突。 |
| `ResInsufficient` (304) | 資源不足 | **400 Bad Request** | 數量/容量不足(用戶可改參數再試)。 |
| `ResInsufficientPerm` (305) | 權限不足 | **403 Forbidden** | 已驗證但無權限。 |
| `ResInvalidMeasureID` (306) | 無效測量ID | **400 Bad Request** | ID 本身不合法。 |
| `ResExpired` (307) | 資源過期 | **410 Gone** | 已不可用(可於上層補 Location。 |
| `ResMigrated` (308) | 資源已遷移 | **410 Gone** | 同上,如需導引請於上層處理。 |
| `ResInvalidState` (309) | 無效狀態 | **409 Conflict** | 當前狀態不允許此操作。 |
| `ResInsufficientQuota` (310) | 配額不足 | **429 Too Many Requests** | 達配額/速率限制。 |
| `ResMultiOwner` (311) | 多所有者 | **409 Conflict** | 所有權歧異造成衝突。 |
### D. AuthCategory 5xx
| Category 常數 | 說明 | HTTP | 原因/說明 |
|---|-------------------------|:----:|---|
| `AuthUnauthorized` (501) | 未授權/未驗證 | **401 Unauthorized** | 缺 Token、無效 Token。 |
| `AuthExpired` (502) | 授權過期 | **401 Unauthorized** | Token 過期或時效失效。 |
| `AuthInvalidPosixTime` (503) | 無效 POSIX 時間 | **401 Unauthorized** | 時戳異常導致驗簽失敗。 |
| `AuthSigPayloadMismatch` (504) | 簽名與載荷不符 | **401 Unauthorized** | 驗簽失敗。 |
| `AuthForbidden` (505) | 禁止存取 | **403 Forbidden** | 已驗證但沒有操作權限。 |
### E. SystemCategory 6xx
| Category 常數 | 說明 | HTTP | 原因/說明 |
|---|---------------|:----:|---|
| `SysInternal` (601) | 系統內部錯誤 | **500 Internal Server Error** | 未預期的系統錯。 |
| `SysMaintain` (602) | 系統維護中 | **503 Service Unavailable** | 維護/停機。 |
| `SysTimeout` (603) | 系統超時 | **504 Gateway Timeout** | 下游/處理逾時。 |
| `SysTooManyRequest` (604) | 請求過多 | **429 Too Many Requests** | 節流/限流。 |
### F. PubSubCategory 7xx
| Category 常數 | 說明 | HTTP | 原因/說明 |
|---|---------|:----:|---|
| `PSuPublish` (701) | 發佈失敗 | **502 Bad Gateway** | 中介或外部匯流排錯誤。 |
| `PSuConsume` (702) | 消費失敗 | **502 Bad Gateway** | 同上。 |
| `PSuTooLarge` (703) | 訊息過大 | **413 Payload Too Large** | 封包大小超限。 |
### G. ServiceCategory 8xx
| Category 常數 | 說明 | HTTP | 原因/說明 |
|---|---------------|:----:|---|
| `SvcInternal` (801) | 服務內部錯誤 | **500 Internal Server Error** | 非基礎設施層的內錯。 |
| `SvcThirdParty` (802) | 第三方失敗 | **502 Bad Gateway** | 呼叫外部服務失敗。 |
| `SvcHTTP400` (803) | 明確指派 400 | **400 Bad Request** | 自行指定。 |
| `SvcMaintenance` (804) | 服務維護中 | **503 Service Unavailable** | 模組級維運中。 |
---
## 2) 使用範例
### 2.1 在 Handler 中回傳錯誤
```go
import (
"net/http"
errs "gitlab.supermicro.com/infra/infra-core/errors"
"gitlab.supermicro.com/infra/infra-core/errors/code"
)
func init() {
errs.Scope = code.Gateway // 設定當前服務的 Scope
}
func GetUser(w http.ResponseWriter, r *http.Request) error {
id := r.URL.Query().Get("id")
if id == "" {
return errs.InputInvalidFormatError("缺少參數: id") // 現在是 8 位碼
}
u, err := repo.Find(r.Context(), id)
switch {
case errors.Is(err, repo.ErrNotFound):
return errs.ResNotFoundError("user", id)
case err != nil:
return errs.DBErrorError("查詢使用者失敗").Wrap(err) // Wrap 內部錯誤
}
// … 寫入回應
return nil
}
// 統一寫出 HTTP 錯誤
func writeHTTP(w http.ResponseWriter, e *errs.Error) {
http.Error(w, e.Error(), e.HTTPStatus())
}
```
### 2.2 取出 Wrap 的內部錯誤
```go
if internal := e.Unwrap(); internal != nil {
log.Error("Internal error: ", internal)
}
```
### 2.3 搭配日誌裝飾器(`WithLog` / `WithLogWrap`
```go
log := logger.WithFields(errs.LogField{Key: "req_id", Val: rid})
if badInput {
return errs.WithLog(log, nil, errs.InputInvalidFormatError, "email 無效")
}
if err := repo.Save(ctx, u); err != nil {
return errs.WithLogWrap(
log,
[]errs.LogField{{Key: "entity", Val: "user"}, {Key: "op", Val: "save"}},
errs.DBErrorError,
err,
"儲存失敗",
)
}
```
### 2.4 只知道 Category+Detail 的動態場景(`EL` / `ELWrap`
```go
// 依流程動態產生
return errs.EL(log, nil, code.SysTimeout, 123, "下游逾時") // 自定義 detail=123
// 或需保留 cause
return errs.ELWrap(log, nil, code.SvcThirdParty, 456, err, "金流商失敗")
```
### 2.5 gRPC 互通
```go
// 由 *errs.Error 轉為 gRPC status
st := e.GRPCStatus() // *status.Status
// 客戶端收到 gRPC error → 轉回 *errs.Error
e := errs.FromGRPCError(grpcErr)
fmt.Println(e.DisplayCode(), e.Error()) // e.g., "10101000" "error msg"
```
### 2.6 從 8 碼反解(`FromCode`
```go
e := errs.FromCode(10101000) // 10101000
fmt.Println(e.Scope(), e.Category(), e.Detail()) // 10, 101, 000
```

View File

@ -0,0 +1,102 @@
package code
type Scope uint32 // SS (00..99)
type Category uint32 // CCC (000..999)
type Detail uint32 // DDD (000..999) // Updated to 3 digits
const (
Unset Scope = 0
CategoryMultiplier uint32 = 1000
ScopeMultiplier uint32 = 1000000
NonCode uint32 = 0
OK uint32 = 0 // Already exists, but merged for completeness; avoid duplication if needed
SUCCESSCode = "00000000"
SUCCESSMessage = "success"
)
// Boundary constants for validation
const (
MaxCategory Category = 999 // Maximum allowed category value
MaxDetail Detail = 999 // Maximum allowed detail value (updated)
DefaultCategory Category = 0
DefaultDetail Detail = 0
// Reserved values - DO NOT USE in normal operations
// These are used internally for overflow protection
ReservedMaxCategory Category = 999 // Used when category > 999
ReservedMaxDetail Detail = 999 // Used when detail > 999 (updated)
)
// New 3-digit categories (merged from original category + detail)
// Input errors (100-109)
const (
InputInvalidFormat Category = 101
InputNotValidImplementation Category = 102
InputInvalidRange Category = 103
)
// DB errors (200-209)
const (
DBError Category = 201
DBDataConvert Category = 202
DBDuplicate Category = 203
)
// Resource errors (300-399)
const (
ResNotFound Category = 301
ResInvalidFormat Category = 302
ResAlreadyExist Category = 303
ResInsufficient Category = 304
ResInsufficientPerm Category = 305
ResInvalidMeasureID Category = 306
ResExpired Category = 307
ResMigrated Category = 308
ResInvalidState Category = 309
ResInsufficientQuota Category = 310
ResMultiOwner Category = 311
)
// GRPC category
const (
CatGRPC Category = 400
)
// Auth errors (500-509)
const (
AuthUnauthorized Category = 501
AuthExpired Category = 502
AuthInvalidPosixTime Category = 503
AuthSigPayloadMismatch Category = 504
AuthForbidden Category = 505
)
// System errors (600-609)
const (
SysInternal Category = 601
SysMaintain Category = 602
SysTimeout Category = 603
SysTooManyRequest Category = 604
)
// PubSub errors (700-709)
const (
PSuPublish Category = 701
PSuConsume Category = 702
PSuTooLarge Category = 703
)
// Service errors (800-809)
const (
SvcInternal Category = 801
SvcThirdParty Category = 802
SvcHTTP400 Category = 803
SvcMaintenance Category = 804
)
const (
Gateway Scope = 10
)

View File

@ -0,0 +1,234 @@
package errs
import (
"errors"
"fmt"
"net/http"
"chat/internal/library/errors/code"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// Scope is a global variable that should be set by the service or module.
var Scope = code.Unset
// Error represents a structured error with an 8-digit code.
// The code is composed of a 2-digit scope, a 3-digit category, and a 3-digit detail.
// Format: SSCCCDDD
type Error struct {
scope uint32 // 2-digit service scope
category uint32 // 3-digit category
detail uint32 // 3-digit detail
msg string // Display message for the client
internalErr error // The actual underlying error
}
// New creates a new Error.
// It ensures that category is within 0-999 and detail is within 0-999.
func New(scope, category, detail uint32, displayMsg string) *Error {
if category > uint32(code.MaxCategory) {
category = uint32(code.ReservedMaxCategory)
}
if detail > uint32(code.MaxDetail) {
detail = uint32(code.ReservedMaxDetail)
}
return &Error{
scope: scope,
category: category,
detail: detail,
msg: displayMsg,
}
}
// Error returns the display message. This is intended for the client.
// For internal logging and debugging, use Unwrap() to get the underlying error.
func (e *Error) Error() string {
if e == nil {
return ""
}
return e.msg
}
// Scope returns the 2-digit scope of the error.
func (e *Error) Scope() uint32 {
if e == nil {
return uint32(code.Unset)
}
return e.scope
}
// Category returns the 3-digit category of the error.
func (e *Error) Category() uint32 {
if e == nil {
return uint32(code.DefaultCategory)
}
return e.category
}
// Detail returns the 2-digit detail code of the error.
func (e *Error) Detail() uint32 {
if e == nil {
return uint32(code.DefaultDetail)
}
return e.detail
}
// SubCode returns the 6-digit code (category + detail).
func (e *Error) SubCode() uint32 {
if e == nil {
return code.OK
}
c := e.category*code.CategoryMultiplier + e.detail
return c
}
// Code returns the full 8-digit error code (scope + category + detail).
func (e *Error) Code() uint32 {
if e == nil {
return code.NonCode
}
return e.Scope()*code.ScopeMultiplier + e.SubCode()
}
// DisplayCode returns the 8-digit error code as a zero-padded string.
func (e *Error) DisplayCode() string {
if e == nil {
return "00000000"
}
return fmt.Sprintf("%08d", e.Code())
}
// Is checks if the target error is of type *Error and has the same sub-code.
// It is called by errors.Is(). Do not use it directly.
func (e *Error) Is(target error) bool {
var err *Error
if !errors.As(target, &err) {
return false
}
return e.SubCode() == err.SubCode()
}
// Unwrap returns the underlying wrapped error.
func (e *Error) Unwrap() error {
if e == nil {
return nil
}
return e.internalErr
}
// Wrap sets the internal error for the current error.
func (e *Error) Wrap(internalErr error) *Error {
if e != nil {
e.internalErr = internalErr
}
return e
}
// GRPCStatus converts the error to a gRPC status.
func (e *Error) GRPCStatus() *status.Status {
if e == nil {
return status.New(codes.OK, "")
}
return status.New(codes.Code(e.Code()), e.Error())
}
// HTTPStatus returns the corresponding HTTP status code for the error.
func (e *Error) HTTPStatus() int {
if e == nil || e.SubCode() == code.OK {
return http.StatusOK
}
switch e.Category() {
// Input
case uint32(code.InputInvalidFormat):
return http.StatusBadRequest // 400輸入格式錯
case uint32(code.InputNotValidImplementation),
uint32(code.InputInvalidRange):
return http.StatusUnprocessableEntity // 422語意正確但無法處理範圍/實作)
// DB
case uint32(code.DBError):
return http.StatusInternalServerError // 500後端暫時性故障若你偏好 503 可自行調整)
case uint32(code.DBDataConvert):
return http.StatusUnprocessableEntity // 422可修正的資料轉換失敗
case uint32(code.DBDuplicate):
return http.StatusConflict // 409唯一鍵/重複
// Resource
case uint32(code.ResNotFound):
return http.StatusNotFound // 404資源不存在
case uint32(code.ResInvalidFormat):
return http.StatusUnprocessableEntity // 422資源表示/格式不符
case uint32(code.ResAlreadyExist):
return http.StatusConflict // 409已存在
case uint32(code.ResInsufficient):
return http.StatusBadRequest // 400數量/容量/條件不足(可由客戶端修正)
case uint32(code.ResInsufficientPerm):
return http.StatusForbidden // 403資源層面的權限不足
case uint32(code.ResInvalidMeasureID):
return http.StatusBadRequest // 400ID 無效
case uint32(code.ResExpired):
return http.StatusGone // 410資源已過期/不可用
case uint32(code.ResMigrated):
return http.StatusGone // 410已遷移若需導引可由上層加 Location
case uint32(code.ResInvalidState):
return http.StatusConflict // 409目前狀態不允許此操作
case uint32(code.ResInsufficientQuota):
return http.StatusTooManyRequests // 429配額不足/達上限
case uint32(code.ResMultiOwner):
return http.StatusConflict // 409多所有者衝突
// Auth
case uint32(code.AuthUnauthorized),
uint32(code.AuthExpired),
uint32(code.AuthInvalidPosixTime),
uint32(code.AuthSigPayloadMismatch):
return http.StatusUnauthorized // 401未驗證/無效憑證
case uint32(code.AuthForbidden):
return http.StatusForbidden // 403有身分但沒權限
// System
case uint32(code.SysTooManyRequest):
return http.StatusTooManyRequests // 429節流
case uint32(code.SysInternal):
return http.StatusInternalServerError // 500系統內部錯
case uint32(code.SysMaintain):
return http.StatusServiceUnavailable // 503維護中
case uint32(code.SysTimeout):
return http.StatusGatewayTimeout // 504處理/下游逾時
// PubSub
case uint32(code.PSuPublish),
uint32(code.PSuConsume):
return http.StatusBadGateway // 502訊息中介/外部匯流排失敗
case uint32(code.PSuTooLarge):
return http.StatusRequestEntityTooLarge // 413訊息太大
// Service
case uint32(code.SvcMaintenance):
return http.StatusServiceUnavailable // 503服務維護
case uint32(code.SvcInternal):
return http.StatusInternalServerError // 500服務內部錯
case uint32(code.SvcThirdParty):
return http.StatusBadGateway // 502第三方依賴失敗
case uint32(code.SvcHTTP400):
return http.StatusBadRequest // 400明確指派 400
}
// fallback
return http.StatusInternalServerError
}

View File

@ -0,0 +1,216 @@
package errs
import (
"errors"
"net/http"
"testing"
"chat/internal/library/errors/code"
"google.golang.org/grpc/codes"
)
func TestNew(t *testing.T) {
tests := []struct {
name string
scope uint32
category uint32
detail uint32
displayMsg string
wantScope uint32
wantCategory uint32
wantDetail uint32
wantMsg string
}{
{"basic", 10, 201, 123, "test", 10, 201, 123, "test"},
{"clamp category", 10, 1000, 0, "clamp cat", 10, 999, 0, "clamp cat"},
{"clamp detail", 10, 101, 1000, "clamp det", 10, 101, 999, "clamp det"},
{"zero values", 0, 0, 0, "", 0, 0, 0, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := New(tt.scope, tt.category, tt.detail, tt.displayMsg)
if e.Scope() != tt.wantScope || e.Category() != tt.wantCategory || e.Detail() != tt.wantDetail || e.msg != tt.wantMsg {
t.Errorf("New() = %+v, want scope=%d cat=%d det=%d msg=%q", e, tt.wantScope, tt.wantCategory, tt.wantDetail, tt.wantMsg)
}
})
}
}
func TestErrorMethods(t *testing.T) {
e := New(10, 201, 123, "test error")
tests := []struct {
name string
err *Error
wantErr string
wantScope uint32
wantCat uint32
wantDet uint32
}{
{"non-nil", e, "test error", 10, 201, 123},
{"nil", nil, "", uint32(code.Unset), uint32(code.DefaultCategory), uint32(code.DefaultDetail)}, // Adjust if Default* not defined; use 0
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.err.Error(); got != tt.wantErr {
t.Errorf("Error() = %q, want %q", got, tt.wantErr)
}
if got := tt.err.Scope(); got != tt.wantScope {
t.Errorf("Scope() = %d, want %d", got, tt.wantScope)
}
if got := tt.err.Category(); got != tt.wantCat {
t.Errorf("Category() = %d, want %d", got, tt.wantCat)
}
if got := tt.err.Detail(); got != tt.wantDet {
t.Errorf("Detail() = %d, want %d", got, tt.wantDet)
}
})
}
}
func TestCodes(t *testing.T) {
tests := []struct {
name string
err *Error
wantSubCode uint32
wantCode uint32
wantDisplay string
}{
{"basic", New(10, 201, 123, ""), 201123, 10201123, "10201123"},
{"nil", nil, code.OK, code.NonCode, "00000000"},
{"max clamp", New(99, 999, 999, ""), 999999, 99999999, "99999999"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.err.SubCode(); got != tt.wantSubCode {
t.Errorf("SubCode() = %d, want %d", got, tt.wantSubCode)
}
if got := tt.err.Code(); got != tt.wantCode {
t.Errorf("Code() = %d, want %d", got, tt.wantCode)
}
if got := tt.err.DisplayCode(); got != tt.wantDisplay {
t.Errorf("DisplayCode() = %q, want %q", got, tt.wantDisplay)
}
})
}
}
func TestIs(t *testing.T) {
e1 := New(10, 201, 123, "")
e2 := New(10, 201, 123, "") // same subcode
e3 := New(10, 202, 123, "") // different category
stdErr := errors.New("std")
tests := []struct {
name string
err error
target error
want bool
}{
{"match", e1, e2, true},
{"mismatch", e1, e3, false},
{"not Error type", e1, stdErr, false},
{"nil err", nil, e2, false},
{"nil target", e1, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := errors.Is(tt.err, tt.target); got != tt.want {
t.Errorf("Is() = %v, want %v", got, tt.want)
}
})
}
}
func TestWrapUnwrap(t *testing.T) {
internal := errors.New("internal")
tests := []struct {
name string
err *Error
wrapErr error
wantUnwrap error
}{
{"wrap non-nil", New(10, 201, 0, ""), internal, internal},
{"wrap nil", nil, internal, nil}, // Wrap on nil does nothing
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.err.Wrap(tt.wrapErr)
if unwrapped := got.Unwrap(); unwrapped != tt.wantUnwrap {
t.Errorf("Unwrap() = %v, want %v", unwrapped, tt.wantUnwrap)
}
})
}
}
func TestGRPCStatus(t *testing.T) {
tests := []struct {
name string
err *Error
wantCode codes.Code
wantMsg string
}{
{"non-nil", New(10, 201, 123, "grpc err"), codes.Code(10201123), "grpc err"},
{"nil", nil, codes.OK, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := tt.err.GRPCStatus()
if s.Code() != tt.wantCode || s.Message() != tt.wantMsg {
t.Errorf("GRPCStatus() = code=%v msg=%q, want code=%v msg=%q", s.Code(), s.Message(), tt.wantCode, tt.wantMsg)
}
})
}
}
func TestHTTPStatus(t *testing.T) {
tests := []struct {
name string
err *Error
want int
}{
{"nil", nil, http.StatusOK},
{"OK subcode", New(10, 0, 0, ""), http.StatusOK},
{"InputInvalidFormat", New(10, uint32(code.InputInvalidFormat), 0, ""), http.StatusBadRequest},
{"InputNotValidImplementation", New(10, uint32(code.InputNotValidImplementation), 0, ""), http.StatusUnprocessableEntity},
{"DBError", New(10, uint32(code.DBError), 0, ""), http.StatusInternalServerError},
{"ResNotFound", New(10, uint32(code.ResNotFound), 0, ""), http.StatusNotFound},
// Add all other categories to cover switch branches
{"InputInvalidRange", New(10, uint32(code.InputInvalidRange), 0, ""), http.StatusUnprocessableEntity},
{"DBDataConvert", New(10, uint32(code.DBDataConvert), 0, ""), http.StatusUnprocessableEntity},
{"DBDuplicate", New(10, uint32(code.DBDuplicate), 0, ""), http.StatusConflict},
{"ResInvalidFormat", New(10, uint32(code.ResInvalidFormat), 0, ""), http.StatusUnprocessableEntity},
{"ResAlreadyExist", New(10, uint32(code.ResAlreadyExist), 0, ""), http.StatusConflict},
{"ResInsufficient", New(10, uint32(code.ResInsufficient), 0, ""), http.StatusBadRequest},
{"ResInsufficientPerm", New(10, uint32(code.ResInsufficientPerm), 0, ""), http.StatusForbidden},
{"ResInvalidMeasureID", New(10, uint32(code.ResInvalidMeasureID), 0, ""), http.StatusBadRequest},
{"ResExpired", New(10, uint32(code.ResExpired), 0, ""), http.StatusGone},
{"ResMigrated", New(10, uint32(code.ResMigrated), 0, ""), http.StatusGone},
{"ResInvalidState", New(10, uint32(code.ResInvalidState), 0, ""), http.StatusConflict},
{"ResInsufficientQuota", New(10, uint32(code.ResInsufficientQuota), 0, ""), http.StatusTooManyRequests},
{"ResMultiOwner", New(10, uint32(code.ResMultiOwner), 0, ""), http.StatusConflict},
{"AuthUnauthorized", New(10, uint32(code.AuthUnauthorized), 0, ""), http.StatusUnauthorized},
{"AuthExpired", New(10, uint32(code.AuthExpired), 0, ""), http.StatusUnauthorized},
{"AuthInvalidPosixTime", New(10, uint32(code.AuthInvalidPosixTime), 0, ""), http.StatusUnauthorized},
{"AuthSigPayloadMismatch", New(10, uint32(code.AuthSigPayloadMismatch), 0, ""), http.StatusUnauthorized},
{"AuthForbidden", New(10, uint32(code.AuthForbidden), 0, ""), http.StatusForbidden},
{"SysTooManyRequest", New(10, uint32(code.SysTooManyRequest), 0, ""), http.StatusTooManyRequests},
{"SysInternal", New(10, uint32(code.SysInternal), 0, ""), http.StatusInternalServerError},
{"SysMaintain", New(10, uint32(code.SysMaintain), 0, ""), http.StatusServiceUnavailable},
{"SysTimeout", New(10, uint32(code.SysTimeout), 0, ""), http.StatusGatewayTimeout},
{"PSuPublish", New(10, uint32(code.PSuPublish), 0, ""), http.StatusBadGateway},
{"PSuConsume", New(10, uint32(code.PSuConsume), 0, ""), http.StatusBadGateway},
{"PSuTooLarge", New(10, uint32(code.PSuTooLarge), 0, ""), http.StatusRequestEntityTooLarge},
{"SvcMaintenance", New(10, uint32(code.SvcMaintenance), 0, ""), http.StatusServiceUnavailable},
{"SvcInternal", New(10, uint32(code.SvcInternal), 0, ""), http.StatusInternalServerError},
{"SvcThirdParty", New(10, uint32(code.SvcThirdParty), 0, ""), http.StatusBadGateway},
{"SvcHTTP400", New(10, uint32(code.SvcHTTP400), 0, ""), http.StatusBadRequest},
{"fallback unknown", New(10, 999, 0, ""), http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.err.HTTPStatus(); got != tt.want {
t.Errorf("HTTPStatus() = %d, want %d", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,467 @@
package errs
import (
"strings"
"chat/internal/library/errors/code"
)
/* =========================
日誌介面與你現有 Logger 對齊
========================= */
// Logger 你現有的 logger 介面(與外部一致即可)
type Logger interface {
WithCallerSkip(n int) Logger
WithFields(fields ...LogField) Logger
Error(msg string)
Warn(msg string)
Info(msg string)
}
// LogField 結構化欄位
type LogField struct {
Key string
Val any
}
/* =========================
共用小工具
========================= */
// joinMsg把可變參數字串用空白串接避免到處判斷 nil / 空 slice
func joinMsg(s []string) string {
if len(s) == 0 {
return ""
}
return strings.Join(s, " ")
}
// logErr統一打一筆 error log避免重複記錄
func logErr(l Logger, fields []LogField, e *Error) {
if l == nil || e == nil {
return
}
ll := l.WithCallerSkip(1)
if len(fields) > 0 {
ll = ll.WithFields(fields...)
}
// 需要更多欄位可在此擴充例如e.DisplayCode()、e.Category()、e.Detail()
ll.Error(e.Error())
}
/* =========================
共用裝飾器把任意 ez 建構器包成帶日誌版本
========================= */
// WithLog 將任一 *Error 建構器(如 SysTimeoutError轉成帶日誌的版本
func WithLog(l Logger, fields []LogField, ctor func(s ...string) *Error, s ...string) *Error {
e := ctor(s...)
logErr(l, fields, e)
return e
}
// WithLogWrap 同上,但會同時 Wrap 內部 cause
func WithLogWrap(l Logger, fields []LogField, ctor func(s ...string) *Error, cause error, s ...string) *Error {
e := ctor(s...).Wrap(cause)
logErr(l, fields, e)
return e
}
/* =========================
泛用建構器當你懶得記函式名時
========================= */
// EL 依 Category/Detail 直接建構並記錄日誌
func EL(l Logger, fields []LogField, cat code.Category, det code.Detail, s ...string) *Error {
e := New(uint32(Scope), uint32(cat), uint32(det), joinMsg(s))
logErr(l, fields, e)
return e
}
// ELWrap 同上,並 Wrap cause
func ELWrap(l Logger, fields []LogField, cat code.Category, det code.Detail, cause error, s ...string) *Error {
e := New(uint32(Scope), uint32(cat), uint32(det), joinMsg(s)).Wrap(cause)
logErr(l, fields, e)
return e
}
/* =======================================================================
基礎 ez 建構器純建構 *Error不帶日誌
分類順序Input DB Resource Auth System PubSub Service
======================================================================= */
/* ----- Input (CatInput) ----- */
func InputInvalidFormatError(s ...string) *Error {
return New(uint32(Scope), uint32(code.InputInvalidFormat), 0, joinMsg(s))
}
func InputNotValidImplementationError(s ...string) *Error {
return New(uint32(Scope), uint32(code.InputNotValidImplementation), 0, joinMsg(s))
}
func InputInvalidRangeError(s ...string) *Error {
return New(uint32(Scope), uint32(code.InputInvalidRange), 0, joinMsg(s))
}
/* ----- DB (CatDB) ----- */
func DBErrorError(s ...string) *Error {
return New(uint32(Scope), uint32(code.DBError), 0, joinMsg(s))
}
func DBDataConvertError(s ...string) *Error {
return New(uint32(Scope), uint32(code.DBDataConvert), 0, joinMsg(s))
}
func DBDuplicateError(s ...string) *Error {
return New(uint32(Scope), uint32(code.DBDuplicate), 0, joinMsg(s))
}
/* ----- Resource (CatResource) ----- */
func ResNotFoundError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResNotFound), 0, joinMsg(s))
}
func ResInvalidFormatError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResInvalidFormat), 0, joinMsg(s))
}
func ResAlreadyExistError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResAlreadyExist), 0, joinMsg(s))
}
func ResInsufficientError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResInsufficient), 0, joinMsg(s))
}
func ResInsufficientPermError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResInsufficientPerm), 0, joinMsg(s))
}
func ResInvalidMeasureIDError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResInvalidMeasureID), 0, joinMsg(s))
}
func ResExpiredError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResExpired), 0, joinMsg(s))
}
func ResMigratedError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResMigrated), 0, joinMsg(s))
}
func ResInvalidStateError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResInvalidState), 0, joinMsg(s))
}
func ResInsufficientQuotaError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResInsufficientQuota), 0, joinMsg(s))
}
func ResMultiOwnerError(s ...string) *Error {
return New(uint32(Scope), uint32(code.ResMultiOwner), 0, joinMsg(s))
}
/* ----- Auth (CatAuth) ----- */
func AuthUnauthorizedError(s ...string) *Error {
return New(uint32(Scope), uint32(code.AuthUnauthorized), 0, joinMsg(s))
}
func AuthExpiredError(s ...string) *Error {
return New(uint32(Scope), uint32(code.AuthExpired), 0, joinMsg(s))
}
func AuthInvalidPosixTimeError(s ...string) *Error {
return New(uint32(Scope), uint32(code.AuthInvalidPosixTime), 0, joinMsg(s))
}
func AuthSigPayloadMismatchError(s ...string) *Error {
return New(uint32(Scope), uint32(code.AuthSigPayloadMismatch), 0, joinMsg(s))
}
func AuthForbiddenError(s ...string) *Error {
return New(uint32(Scope), uint32(code.AuthForbidden), 0, joinMsg(s))
}
/* ----- System (CatSystem) ----- */
func SysInternalError(s ...string) *Error {
return New(uint32(Scope), uint32(code.SysInternal), 0, joinMsg(s))
}
func SysMaintainError(s ...string) *Error {
return New(uint32(Scope), uint32(code.SysMaintain), 0, joinMsg(s))
}
func SysTimeoutError(s ...string) *Error {
return New(uint32(Scope), uint32(code.SysTimeout), 0, joinMsg(s))
}
func SysTooManyRequestError(s ...string) *Error {
return New(uint32(Scope), uint32(code.SysTooManyRequest), 0, joinMsg(s))
}
/* ----- PubSub (CatPubSub) ----- */
func PSuPublishError(s ...string) *Error {
return New(uint32(Scope), uint32(code.PSuPublish), 0, joinMsg(s))
}
func PSuConsumeError(s ...string) *Error {
return New(uint32(Scope), uint32(code.PSuConsume), 0, joinMsg(s))
}
func PSuTooLargeError(s ...string) *Error {
return New(uint32(Scope), uint32(code.PSuTooLarge), 0, joinMsg(s))
}
/* ----- Service (CatService) ----- */
func SvcInternalError(s ...string) *Error {
return New(uint32(Scope), uint32(code.SvcInternal), 0, joinMsg(s))
}
func SvcThirdPartyError(s ...string) *Error {
return New(uint32(Scope), uint32(code.SvcThirdParty), 0, joinMsg(s))
}
func SvcHTTP400Error(s ...string) *Error {
return New(uint32(Scope), uint32(code.SvcHTTP400), 0, joinMsg(s))
}
func SvcMaintenanceError(s ...string) *Error {
return New(uint32(Scope), uint32(code.SvcMaintenance), 0, joinMsg(s))
}
/* =============================================================================
帶日誌版本L / WrapL基礎 ez 建構器之上包裝 WithLog / WithLogWrap
分類順序同上Input DB Resource Auth System PubSub Service
============================================================================= */
/* ----- Input (CatInput) ----- */
func InputInvalidFormatErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, InputInvalidFormatError, s...)
}
func InputInvalidFormatErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, InputInvalidFormatError, cause, s...)
}
func InputNotValidImplementationErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, InputNotValidImplementationError, s...)
}
func InputNotValidImplementationErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, InputNotValidImplementationError, cause, s...)
}
func InputInvalidRangeErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, InputInvalidRangeError, s...)
}
func InputInvalidRangeErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, InputInvalidRangeError, cause, s...)
}
/* ----- DB (CatDB) ----- */
func DBErrorErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, DBErrorError, s...)
}
func DBErrorErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, DBErrorError, cause, s...)
}
func DBDataConvertErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, DBDataConvertError, s...)
}
func DBDataConvertErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, DBDataConvertError, cause, s...)
}
func DBDuplicateErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, DBDuplicateError, s...)
}
func DBDuplicateErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, DBDuplicateError, cause, s...)
}
/* ----- Resource (CatResource) ----- */
func ResNotFoundErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResNotFoundError, s...)
}
func ResNotFoundErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResNotFoundError, cause, s...)
}
func ResInvalidFormatErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResInvalidFormatError, s...)
}
func ResInvalidFormatErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResInvalidFormatError, cause, s...)
}
func ResAlreadyExistErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResAlreadyExistError, s...)
}
func ResAlreadyExistErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResAlreadyExistError, cause, s...)
}
func ResInsufficientErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResInsufficientError, s...)
}
func ResInsufficientErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResInsufficientError, cause, s...)
}
func ResInsufficientPermErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResInsufficientPermError, s...)
}
func ResInsufficientPermErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResInsufficientPermError, cause, s...)
}
func ResInvalidMeasureIDErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResInvalidMeasureIDError, s...)
}
func ResInvalidMeasureIDErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResInvalidMeasureIDError, cause, s...)
}
func ResExpiredErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResExpiredError, s...)
}
func ResExpiredErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResExpiredError, cause, s...)
}
func ResMigratedErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResMigratedError, s...)
}
func ResMigratedErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResMigratedError, cause, s...)
}
func ResInvalidStateErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResInvalidStateError, s...)
}
func ResInvalidStateErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResInvalidStateError, cause, s...)
}
func ResInsufficientQuotaErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResInsufficientQuotaError, s...)
}
func ResInsufficientQuotaErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResInsufficientQuotaError, cause, s...)
}
func ResMultiOwnerErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, ResMultiOwnerError, s...)
}
func ResMultiOwnerErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, ResMultiOwnerError, cause, s...)
}
/* ----- Auth (CatAuth) ----- */
func AuthUnauthorizedErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, AuthUnauthorizedError, s...)
}
func AuthUnauthorizedErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, AuthUnauthorizedError, cause, s...)
}
func AuthExpiredErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, AuthExpiredError, s...)
}
func AuthExpiredErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, AuthExpiredError, cause, s...)
}
func AuthInvalidPosixTimeErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, AuthInvalidPosixTimeError, s...)
}
func AuthInvalidPosixTimeErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, AuthInvalidPosixTimeError, cause, s...)
}
func AuthSigPayloadMismatchErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, AuthSigPayloadMismatchError, s...)
}
func AuthSigPayloadMismatchErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, AuthSigPayloadMismatchError, cause, s...)
}
func AuthForbiddenErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, AuthForbiddenError, s...)
}
func AuthForbiddenErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, AuthForbiddenError, cause, s...)
}
/* ----- System (CatSystem) ----- */
func SysInternalErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, SysInternalError, s...)
}
func SysInternalErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, SysInternalError, cause, s...)
}
func SysMaintainErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, SysMaintainError, s...)
}
func SysMaintainErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, SysMaintainError, cause, s...)
}
func SysTimeoutErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, SysTimeoutError, s...)
}
func SysTimeoutErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, SysTimeoutError, cause, s...)
}
func SysTooManyRequestErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, SysTooManyRequestError, s...)
}
func SysTooManyRequestErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, SysTooManyRequestError, cause, s...)
}
/* ----- PubSub (CatPubSub) ----- */
func PSuPublishErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, PSuPublishError, s...)
}
func PSuPublishErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, PSuPublishError, cause, s...)
}
func PSuConsumeErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, PSuConsumeError, s...)
}
func PSuConsumeErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, PSuConsumeError, cause, s...)
}
func PSuTooLargeErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, PSuTooLargeError, s...)
}
func PSuTooLargeErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, PSuTooLargeError, cause, s...)
}
/* ----- Service (CatService) ----- */
func SvcInternalErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, SvcInternalError, s...)
}
func SvcInternalErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, SvcInternalError, cause, s...)
}
func SvcThirdPartyErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, SvcThirdPartyError, s...)
}
func SvcThirdPartyErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, SvcThirdPartyError, cause, s...)
}
func SvcHTTP400ErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, SvcHTTP400Error, s...)
}
func SvcHTTP400ErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, SvcHTTP400Error, cause, s...)
}
func SvcMaintenanceErrorL(l Logger, fields []LogField, s ...string) *Error {
return WithLog(l, fields, SvcMaintenanceError, s...)
}
func SvcMaintenanceErrorWrapL(l Logger, fields []LogField, cause error, s ...string) *Error {
return WithLogWrap(l, fields, SvcMaintenanceError, cause, s...)
}

View File

@ -0,0 +1,347 @@
package errs
import (
"errors"
"reflect"
"testing"
"chat/internal/library/errors/code"
)
// fakeLogger as before
type fakeLogger struct {
calls []string
lastMsg string
fieldsStack [][]LogField
callerSkips []int
}
func (l *fakeLogger) WithCallerSkip(n int) Logger { l.callerSkips = append(l.callerSkips, n); return l }
func (l *fakeLogger) WithFields(fields ...LogField) Logger {
cp := make([]LogField, len(fields))
copy(cp, fields)
l.fieldsStack = append(l.fieldsStack, cp)
return l
}
func (l *fakeLogger) Error(msg string) { l.calls = append(l.calls, "ERROR"); l.lastMsg = msg }
func (l *fakeLogger) Warn(msg string) { l.calls = append(l.calls, "WARN"); l.lastMsg = msg }
func (l *fakeLogger) Info(msg string) { l.calls = append(l.calls, "INFO"); l.lastMsg = msg }
func (l *fakeLogger) reset() {
l.calls, l.lastMsg, l.fieldsStack, l.callerSkips = nil, "", nil, nil
}
func init() { Scope = code.Gateway }
func TestJoinMsg(t *testing.T) {
tests := []struct {
name string
in []string
want string
}{
{"nil", nil, ""},
{"empty", []string{}, ""},
{"single", []string{"a"}, "a"},
{"multi", []string{"a", "b", "c"}, "a b c"},
{"with spaces", []string{"hello", "world"}, "hello world"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := joinMsg(tt.in); got != tt.want {
t.Errorf("joinMsg() = %q, want %q", got, tt.want)
}
})
}
}
func TestLogErr(t *testing.T) {
tests := []struct {
name string
l Logger
fields []LogField
e *Error
wantCall string
wantFields bool
wantCallerSkip int
}{
{"nil logger", nil, nil, New(10, 101, 0, "err"), "", false, 0},
{"nil error", &fakeLogger{}, nil, nil, "", false, 0},
{"basic log", &fakeLogger{}, nil, New(10, 101, 0, "err"), "ERROR", false, 1},
{"with fields", &fakeLogger{}, []LogField{{Key: "k", Val: "v"}}, New(10, 101, 0, "err"), "ERROR", true, 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var fl *fakeLogger
if tt.l != nil {
var ok bool
fl, ok = tt.l.(*fakeLogger)
if !ok {
t.Fatalf("logger is not *fakeLogger")
}
fl.reset()
}
logErr(tt.l, tt.fields, tt.e)
if fl == nil {
if tt.wantCall != "" {
t.Errorf("expected log but logger is nil")
}
return
}
if tt.wantCall == "" && len(fl.calls) > 0 {
t.Errorf("unexpected log call")
}
if tt.wantCall != "" && (len(fl.calls) == 0 || fl.calls[0] != tt.wantCall) {
t.Errorf("expected call %q, got %v", tt.wantCall, fl.calls)
}
if tt.wantFields && (len(fl.fieldsStack) == 0 || !reflect.DeepEqual(fl.fieldsStack[0], tt.fields)) {
t.Errorf("fields mismatch: got %v, want %v", fl.fieldsStack, tt.fields)
}
if tt.wantCallerSkip != 0 && (len(fl.callerSkips) == 0 || fl.callerSkips[0] != tt.wantCallerSkip) {
t.Errorf("callerSkip = %v, want %d", fl.callerSkips, tt.wantCallerSkip)
}
})
}
}
func TestEL(t *testing.T) {
tests := []struct {
name string
cat code.Category
det code.Detail
s []string
wantCat uint32
wantDet uint32
wantMsg string
wantLog bool
}{
{"basic", code.ResNotFound, 123, []string{"not found"}, uint32(code.ResNotFound), 123, "not found", true},
{"nil logger", code.ResNotFound, 0, []string{}, uint32(code.ResNotFound), 0, "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &fakeLogger{}
e := EL(l, nil, tt.cat, tt.det, tt.s...)
if e.Category() != tt.wantCat || e.Detail() != tt.wantDet || e.Error() != tt.wantMsg {
t.Errorf("EL = cat=%d det=%d msg=%q, want %d %d %q", e.Category(), e.Detail(), e.Error(), tt.wantCat, tt.wantDet, tt.wantMsg)
}
if tt.wantLog && len(l.calls) == 0 {
t.Errorf("expected log")
}
})
}
}
func TestELWrap(t *testing.T) {
tests := []struct {
name string
cat code.Category
det code.Detail
cause error
s []string
wantCat uint32
wantDet uint32
wantMsg string
wantUnwrap string
wantLog bool
}{
{"basic", code.SysInternal, 456, errors.New("internal"), []string{"sys err"}, uint32(code.SysInternal), 456, "sys err", "internal", true},
{"no log", code.SysInternal, 0, nil, []string{}, uint32(code.SysInternal), 0, "", "", false}, // nil cause ok
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &fakeLogger{}
e := ELWrap(l, nil, tt.cat, tt.det, tt.cause, tt.s...)
if e.Category() != tt.wantCat || e.Detail() != tt.wantDet || e.Error() != tt.wantMsg {
t.Errorf("ELWrap = cat=%d det=%d msg=%q, want %d %d %q", e.Category(), e.Detail(), e.Error(), tt.wantCat, tt.wantDet, tt.wantMsg)
}
unw := e.Unwrap()
gotUnwrap := ""
if unw != nil {
gotUnwrap = unw.Error()
}
if gotUnwrap != tt.wantUnwrap {
t.Errorf("Unwrap = %q, want %q", gotUnwrap, tt.wantUnwrap)
}
if tt.wantLog && len(l.calls) == 0 {
t.Errorf("expected log")
}
})
}
}
// Expand TestBaseConstructors with all base funcs
func TestBaseConstructors(t *testing.T) {
tests := []struct {
name string
fn func(...string) *Error
wantCat uint32
wantDet uint32
wantMsg string
}{
{"InputInvalidFormatError", InputInvalidFormatError, uint32(code.InputInvalidFormat), 0, "test msg"},
{"InputNotValidImplementationError", InputNotValidImplementationError, uint32(code.InputNotValidImplementation), 0, "test msg"},
{"InputInvalidRangeError", InputInvalidRangeError, uint32(code.InputInvalidRange), 0, "test msg"},
{"DBErrorError", DBErrorError, uint32(code.DBError), 0, "test msg"},
{"DBDataConvertError", DBDataConvertError, uint32(code.DBDataConvert), 0, "test msg"},
{"DBDuplicateError", DBDuplicateError, uint32(code.DBDuplicate), 0, "test msg"},
{"ResNotFoundError", ResNotFoundError, uint32(code.ResNotFound), 0, "test msg"},
{"ResInvalidFormatError", ResInvalidFormatError, uint32(code.ResInvalidFormat), 0, "test msg"},
{"ResAlreadyExistError", ResAlreadyExistError, uint32(code.ResAlreadyExist), 0, "test msg"},
{"ResInsufficientError", ResInsufficientError, uint32(code.ResInsufficient), 0, "test msg"},
{"ResInsufficientPermError", ResInsufficientPermError, uint32(code.ResInsufficientPerm), 0, "test msg"},
{"ResInvalidMeasureIDError", ResInvalidMeasureIDError, uint32(code.ResInvalidMeasureID), 0, "test msg"},
{"ResExpiredError", ResExpiredError, uint32(code.ResExpired), 0, "test msg"},
{"ResMigratedError", ResMigratedError, uint32(code.ResMigrated), 0, "test msg"},
{"ResInvalidStateError", ResInvalidStateError, uint32(code.ResInvalidState), 0, "test msg"},
{"ResInsufficientQuotaError", ResInsufficientQuotaError, uint32(code.ResInsufficientQuota), 0, "test msg"},
{"ResMultiOwnerError", ResMultiOwnerError, uint32(code.ResMultiOwner), 0, "test msg"},
{"AuthUnauthorizedError", AuthUnauthorizedError, uint32(code.AuthUnauthorized), 0, "test msg"},
{"AuthExpiredError", AuthExpiredError, uint32(code.AuthExpired), 0, "test msg"},
{"AuthInvalidPosixTimeError", AuthInvalidPosixTimeError, uint32(code.AuthInvalidPosixTime), 0, "test msg"},
{"AuthSigPayloadMismatchError", AuthSigPayloadMismatchError, uint32(code.AuthSigPayloadMismatch), 0, "test msg"},
{"AuthForbiddenError", AuthForbiddenError, uint32(code.AuthForbidden), 0, "test msg"},
{"SysInternalError", SysInternalError, uint32(code.SysInternal), 0, "test msg"},
{"SysMaintainError", SysMaintainError, uint32(code.SysMaintain), 0, "test msg"},
{"SysTimeoutError", SysTimeoutError, uint32(code.SysTimeout), 0, "test msg"},
{"SysTooManyRequestError", SysTooManyRequestError, uint32(code.SysTooManyRequest), 0, "test msg"},
{"PSuPublishError", PSuPublishError, uint32(code.PSuPublish), 0, "test msg"},
{"PSuConsumeError", PSuConsumeError, uint32(code.PSuConsume), 0, "test msg"},
{"PSuTooLargeError", PSuTooLargeError, uint32(code.PSuTooLarge), 0, "test msg"},
{"SvcInternalError", SvcInternalError, uint32(code.SvcInternal), 0, "test msg"},
{"SvcThirdPartyError", SvcThirdPartyError, uint32(code.SvcThirdParty), 0, "test msg"},
{"SvcHTTP400Error", SvcHTTP400Error, uint32(code.SvcHTTP400), 0, "test msg"},
{"SvcMaintenanceError", SvcMaintenanceError, uint32(code.SvcMaintenance), 0, "test msg"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := tt.fn("test", "msg")
if e == nil || e.Category() != tt.wantCat || e.Detail() != tt.wantDet || e.Error() != "test msg" {
t.Errorf("%s = cat=%d det=%d msg=%q, want %d 0 %q", tt.name, e.Category(), e.Detail(), e.Error(), tt.wantCat, "test msg")
}
})
}
}
// Expand TestLConstructors with all L funcs
func TestLConstructors(t *testing.T) {
tests := []struct {
name string
fn func(Logger, []LogField, ...string) *Error
wantCat uint32
wantDet uint32
wantMsg string
wantLog bool
}{
{"InputInvalidFormatErrorL", InputInvalidFormatErrorL, uint32(code.InputInvalidFormat), 0, "test msg", true},
{"InputNotValidImplementationErrorL", InputNotValidImplementationErrorL, uint32(code.InputNotValidImplementation), 0, "test msg", true},
{"InputInvalidRangeErrorL", InputInvalidRangeErrorL, uint32(code.InputInvalidRange), 0, "test msg", true},
{"DBErrorErrorL", DBErrorErrorL, uint32(code.DBError), 0, "test msg", true},
{"DBDataConvertErrorL", DBDataConvertErrorL, uint32(code.DBDataConvert), 0, "test msg", true},
{"DBDuplicateErrorL", DBDuplicateErrorL, uint32(code.DBDuplicate), 0, "test msg", true},
{"ResNotFoundErrorL", ResNotFoundErrorL, uint32(code.ResNotFound), 0, "test msg", true},
{"ResInvalidFormatErrorL", ResInvalidFormatErrorL, uint32(code.ResInvalidFormat), 0, "test msg", true},
{"ResAlreadyExistErrorL", ResAlreadyExistErrorL, uint32(code.ResAlreadyExist), 0, "test msg", true},
{"ResInsufficientErrorL", ResInsufficientErrorL, uint32(code.ResInsufficient), 0, "test msg", true},
{"ResInsufficientPermErrorL", ResInsufficientPermErrorL, uint32(code.ResInsufficientPerm), 0, "test msg", true},
{"ResInvalidMeasureIDErrorL", ResInvalidMeasureIDErrorL, uint32(code.ResInvalidMeasureID), 0, "test msg", true},
{"ResExpiredErrorL", ResExpiredErrorL, uint32(code.ResExpired), 0, "test msg", true},
{"ResMigratedErrorL", ResMigratedErrorL, uint32(code.ResMigrated), 0, "test msg", true},
{"ResInvalidStateErrorL", ResInvalidStateErrorL, uint32(code.ResInvalidState), 0, "test msg", true},
{"ResInsufficientQuotaErrorL", ResInsufficientQuotaErrorL, uint32(code.ResInsufficientQuota), 0, "test msg", true},
{"ResMultiOwnerErrorL", ResMultiOwnerErrorL, uint32(code.ResMultiOwner), 0, "test msg", true},
{"AuthUnauthorizedErrorL", AuthUnauthorizedErrorL, uint32(code.AuthUnauthorized), 0, "test msg", true},
{"AuthExpiredErrorL", AuthExpiredErrorL, uint32(code.AuthExpired), 0, "test msg", true},
{"AuthInvalidPosixTimeErrorL", AuthInvalidPosixTimeErrorL, uint32(code.AuthInvalidPosixTime), 0, "test msg", true},
{"AuthSigPayloadMismatchErrorL", AuthSigPayloadMismatchErrorL, uint32(code.AuthSigPayloadMismatch), 0, "test msg", true},
{"AuthForbiddenErrorL", AuthForbiddenErrorL, uint32(code.AuthForbidden), 0, "test msg", true},
{"SysInternalErrorL", SysInternalErrorL, uint32(code.SysInternal), 0, "test msg", true},
{"SysMaintainErrorL", SysMaintainErrorL, uint32(code.SysMaintain), 0, "test msg", true},
{"SysTimeoutErrorL", SysTimeoutErrorL, uint32(code.SysTimeout), 0, "test msg", true},
{"SysTooManyRequestErrorL", SysTooManyRequestErrorL, uint32(code.SysTooManyRequest), 0, "test msg", true},
{"PSuPublishErrorL", PSuPublishErrorL, uint32(code.PSuPublish), 0, "test msg", true},
{"PSuConsumeErrorL", PSuConsumeErrorL, uint32(code.PSuConsume), 0, "test msg", true},
{"PSuTooLargeErrorL", PSuTooLargeErrorL, uint32(code.PSuTooLarge), 0, "test msg", true},
{"SvcInternalErrorL", SvcInternalErrorL, uint32(code.SvcInternal), 0, "test msg", true},
{"SvcThirdPartyErrorL", SvcThirdPartyErrorL, uint32(code.SvcThirdParty), 0, "test msg", true},
{"SvcHTTP400ErrorL", SvcHTTP400ErrorL, uint32(code.SvcHTTP400), 0, "test msg", true},
{"SvcMaintenanceErrorL", SvcMaintenanceErrorL, uint32(code.SvcMaintenance), 0, "test msg", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &fakeLogger{}
fields := []LogField{}
e := tt.fn(l, fields, "test", "msg")
if e == nil || e.Category() != tt.wantCat || e.Detail() != tt.wantDet || e.Error() != "test msg" {
t.Errorf("%s = cat=%d det=%d msg=%q, want %d 0 %q", tt.name, e.Category(), e.Detail(), e.Error(), tt.wantCat, "test msg")
}
if tt.wantLog && len(l.calls) == 0 {
t.Errorf("expected log call")
}
})
}
}
// Add TestWrapLConstructors similarly
func TestWrapLConstructors(t *testing.T) {
tests := []struct {
name string
fn func(Logger, []LogField, error, ...string) *Error
wantCat uint32
wantDet uint32
wantMsg string
wantUnwrap string
wantLog bool
}{
{"InputInvalidFormatErrorWrapL", InputInvalidFormatErrorWrapL, uint32(code.InputInvalidFormat), 0, "test msg", "cause err", true},
{"InputNotValidImplementationErrorWrapL", InputNotValidImplementationErrorWrapL, uint32(code.InputNotValidImplementation), 0, "test msg", "cause err", true},
{"InputInvalidRangeErrorWrapL", InputInvalidRangeErrorWrapL, uint32(code.InputInvalidRange), 0, "test msg", "cause err", true},
{"DBErrorErrorWrapL", DBErrorErrorWrapL, uint32(code.DBError), 0, "test msg", "cause err", true},
{"DBDataConvertErrorWrapL", DBDataConvertErrorWrapL, uint32(code.DBDataConvert), 0, "test msg", "cause err", true},
{"DBDuplicateErrorWrapL", DBDuplicateErrorWrapL, uint32(code.DBDuplicate), 0, "test msg", "cause err", true},
{"ResNotFoundErrorWrapL", ResNotFoundErrorWrapL, uint32(code.ResNotFound), 0, "test msg", "cause err", true},
{"ResInvalidFormatErrorWrapL", ResInvalidFormatErrorWrapL, uint32(code.ResInvalidFormat), 0, "test msg", "cause err", true},
{"ResAlreadyExistErrorWrapL", ResAlreadyExistErrorWrapL, uint32(code.ResAlreadyExist), 0, "test msg", "cause err", true},
{"ResInsufficientErrorWrapL", ResInsufficientErrorWrapL, uint32(code.ResInsufficient), 0, "test msg", "cause err", true},
{"ResInsufficientPermErrorWrapL", ResInsufficientPermErrorWrapL, uint32(code.ResInsufficientPerm), 0, "test msg", "cause err", true},
{"ResInvalidMeasureIDErrorWrapL", ResInvalidMeasureIDErrorWrapL, uint32(code.ResInvalidMeasureID), 0, "test msg", "cause err", true},
{"ResExpiredErrorWrapL", ResExpiredErrorWrapL, uint32(code.ResExpired), 0, "test msg", "cause err", true},
{"ResMigratedErrorWrapL", ResMigratedErrorWrapL, uint32(code.ResMigrated), 0, "test msg", "cause err", true},
{"ResInvalidStateErrorWrapL", ResInvalidStateErrorWrapL, uint32(code.ResInvalidState), 0, "test msg", "cause err", true},
{"ResInsufficientQuotaErrorWrapL", ResInsufficientQuotaErrorWrapL, uint32(code.ResInsufficientQuota), 0, "test msg", "cause err", true},
{"ResMultiOwnerErrorWrapL", ResMultiOwnerErrorWrapL, uint32(code.ResMultiOwner), 0, "test msg", "cause err", true},
{"AuthUnauthorizedErrorWrapL", AuthUnauthorizedErrorWrapL, uint32(code.AuthUnauthorized), 0, "test msg", "cause err", true},
{"AuthExpiredErrorWrapL", AuthExpiredErrorWrapL, uint32(code.AuthExpired), 0, "test msg", "cause err", true},
{"AuthInvalidPosixTimeErrorWrapL", AuthInvalidPosixTimeErrorWrapL, uint32(code.AuthInvalidPosixTime), 0, "test msg", "cause err", true},
{"AuthSigPayloadMismatchErrorWrapL", AuthSigPayloadMismatchErrorWrapL, uint32(code.AuthSigPayloadMismatch), 0, "test msg", "cause err", true},
{"AuthForbiddenErrorWrapL", AuthForbiddenErrorWrapL, uint32(code.AuthForbidden), 0, "test msg", "cause err", true},
{"SysInternalErrorWrapL", SysInternalErrorWrapL, uint32(code.SysInternal), 0, "test msg", "cause err", true},
{"SysMaintainErrorWrapL", SysMaintainErrorWrapL, uint32(code.SysMaintain), 0, "test msg", "cause err", true},
{"SysTimeoutErrorWrapL", SysTimeoutErrorWrapL, uint32(code.SysTimeout), 0, "test msg", "cause err", true},
{"SysTooManyRequestErrorWrapL", SysTooManyRequestErrorWrapL, uint32(code.SysTooManyRequest), 0, "test msg", "cause err", true},
{"PSuPublishErrorWrapL", PSuPublishErrorWrapL, uint32(code.PSuPublish), 0, "test msg", "cause err", true},
{"PSuConsumeErrorWrapL", PSuConsumeErrorWrapL, uint32(code.PSuConsume), 0, "test msg", "cause err", true},
{"PSuTooLargeErrorWrapL", PSuTooLargeErrorWrapL, uint32(code.PSuTooLarge), 0, "test msg", "cause err", true},
{"SvcInternalErrorWrapL", SvcInternalErrorWrapL, uint32(code.SvcInternal), 0, "test msg", "cause err", true},
{"SvcThirdPartyErrorWrapL", SvcThirdPartyErrorWrapL, uint32(code.SvcThirdParty), 0, "test msg", "cause err", true},
{"SvcHTTP400ErrorWrapL", SvcHTTP400ErrorWrapL, uint32(code.SvcHTTP400), 0, "test msg", "cause err", true},
{"SvcMaintenanceErrorWrapL", SvcMaintenanceErrorWrapL, uint32(code.SvcMaintenance), 0, "test msg", "cause err", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &fakeLogger{}
fields := []LogField{}
cause := errors.New("cause err")
e := tt.fn(l, fields, cause, "test", "msg")
if e == nil || e.Category() != tt.wantCat || e.Detail() != tt.wantDet || e.Error() != "test msg" {
t.Errorf("%s = cat=%d det=%d msg=%q, want %d 0 %q", tt.name, e.Category(), e.Detail(), e.Error(), tt.wantCat, "test msg")
}
if tt.wantUnwrap != "" && e.Unwrap().Error() != tt.wantUnwrap {
t.Errorf("Unwrap() = %q, want %q", e.Unwrap().Error(), tt.wantUnwrap)
}
if tt.wantLog && len(l.calls) == 0 {
t.Errorf("expected log call")
}
})
}
}
// ... Add more tests for edge cases, like empty strings, multiple args in joinMsg, etc.

View File

@ -0,0 +1,74 @@
package errs
import (
"errors"
"chat/internal/library/errors/code"
"google.golang.org/grpc/status"
)
func newBuiltinGRPCErr(scope, detail uint32, msg string) *Error {
return &Error{
category: uint32(code.CatGRPC),
detail: detail,
scope: scope,
msg: msg,
}
}
// FromError tries to let error as Err
// it supports to unwrap error that has Error
// return nil if failed to transfer
func FromError(err error) *Error {
if err == nil {
return nil
}
var e *Error
if errors.As(err, &e) {
return e
}
return nil
}
// FromCode parses code as following 8 碼
// Decimal: 10201000
// 10 represents Scope
// 201 represents Category
// 000 represents Detail error code
func FromCode(code uint32) *Error {
const CodeMultiplier = 1000000
const SubMultiplier = 1000
// 獲取 scope前兩位數
scope := code / CodeMultiplier
// 獲取 detail最後三位數
detail := code % SubMultiplier
// 獲取 category中間三位數
category := (code / SubMultiplier) % SubMultiplier
return &Error{
category: category,
detail: detail,
scope: scope,
msg: "",
}
}
// FromGRPCError transfer error to Err
// useful for gRPC client
func FromGRPCError(err error) *Error {
s, _ := status.FromError(err)
e := FromCode(uint32(s.Code()))
e.msg = s.Message()
// For GRPC built-in code
if e.Scope() == uint32(code.Unset) && e.Category() == 0 && e.Code() != code.OK {
e = newBuiltinGRPCErr(uint32(Scope), e.detail, s.Message()) // Note: detail is now 3-digit, but built-in codes are small
}
return e
}

View File

@ -0,0 +1,116 @@
package errs
import (
"errors"
"fmt"
"testing"
"chat/internal/library/errors/code"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestNewBuiltinGRPCErr(t *testing.T) {
tests := []struct {
name string
scope uint32
detail uint32
msg string
wantCat uint32
wantScope uint32
wantDet uint32
wantMsg string
}{
{"basic", 10, 3, "test", uint32(code.CatGRPC), 10, 3, "test"},
{"zero", 0, 0, "", uint32(code.CatGRPC), 0, 0, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := newBuiltinGRPCErr(tt.scope, tt.detail, tt.msg)
if e.Category() != tt.wantCat || e.Scope() != tt.wantScope || e.Detail() != tt.wantDet || e.Error() != tt.wantMsg {
t.Errorf("newBuiltinGRPCErr = cat=%d scope=%d det=%d msg=%q, want %d %d %d %q",
e.Category(), e.Scope(), e.Detail(), e.Error(), tt.wantCat, tt.wantScope, tt.wantDet, tt.wantMsg)
}
})
}
}
func TestFromError(t *testing.T) {
base := New(10, uint32(code.DBError), 0, "base")
// but actually use fmt.Errorf("%w", base) for proper wrapping
tests := []struct {
name string
in error
want *Error
}{
{"nil", nil, nil},
{"not Error", errors.New("std"), nil},
{"direct Error", base, base},
{"wrapped Error", fmt.Errorf("wrap: %w", base), base},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromError(tt.in)
if (got == nil) != (tt.want == nil) {
t.Errorf("FromError = %v, want nil=%v", got, tt.want == nil)
}
if got != nil && (got.Category() != tt.want.Category() || got.Detail() != tt.want.Detail()) {
t.Errorf("FromError = cat=%d det=%d, want %d %d", got.Category(), got.Detail(), tt.want.Category(), tt.want.Detail())
}
})
}
}
func TestFromCode(t *testing.T) {
tests := []struct {
name string
code uint32
wantScope uint32
wantCat uint32
wantDet uint32
}{
{"basic", 10201123, 10, 201, 123},
{"zero", 0, 0, 0, 0},
{"max", 99999999, 99, 999, 999},
{"overflow code", 100000000, 100, 0, 0}, // Parses as scope=100, but uint32 limits
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := FromCode(tt.code)
if e.Scope() != tt.wantScope || e.Category() != tt.wantCat || e.Detail() != tt.wantDet {
t.Errorf("FromCode = scope=%d cat=%d det=%d, want %d %d %d", e.Scope(), e.Category(), e.Detail(), tt.wantScope, tt.wantCat, tt.wantDet)
}
})
}
}
func TestFromGRPCError(t *testing.T) {
Scope = code.Gateway
tests := []struct {
name string
in error
wantScope uint32
wantCat uint32
wantDet uint32
wantMsg string
}{
{"nil", nil, 0, 0, 0, ""},
{"builtin OK", status.New(codes.OK, "").Err(), 0, 0, 0, ""},
{"builtin InvalidArgument", status.New(codes.InvalidArgument, "bad").Err(), uint32(code.Gateway), uint32(code.CatGRPC), uint32(codes.InvalidArgument), "bad"},
{"custom code", status.New(codes.Code(10201123), "custom").Err(), 10, 201, 123, "custom"},
{"unset scope with builtin", status.New(codes.NotFound, "not found").Err(), uint32(code.Gateway), uint32(code.CatGRPC), uint32(codes.NotFound), "not found"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := FromGRPCError(tt.in)
if e == nil && (tt.wantScope != 0 || tt.wantCat != 0 || tt.wantDet != 0) {
t.Errorf("got nil, want non-nil")
}
if e != nil && (e.Scope() != tt.wantScope || e.Category() != tt.wantCat || e.Detail() != tt.wantDet || e.Error() != tt.wantMsg) {
t.Errorf("FromGRPCError = scope=%d cat=%d det=%d msg=%q, want %d %d %d %q",
e.Scope(), e.Category(), e.Detail(), e.Error(), tt.wantScope, tt.wantCat, tt.wantDet, tt.wantMsg)
}
})
}
}

View File

@ -0,0 +1,19 @@
package mongo
import "time"
type Conf struct {
Schema string
User string
Password string
Host string
Database string
ReplicaName string
MaxStaleness time.Duration
MaxPoolSize uint64
MinPoolSize uint64
MaxConnIdleTime time.Duration
Compressors []string
EnableStandardReadWriteSplitMode bool
ConnectTimeoutMs int64
}

View File

@ -0,0 +1,113 @@
package mongo
import (
"testing"
"time"
)
func TestConf_DefaultValues(t *testing.T) {
conf := &Conf{}
// Test default values
if conf.Schema != "" {
t.Errorf("Expected empty Schema, got %s", conf.Schema)
}
if conf.User != "" {
t.Errorf("Expected empty User, got %s", conf.User)
}
if conf.Password != "" {
t.Errorf("Expected empty Password, got %s", conf.Password)
}
if conf.Host != "" {
t.Errorf("Expected empty Host, got %s", conf.Host)
}
if conf.Database != "" {
t.Errorf("Expected empty Database, got %s", conf.Database)
}
if conf.ReplicaName != "" {
t.Errorf("Expected empty ReplicaName, got %s", conf.ReplicaName)
}
if conf.MaxStaleness != 0 {
t.Errorf("Expected zero MaxStaleness, got %v", conf.MaxStaleness)
}
if conf.MaxPoolSize != 0 {
t.Errorf("Expected zero MaxPoolSize, got %d", conf.MaxPoolSize)
}
if conf.MinPoolSize != 0 {
t.Errorf("Expected zero MinPoolSize, got %d", conf.MinPoolSize)
}
if conf.MaxConnIdleTime != 0 {
t.Errorf("Expected zero MaxConnIdleTime, got %v", conf.MaxConnIdleTime)
}
if conf.Compressors != nil {
t.Errorf("Expected nil Compressors, got %v", conf.Compressors)
}
if conf.EnableStandardReadWriteSplitMode {
t.Errorf("Expected false EnableStandardReadWriteSplitMode, got %v", conf.EnableStandardReadWriteSplitMode)
}
if conf.ConnectTimeoutMs != 0 {
t.Errorf("Expected zero ConnectTimeoutMs, got %d", conf.ConnectTimeoutMs)
}
}
func TestConf_WithValues(t *testing.T) {
conf := &Conf{
Schema: "mongodb",
User: "testuser",
Password: "testpass",
Host: "localhost:27017",
Database: "testdb",
ReplicaName: "testreplica",
MaxStaleness: 30 * time.Second,
MaxPoolSize: 100,
MinPoolSize: 10,
MaxConnIdleTime: 5 * time.Minute,
Compressors: []string{"snappy", "zlib"},
EnableStandardReadWriteSplitMode: true,
ConnectTimeoutMs: 5000,
}
// Test set values
if conf.Schema != "mongodb" {
t.Errorf("Expected 'mongodb' Schema, got %s", conf.Schema)
}
if conf.User != "testuser" {
t.Errorf("Expected 'testuser' User, got %s", conf.User)
}
if conf.Password != "testpass" {
t.Errorf("Expected 'testpass' Password, got %s", conf.Password)
}
if conf.Host != "localhost:27017" {
t.Errorf("Expected 'localhost:27017' Host, got %s", conf.Host)
}
if conf.Database != "testdb" {
t.Errorf("Expected 'testdb' Database, got %s", conf.Database)
}
if conf.ReplicaName != "testreplica" {
t.Errorf("Expected 'testreplica' ReplicaName, got %s", conf.ReplicaName)
}
if conf.MaxStaleness != 30*time.Second {
t.Errorf("Expected 30s MaxStaleness, got %v", conf.MaxStaleness)
}
if conf.MaxPoolSize != 100 {
t.Errorf("Expected 100 MaxPoolSize, got %d", conf.MaxPoolSize)
}
if conf.MinPoolSize != 10 {
t.Errorf("Expected 10 MinPoolSize, got %d", conf.MinPoolSize)
}
if conf.MaxConnIdleTime != 5*time.Minute {
t.Errorf("Expected 5m MaxConnIdleTime, got %v", conf.MaxConnIdleTime)
}
if len(conf.Compressors) != 2 {
t.Errorf("Expected 2 Compressors, got %d", len(conf.Compressors))
}
if conf.Compressors[0] != "snappy" || conf.Compressors[1] != "zlib" {
t.Errorf("Expected ['snappy', 'zlib'] Compressors, got %v", conf.Compressors)
}
if !conf.EnableStandardReadWriteSplitMode {
t.Errorf("Expected true EnableStandardReadWriteSplitMode, got %v", conf.EnableStandardReadWriteSplitMode)
}
if conf.ConnectTimeoutMs != 5000 {
t.Errorf("Expected 5000 ConnectTimeoutMs, got %d", conf.ConnectTimeoutMs)
}
}

View File

@ -0,0 +1,21 @@
package mongo
import (
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/syncx"
"go.mongodb.org/mongo-driver/v2/mongo"
)
const (
authenticationStringTemplate = "%s:%s@"
connectionStringTemplate = "%s://%s%s"
)
var (
// ErrNotFound is an alias of mongo.ErrNoDocuments.
ErrNotFound = mongo.ErrNoDocuments
// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
singleFlight = syncx.NewSingleFlight()
stats = cache.NewStat("monc")
)

View File

@ -0,0 +1,50 @@
package mongo
import (
"fmt"
"reflect"
"github.com/shopspring/decimal"
"go.mongodb.org/mongo-driver/v2/bson"
)
type MgoDecimal struct{}
var (
_ bson.ValueEncoder = &MgoDecimal{}
_ bson.ValueDecoder = &MgoDecimal{}
)
func (dc *MgoDecimal) EncodeValue(_ bson.EncodeContext, w bson.ValueWriter, value reflect.Value) error {
// TODO 待確認是否有非decimal.Decimal type而導致error的場景
dec, ok := value.Interface().(decimal.Decimal)
if !ok {
return fmt.Errorf("value %v to encode is not of type decimal.Decimal", value)
}
// Convert decimal.Decimal to bson.Decimal128.
primDec, err := bson.ParseDecimal128(dec.String())
if err != nil {
return fmt.Errorf("converting decimal.Decimal %v to bson.Decimal128 error: %w", dec, err)
}
return w.WriteDecimal128(primDec)
}
func (dc *MgoDecimal) DecodeValue(_ bson.DecodeContext, r bson.ValueReader, value reflect.Value) error {
primDec, err := r.ReadDecimal128()
if err != nil {
return fmt.Errorf("reading bson.Decimal128 from ValueReader error: %w", err)
}
// Convert bson.Decimal128 to decimal.Decimal.
dec, err := decimal.NewFromString(primDec.String())
if err != nil {
return fmt.Errorf("converting bson.Decimal128 %v to decimal.Decimal error: %w", primDec, err)
}
// set as decimal.Decimal type
value.Set(reflect.ValueOf(dec))
return nil
}

View File

@ -0,0 +1,275 @@
package mongo
import (
"reflect"
"testing"
"github.com/shopspring/decimal"
"go.mongodb.org/mongo-driver/v2/bson"
)
func TestMgoDecimal_InterfaceCompliance(t *testing.T) {
encoder := &MgoDecimal{}
decoder := &MgoDecimal{}
// Test that they implement the required interfaces
var _ bson.ValueEncoder = encoder
var _ bson.ValueDecoder = decoder
// Test that they can be used in TypeCodec
codec := TypeCodec{
ValueType: reflect.TypeOf(decimal.Decimal{}),
Encoder: encoder,
Decoder: decoder,
}
if codec.Encoder != encoder {
t.Error("Expected encoder to be set correctly")
}
if codec.Decoder != decoder {
t.Error("Expected decoder to be set correctly")
}
}
func TestMgoDecimal_EncodeValue_InvalidType(t *testing.T) {
encoder := &MgoDecimal{}
// Test with invalid type
value := reflect.ValueOf("not a decimal")
err := encoder.EncodeValue(bson.EncodeContext{}, nil, value)
if err == nil {
t.Error("Expected error for invalid type, got nil")
}
expectedErr := "value not a decimal to encode is not of type decimal.Decimal"
if err.Error() != expectedErr {
t.Errorf("Expected error '%s', got '%s'", expectedErr, err.Error())
}
}
// Test decimal conversion functions
func TestDecimalConversion(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"0", "0"},
{"123.45", "123.45"},
{"-123.45", "-123.45"},
{"0.000001", "0.000001"},
{"9999999999999999999.999999999999999", "9999999999999999999.999999999999999"},
{"-9999999999999999999.999999999999999", "-9999999999999999999.999999999999999"},
}
for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
// Test decimal to string conversion
dec, err := decimal.NewFromString(tc.input)
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", tc.input, err)
}
if dec.String() != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, dec.String())
}
// Test BSON decimal128 conversion
primDec, err := bson.ParseDecimal128(dec.String())
if err != nil {
t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err)
}
if primDec.String() != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, primDec.String())
}
})
}
}
// Test error cases
func TestDecimalConversionErrors(t *testing.T) {
invalidCases := []string{
"invalid",
"not a number",
"",
"123.45.67",
"abc123",
}
for _, invalid := range invalidCases {
t.Run(invalid, func(t *testing.T) {
_, err := decimal.NewFromString(invalid)
if err == nil {
t.Errorf("Expected error for invalid decimal string: %s", invalid)
}
_, err = bson.ParseDecimal128(invalid)
if err == nil {
t.Errorf("Expected error for invalid decimal128 string: %s", invalid)
}
})
}
}
// Test edge cases for decimal values
func TestDecimalEdgeCases(t *testing.T) {
testCases := []struct {
name string
value decimal.Decimal
expected string
}{
{"zero", decimal.Zero, "0"},
{"positive small", decimal.NewFromFloat(0.000001), "0.000001"},
{"negative small", decimal.NewFromFloat(-0.000001), "-0.000001"},
{"positive large", decimal.NewFromInt(999999999999999), "999999999999999"},
{"negative large", decimal.NewFromInt(-999999999999999), "-999999999999999"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test conversion to BSON Decimal128
primDec, err := bson.ParseDecimal128(tc.value.String())
if err != nil {
t.Fatalf("Failed to parse decimal128 from %s: %v", tc.value.String(), err)
}
// Test conversion back to decimal
dec, err := decimal.NewFromString(primDec.String())
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err)
}
if !dec.Equal(tc.value) {
t.Errorf("Round trip failed: original=%s, result=%s", tc.value.String(), dec.String())
}
})
}
}
// Test error handling in encoder
func TestMgoDecimal_EncoderErrors(t *testing.T) {
encoder := &MgoDecimal{}
testCases := []struct {
name string
value interface{}
}{
{"string", "not a decimal"},
{"int", 123},
{"float", 123.45},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
value := reflect.ValueOf(tc.value)
err := encoder.EncodeValue(bson.EncodeContext{}, nil, value)
if err == nil {
t.Errorf("Expected error for type %T, got nil", tc.value)
}
})
}
}
// Test decimal precision
func TestDecimalPrecision(t *testing.T) {
testCases := []string{
"0.1",
"0.01",
"0.001",
"0.0001",
"0.00001",
"0.000001",
"0.0000001",
"0.00000001",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
dec, err := decimal.NewFromString(tc)
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", tc, err)
}
// Test conversion to BSON Decimal128
primDec, err := bson.ParseDecimal128(dec.String())
if err != nil {
t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err)
}
// Test conversion back to decimal
result, err := decimal.NewFromString(primDec.String())
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err)
}
if !result.Equal(dec) {
t.Errorf("Precision lost: original=%s, result=%s", dec.String(), result.String())
}
})
}
}
// Test large numbers
func TestDecimalLargeNumbers(t *testing.T) {
testCases := []string{
"1000000000000000",
"10000000000000000",
"100000000000000000",
"1000000000000000000",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
dec, err := decimal.NewFromString(tc)
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", tc, err)
}
// Test conversion to BSON Decimal128
primDec, err := bson.ParseDecimal128(dec.String())
if err != nil {
t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err)
}
// Test conversion back to decimal
result, err := decimal.NewFromString(primDec.String())
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err)
}
if !result.Equal(dec) {
t.Errorf("Large number lost: original=%s, result=%s", dec.String(), result.String())
}
})
}
}
// Benchmark tests
func BenchmarkMgoDecimal_ParseDecimal128(b *testing.B) {
dec := decimal.NewFromFloat(123.45)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = bson.ParseDecimal128(dec.String())
}
}
func BenchmarkMgoDecimal_DecimalFromString(b *testing.B) {
primDec, _ := bson.ParseDecimal128("123.45")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = decimal.NewFromString(primDec.String())
}
}
func BenchmarkMgoDecimal_RoundTrip(b *testing.B) {
dec := decimal.NewFromFloat(123.45)
b.ResetTimer()
for i := 0; i < b.N; i++ {
primDec, _ := bson.ParseDecimal128(dec.String())
_, _ = decimal.NewFromString(primDec.String())
}
}

View File

@ -0,0 +1,339 @@
package mongo
import (
"context"
"fmt"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
type DocumentDBWithCache struct {
DocumentDBUseCase
Cache cache.Cache
}
func MustDocumentDBWithCache(conf *Conf, collection string, cacheConf cache.CacheConf, dbOpts []mon.Option, cacheOpts []cache.Option) (DocumentDBWithCacheUseCase, error) {
documentDB, err := NewDocumentDB(conf, collection, dbOpts...)
if err != nil {
return nil, fmt.Errorf("failed to initialize DocumentDB: %w", err)
}
c := MustModelCache(cacheConf, cacheOpts...)
return &DocumentDBWithCache{
DocumentDBUseCase: documentDB,
Cache: c,
}, nil
}
func (dc *DocumentDBWithCache) DelCache(ctx context.Context, keys ...string) error {
return dc.Cache.DelCtx(ctx, keys...)
}
func (dc *DocumentDBWithCache) GetCache(key string, v any) error {
return dc.Cache.Get(key, v)
}
func (dc *DocumentDBWithCache) SetCache(key string, v any) error {
return dc.Cache.Set(key, v)
}
// DeleteOne deletes a single document and invalidates cache
func (dc *DocumentDBWithCache) DeleteOne(ctx context.Context, key string, filter any, opts ...*options.DeleteOneOptions) (int64, error) {
// Convert options to Builder format
var listerOpts []options.Lister[options.DeleteOneOptions]
for _, opt := range opts {
if opt != nil {
builder := options.DeleteOne()
if opt.Collation != nil {
builder.SetCollation(opt.Collation)
}
if opt.Comment != nil {
builder.SetComment(opt.Comment)
}
if opt.Hint != nil {
builder.SetHint(opt.Hint)
}
listerOpts = append(listerOpts, builder)
}
}
val, err := dc.GetClient().DeleteOne(ctx, filter, listerOpts...)
if err != nil {
return 0, err
}
if err := dc.DelCache(ctx, key); err != nil {
return 0, err
}
return val, nil
}
// FindOne finds a single document with cache support
func (dc *DocumentDBWithCache) FindOne(ctx context.Context, key string, v, filter any, opts ...*options.FindOneOptions) error {
return dc.Cache.TakeCtx(ctx, v, key, func(v any) error {
// Convert options to Builder format
var listerOpts []options.Lister[options.FindOneOptions]
for _, opt := range opts {
if opt != nil {
builder := options.FindOne()
if opt.Collation != nil {
builder.SetCollation(opt.Collation)
}
if opt.Comment != nil {
builder.SetComment(opt.Comment)
}
if opt.Hint != nil {
builder.SetHint(opt.Hint)
}
if opt.Projection != nil {
builder.SetProjection(opt.Projection)
}
if opt.Skip != nil {
builder.SetSkip(*opt.Skip)
}
if opt.Sort != nil {
builder.SetSort(opt.Sort)
}
listerOpts = append(listerOpts, builder)
}
}
return dc.GetClient().FindOne(ctx, v, filter, listerOpts...)
})
}
// FindOneAndDelete finds and deletes a single document with cache invalidation
func (dc *DocumentDBWithCache) FindOneAndDelete(ctx context.Context, key string, v, filter any, opts ...*options.FindOneAndDeleteOptions) error {
// Convert options to Builder format
var listerOpts []options.Lister[options.FindOneAndDeleteOptions]
for _, opt := range opts {
if opt != nil {
builder := options.FindOneAndDelete()
if opt.Collation != nil {
builder.SetCollation(opt.Collation)
}
if opt.Comment != nil {
builder.SetComment(opt.Comment)
}
if opt.Hint != nil {
builder.SetHint(opt.Hint)
}
if opt.Projection != nil {
builder.SetProjection(opt.Projection)
}
if opt.Sort != nil {
builder.SetSort(opt.Sort)
}
listerOpts = append(listerOpts, builder)
}
}
if err := dc.GetClient().FindOneAndDelete(ctx, v, filter, listerOpts...); err != nil {
return err
}
return dc.DelCache(ctx, key)
}
// FindOneAndReplace finds and replaces a single document with cache invalidation
func (dc *DocumentDBWithCache) FindOneAndReplace(ctx context.Context, key string, v, filter, replacement any, opts ...*options.FindOneAndReplaceOptions) error {
// Convert options to Builder format
var listerOpts []options.Lister[options.FindOneAndReplaceOptions]
for _, opt := range opts {
if opt != nil {
builder := options.FindOneAndReplace()
if opt.Collation != nil {
builder.SetCollation(opt.Collation)
}
if opt.Comment != nil {
builder.SetComment(opt.Comment)
}
if opt.Hint != nil {
builder.SetHint(opt.Hint)
}
if opt.Projection != nil {
builder.SetProjection(opt.Projection)
}
if opt.ReturnDocument != nil {
builder.SetReturnDocument(*opt.ReturnDocument)
}
if opt.Sort != nil {
builder.SetSort(opt.Sort)
}
if opt.Upsert != nil {
builder.SetUpsert(*opt.Upsert)
}
listerOpts = append(listerOpts, builder)
}
}
if err := dc.GetClient().FindOneAndReplace(ctx, v, filter, replacement, listerOpts...); err != nil {
return err
}
return dc.DelCache(ctx, key)
}
// InsertOne inserts a single document and invalidates cache
func (dc *DocumentDBWithCache) InsertOne(ctx context.Context, key string, document any, opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) {
// Convert options to Builder format
var listerOpts []options.Lister[options.InsertOneOptions]
for _, opt := range opts {
if opt != nil {
builder := options.InsertOne()
if opt.BypassDocumentValidation != nil {
builder.SetBypassDocumentValidation(*opt.BypassDocumentValidation)
}
if opt.Comment != nil {
builder.SetComment(opt.Comment)
}
listerOpts = append(listerOpts, builder)
}
}
res, err := dc.GetClient().Collection.InsertOne(ctx, document, listerOpts...)
if err != nil {
return nil, err
}
if err = dc.DelCache(ctx, key); err != nil {
return nil, err
}
return res, nil
}
// UpdateByID updates a document by ID and invalidates cache
func (dc *DocumentDBWithCache) UpdateByID(ctx context.Context, key string, id, update any, opts ...*options.UpdateOneOptions) (*mongo.UpdateResult, error) {
// Convert options to Builder format
var listerOpts []options.Lister[options.UpdateOneOptions]
for _, opt := range opts {
if opt != nil {
builder := options.UpdateOne()
if opt.ArrayFilters != nil {
builder.SetArrayFilters(opt.ArrayFilters)
}
if opt.BypassDocumentValidation != nil {
builder.SetBypassDocumentValidation(*opt.BypassDocumentValidation)
}
if opt.Collation != nil {
builder.SetCollation(opt.Collation)
}
if opt.Comment != nil {
builder.SetComment(opt.Comment)
}
if opt.Hint != nil {
builder.SetHint(opt.Hint)
}
if opt.Upsert != nil {
builder.SetUpsert(*opt.Upsert)
}
listerOpts = append(listerOpts, builder)
}
}
res, err := dc.GetClient().Collection.UpdateByID(ctx, id, update, listerOpts...)
if err != nil {
return nil, err
}
if err = dc.DelCache(ctx, key); err != nil {
return nil, err
}
return res, nil
}
// UpdateMany updates multiple documents and invalidates cache
func (dc *DocumentDBWithCache) UpdateMany(ctx context.Context, keys []string, filter, update any, opts ...*options.UpdateManyOptions) (*mongo.UpdateResult, error) {
// Convert options to Builder format
var listerOpts []options.Lister[options.UpdateManyOptions]
for _, opt := range opts {
if opt != nil {
builder := options.UpdateMany()
if opt.ArrayFilters != nil {
builder.SetArrayFilters(opt.ArrayFilters)
}
if opt.BypassDocumentValidation != nil {
builder.SetBypassDocumentValidation(*opt.BypassDocumentValidation)
}
if opt.Collation != nil {
builder.SetCollation(opt.Collation)
}
if opt.Comment != nil {
builder.SetComment(opt.Comment)
}
if opt.Hint != nil {
builder.SetHint(opt.Hint)
}
if opt.Upsert != nil {
builder.SetUpsert(*opt.Upsert)
}
listerOpts = append(listerOpts, builder)
}
}
res, err := dc.GetClient().Collection.UpdateMany(ctx, filter, update, listerOpts...)
if err != nil {
return nil, err
}
if err = dc.DelCache(ctx, keys...); err != nil {
return nil, err
}
return res, nil
}
// UpdateOne updates a single document and invalidates cache
func (dc *DocumentDBWithCache) UpdateOne(ctx context.Context, key string, filter, update any, opts ...*options.UpdateOneOptions) (*mongo.UpdateResult, error) {
// Convert options to Builder format
var listerOpts []options.Lister[options.UpdateOneOptions]
for _, opt := range opts {
if opt != nil {
builder := options.UpdateOne()
if opt.ArrayFilters != nil {
builder.SetArrayFilters(opt.ArrayFilters)
}
if opt.BypassDocumentValidation != nil {
builder.SetBypassDocumentValidation(*opt.BypassDocumentValidation)
}
if opt.Collation != nil {
builder.SetCollation(opt.Collation)
}
if opt.Comment != nil {
builder.SetComment(opt.Comment)
}
if opt.Hint != nil {
builder.SetHint(opt.Hint)
}
if opt.Upsert != nil {
builder.SetUpsert(*opt.Upsert)
}
listerOpts = append(listerOpts, builder)
}
}
res, err := dc.GetClient().Collection.UpdateOne(ctx, filter, update, listerOpts...)
if err != nil {
return nil, err
}
if err = dc.DelCache(ctx, key); err != nil {
return nil, err
}
return res, nil
}
// ========================
// MustModelCache returns a cache cluster.
func MustModelCache(conf cache.CacheConf, opts ...cache.Option) cache.Cache {
return cache.New(conf, singleFlight, stats, mongo.ErrNoDocuments, opts...)
}

View File

@ -0,0 +1,364 @@
package mongo
import (
"context"
"testing"
"time"
"github.com/shopspring/decimal"
"github.com/zeromicro/go-zero/core/stores/cache"
"go.mongodb.org/mongo-driver/v2/bson"
)
func TestDocumentDBWithCache_MustDocumentDBWithCache(t *testing.T) {
// Test with valid config
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
collection := "testcollection"
cacheConf := cache.CacheConf{}
// This will panic if MongoDB is not available, so we need to handle it
defer func() {
if r := recover(); r != nil {
t.Logf("Expected panic in test environment: %v", r)
}
}()
db, err := MustDocumentDBWithCache(conf, collection, cacheConf, nil, nil)
if err != nil {
t.Logf("MongoDB connection failed (expected in test environment): %v", err)
return
}
if db == nil {
t.Error("Expected DocumentDBWithCache to be non-nil")
}
}
func TestDocumentDBWithCache_CacheOperations(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
collection := "testcollection"
cacheConf := cache.CacheConf{}
ctx := context.Background()
db, err := MustDocumentDBWithCache(conf, collection, cacheConf, nil, nil)
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test cache operations
key := "test-key"
value := "test-value"
// Test SetCache
err = db.SetCache(key, value)
if err != nil {
t.Errorf("Failed to set cache: %v", err)
}
// Test GetCache
var cachedValue string
err = db.GetCache(key, &cachedValue)
if err != nil {
t.Errorf("Failed to get cache: %v", err)
}
if cachedValue != value {
t.Errorf("Expected cached value %s, got %s", value, cachedValue)
}
// Test DelCache
err = db.DelCache(ctx, key)
if err != nil {
t.Errorf("Failed to delete cache: %v", err)
}
}
func TestDocumentDBWithCache_CRUDOperations(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
collection := "testcollection"
cacheConf := cache.CacheConf{}
ctx := context.Background()
db, err := MustDocumentDBWithCache(conf, collection, cacheConf, nil, nil)
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test data
testDoc := bson.M{
"name": "test",
"value": 123,
"price": decimal.NewFromFloat(99.99),
}
// Test InsertOne
result, err := db.InsertOne(ctx, collection, testDoc)
if err != nil {
t.Errorf("Failed to insert document: %v", err)
}
insertedID := result.InsertedID
if insertedID == nil {
t.Error("Expected inserted ID to be non-nil")
}
// Test FindOne
var foundDoc bson.M
err = db.FindOne(ctx, collection, bson.M{"_id": insertedID}, &foundDoc)
if err != nil {
t.Errorf("Failed to find document: %v", err)
}
if foundDoc["name"] != "test" {
t.Errorf("Expected name 'test', got %v", foundDoc["name"])
}
// Test UpdateOne
update := bson.M{"$set": bson.M{"value": 456}}
updateResult, err := db.UpdateOne(ctx, collection, bson.M{"_id": insertedID}, update)
if err != nil {
t.Errorf("Failed to update document: %v", err)
}
if updateResult.ModifiedCount != 1 {
t.Errorf("Expected 1 modified document, got %d", updateResult.ModifiedCount)
}
// Test UpdateByID
updateByID := bson.M{"$set": bson.M{"value": 789}}
updateByIDResult, err := db.UpdateByID(ctx, collection, insertedID, updateByID)
if err != nil {
t.Errorf("Failed to update document by ID: %v", err)
}
if updateByIDResult.ModifiedCount != 1 {
t.Errorf("Expected 1 modified document, got %d", updateByIDResult.ModifiedCount)
}
// Test UpdateMany
updateMany := bson.M{"$set": bson.M{"updated": true}}
updateManyResult, err := db.UpdateMany(ctx, []string{collection}, bson.M{"_id": insertedID}, updateMany)
if err != nil {
t.Errorf("Failed to update many documents: %v", err)
}
if updateManyResult.ModifiedCount != 1 {
t.Errorf("Expected 1 modified document, got %d", updateManyResult.ModifiedCount)
}
// Test FindOneAndReplace
replacement := bson.M{
"name": "replaced",
"value": 999,
"price": decimal.NewFromFloat(199.99),
}
var replacedDoc bson.M
err = db.FindOneAndReplace(ctx, collection, bson.M{"_id": insertedID}, replacement, &replacedDoc)
if err != nil {
t.Errorf("Failed to find and replace document: %v", err)
}
// Test FindOneAndDelete
var deletedDoc bson.M
err = db.FindOneAndDelete(ctx, collection, bson.M{"_id": insertedID}, &deletedDoc)
if err != nil {
t.Errorf("Failed to find and delete document: %v", err)
}
// Test DeleteOne
deleteResult, err := db.DeleteOne(ctx, collection, bson.M{"_id": insertedID})
if err != nil {
t.Errorf("Failed to delete document: %v", err)
}
if deleteResult != 0 { // Should be 0 since we already deleted it
t.Errorf("Expected 0 deleted documents, got %d", deleteResult)
}
}
func TestDocumentDBWithCache_MustModelCache(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
collection := "testcollection"
cacheConf := cache.CacheConf{}
db, err := MustDocumentDBWithCache(conf, collection, cacheConf, nil, nil)
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test that we got a valid DocumentDBWithCache
if db == nil {
t.Error("Expected DocumentDBWithCache to be non-nil")
}
}
func TestDocumentDBWithCache_ErrorHandling(t *testing.T) {
// Test with invalid config
invalidConf := &Conf{
Host: "invalid-host:99999",
}
collection := "testcollection"
cacheConf := cache.CacheConf{}
_, err := MustDocumentDBWithCache(invalidConf, collection, cacheConf, nil, nil)
// This should fail
if err == nil {
t.Error("Expected error with invalid host, got nil")
}
}
func TestDocumentDBWithCache_ContextHandling(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
collection := "testcollection"
cacheConf := cache.CacheConf{}
// Test with timeout context
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
db, err := MustDocumentDBWithCache(conf, collection, cacheConf, nil, nil)
// Use ctx to avoid unused variable warning
_ = ctx
if err != nil {
t.Logf("MongoDB connection failed (expected in test environment): %v", err)
return
}
if db == nil {
t.Error("Expected DocumentDBWithCache to be non-nil")
}
}
func TestDocumentDBWithCache_WithDecimalValues(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
collection := "testcollection"
cacheConf := cache.CacheConf{}
ctx := context.Background()
db, err := MustDocumentDBWithCache(conf, collection, cacheConf, nil, nil)
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test with decimal values
testDoc := bson.M{
"name": "decimal-test",
"price": decimal.NewFromFloat(123.45),
"amount": decimal.NewFromFloat(999.99),
}
// Insert document with decimal values
result, err := db.InsertOne(ctx, collection, testDoc)
if err != nil {
t.Errorf("Failed to insert document with decimal values: %v", err)
}
insertedID := result.InsertedID
// Find document with decimal values
var foundDoc bson.M
err = db.FindOne(ctx, collection, bson.M{"_id": insertedID}, &foundDoc)
if err != nil {
t.Errorf("Failed to find document with decimal values: %v", err)
}
// Verify decimal values
if foundDoc["name"] != "decimal-test" {
t.Errorf("Expected name 'decimal-test', got %v", foundDoc["name"])
}
// Clean up
_, err = db.DeleteOne(ctx, collection, bson.M{"_id": insertedID})
if err != nil {
t.Errorf("Failed to clean up document: %v", err)
}
}
func TestDocumentDBWithCache_WithObjectID(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
collection := "testcollection"
cacheConf := cache.CacheConf{}
ctx := context.Background()
db, err := MustDocumentDBWithCache(conf, collection, cacheConf, nil, nil)
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test with ObjectID
objectID := bson.NewObjectID()
testDoc := bson.M{
"_id": objectID,
"name": "objectid-test",
"value": 123,
}
// Insert document with ObjectID
result, err := db.InsertOne(ctx, collection, testDoc)
if err != nil {
t.Errorf("Failed to insert document with ObjectID: %v", err)
}
insertedID := result.InsertedID
// Verify ObjectID
if insertedID != objectID {
t.Errorf("Expected ObjectID %v, got %v", objectID, insertedID)
}
// Find document by ObjectID
var foundDoc bson.M
err = db.FindOne(ctx, collection, bson.M{"_id": objectID}, &foundDoc)
if err != nil {
t.Errorf("Failed to find document by ObjectID: %v", err)
}
// Clean up
_, err = db.DeleteOne(ctx, collection, bson.M{"_id": objectID})
if err != nil {
t.Errorf("Failed to clean up document: %v", err)
}
}

159
internal/library/mongo/doc-db.go Executable file
View File

@ -0,0 +1,159 @@
package mongo
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
)
type DocumentDB struct {
Mon *mon.Model
}
func NewDocumentDB(config *Conf, collection string, opts ...mon.Option) (DocumentDBUseCase, error) {
authenticationURI := ""
if config.User != "" {
authenticationURI = fmt.Sprintf(
authenticationStringTemplate,
config.User,
config.Password,
)
}
connectionURI := fmt.Sprintf(
connectionStringTemplate,
config.Schema,
authenticationURI,
config.Host,
)
connectUri, err := url.Parse(connectionURI)
if err != nil {
return nil, fmt.Errorf("failed to parse connection URI: %w", err)
}
printConnectUri := connectUri.String()
findIndexAt := strings.Index(connectUri.String(), "@")
if findIndexAt > -1 && config.User != "" {
prefixIndex := len(config.Schema) + 3 + len(config.User)
connectUriStr := connectUri.String()
printConnectUri = fmt.Sprintf("%s:*****%s", connectUriStr[:prefixIndex], connectUriStr[findIndexAt:])
}
// 初始化選項
intOpt := InitMongoOptions(*config)
opts = append(opts, intOpt)
logx.Infof("[DocumentDB] Try to connect document db `%s`", printConnectUri)
client, err := mon.NewModel(connectionURI, config.Database, collection, opts...)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(
context.Background(),
time.Duration(config.ConnectTimeoutMs)*time.Millisecond)
defer cancel()
// Force a connection to verify our connection string
err = client.Database().Client().Ping(ctx, readpref.SecondaryPreferred())
if err != nil {
return nil, errors.New(fmt.Sprintf("Failed to ping cluster: %s", err))
}
logx.Infof("[DocumentDB] Connected to DocumentDB!")
return &DocumentDB{
Mon: client,
}, nil
}
func (document *DocumentDB) PopulateIndex(ctx context.Context, key string, sort int32, unique bool) {
c := document.Mon.Collection
opts := options.CreateIndexes()
index := document.yieldIndexModel(
[]string{key}, []int32{sort}, unique, nil,
)
_, err := c.Indexes().CreateOne(ctx, index, opts)
if err != nil {
logx.Errorf("[DocumentDb] Ensure Index Failed, %s", err.Error())
}
}
func (document *DocumentDB) PopulateTTLIndex(ctx context.Context, key string, sort int32, unique bool, ttl int32) {
c := document.Mon.Collection
opts := options.CreateIndexes()
index := document.yieldIndexModel(
[]string{key}, []int32{sort}, unique,
options.Index().SetExpireAfterSeconds(ttl),
)
_, err := c.Indexes().CreateOne(ctx, index, opts)
if err != nil {
logx.Errorf("[DocumentDb] Ensure TTL Index Failed, %s", err.Error())
}
}
func (document *DocumentDB) PopulateMultiIndex(ctx context.Context, keys []string, sorts []int32, unique bool) {
if len(keys) != len(sorts) {
logx.Infof("[DocumentDb] Ensure Indexes Failed Please provide some item length of keys/sorts")
return
}
c := document.Mon.Collection
opts := options.CreateIndexes()
index := document.yieldIndexModel(keys, sorts, unique, nil)
_, err := c.Indexes().CreateOne(ctx, index, opts)
if err != nil {
logx.Errorf("[DocumentDb] Ensure TTL Index Failed, %s", err.Error())
}
}
// PopulateSparseMultiIndex 建立稀疏複合索引(只索引存在這些欄位的文檔)
func (document *DocumentDB) PopulateSparseMultiIndex(ctx context.Context, keys []string, sorts []int32, unique bool) {
if len(keys) != len(sorts) {
logx.Infof("[DocumentDb] Ensure Indexes Failed Please provide some item length of keys/sorts")
return
}
c := document.Mon.Collection
opts := options.CreateIndexes()
indexOpt := options.Index().SetSparse(true)
index := document.yieldIndexModel(keys, sorts, unique, indexOpt)
_, err := c.Indexes().CreateOne(ctx, index, opts)
if err != nil {
logx.Errorf("[DocumentDb] Ensure Sparse Multi Index Failed, %s", err.Error())
}
}
func (document *DocumentDB) GetClient() *mon.Model {
return document.Mon
}
func (document *DocumentDB) yieldIndexModel(keys []string, sorts []int32, unique bool, indexOpt *options.IndexOptionsBuilder) mongo.IndexModel {
SetKeysDoc := bson.D{}
for index := range keys {
key := keys[index]
sort := sorts[index]
SetKeysDoc = append(SetKeysDoc, bson.E{Key: key, Value: sort})
}
if indexOpt == nil {
indexOpt = options.Index()
}
indexOpt.SetUnique(unique)
index := mongo.IndexModel{
Keys: SetKeysDoc,
Options: indexOpt,
}
return index
}

View File

@ -0,0 +1,268 @@
package mongo
import (
"context"
"testing"
"time"
)
func TestNewDocumentDB(t *testing.T) {
// Test with valid config
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
db, err := NewDocumentDB(conf, "testcollection")
// Note: This will fail in test environment without MongoDB running
// but we can test the error handling and basic structure
if err == nil {
t.Log("MongoDB connection successful (MongoDB is running)")
// Test basic properties
if db == nil {
t.Error("Expected DocumentDB to be non-nil")
}
// Test GetClient
client := db.GetClient()
if client == nil {
t.Error("Expected client to be non-nil")
}
// Test that we got a valid DocumentDB
if db == nil {
t.Error("Expected DocumentDB to be non-nil")
}
} else {
t.Logf("MongoDB connection failed (expected in test environment): %v", err)
}
}
func TestDocumentDB_PopulateIndex(t *testing.T) {
// Test with mock data
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
db, err := NewDocumentDB(conf, "testcollection")
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test index creation
ctx := context.Background()
db.PopulateIndex(ctx, "field1", 1, false)
}
func TestDocumentDB_PopulateTTLIndex(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
db, err := NewDocumentDB(conf, "testcollection")
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test TTL index creation
ctx := context.Background()
ttl := int32(3600) // 1 hour
db.PopulateTTLIndex(ctx, "expireAt", 1, false, ttl)
}
func TestDocumentDB_PopulateMultiIndex(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
db, err := NewDocumentDB(conf, "testcollection")
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test multiple index creation
ctx := context.Background()
keys := []string{"field1", "field2", "field3"}
sorts := []int32{1, -1, 1}
db.PopulateMultiIndex(ctx, keys, sorts, false)
}
func TestDocumentDB_GetClient(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
db, err := NewDocumentDB(conf, "testcollection")
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
client := db.GetClient()
if client == nil {
t.Error("Expected client to be non-nil")
}
}
func TestDocumentDB_DatabaseName(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
db, err := NewDocumentDB(conf, "testcollection")
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test that we got a valid DocumentDB
if db == nil {
t.Error("Expected DocumentDB to be non-nil")
}
}
func TestDocumentDB_WithDifferentConfigs(t *testing.T) {
testCases := []struct {
name string
conf *Conf
}{
{
name: "minimal config",
conf: &Conf{
Host: "localhost:27017",
},
},
{
name: "with database",
conf: &Conf{
Host: "localhost:27017",
Database: "testdb",
},
},
{
name: "with credentials",
conf: &Conf{
Host: "localhost:27017",
Database: "testdb",
User: "user",
Password: "pass",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
db, err := NewDocumentDB(tc.conf, "testcollection")
if err != nil {
t.Logf("MongoDB connection failed (expected in test environment): %v", err)
return
}
if db == nil {
t.Error("Expected DocumentDB to be non-nil")
}
})
}
}
func TestDocumentDB_IndexOperations(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
db, err := NewDocumentDB(conf, "testcollection")
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test single index
ctx := context.Background()
db.PopulateIndex(ctx, "single_field", 1, false)
// Test TTL index
ttl := int32(1800) // 30 minutes
db.PopulateTTLIndex(ctx, "expiresAt", 1, false, ttl)
// Test multiple indexes
keys := []string{"field1", "field2", "compound_field1"}
sorts := []int32{1, -1, 1}
db.PopulateMultiIndex(ctx, keys, sorts, false)
}
func TestDocumentDB_ContextHandling(t *testing.T) {
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
// Test with timeout context
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
db, err := NewDocumentDB(conf, "testcollection")
// Use ctx to avoid unused variable warning
_ = ctx
if err != nil {
t.Logf("MongoDB connection failed (expected in test environment): %v", err)
return
}
if db == nil {
t.Error("Expected DocumentDB to be non-nil")
}
}
func TestDocumentDB_ErrorHandling(t *testing.T) {
// Test with invalid config
invalidConf := &Conf{
Host: "invalid-host:99999",
}
_, err := NewDocumentDB(invalidConf, "testcollection")
// This should fail
if err == nil {
t.Error("Expected error with invalid host, got nil")
}
}
func TestDocumentDB_IndexModelCreation(t *testing.T) {
// Test the yieldIndexModel function indirectly through PopulateIndex
conf := &Conf{
Host: "localhost:27017",
Database: "testdb",
}
db, err := NewDocumentDB(conf, "testcollection")
if err != nil {
t.Skip("Skipping test - MongoDB not available")
return
}
// Test with various index configurations
ctx := context.Background()
db.PopulateIndex(ctx, "ascending", 1, false)
db.PopulateIndex(ctx, "descending", -1, false)
}

View File

@ -0,0 +1,46 @@
package mongo
import (
"reflect"
"github.com/shopspring/decimal"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
type TypeCodec struct {
ValueType reflect.Type
Encoder bson.ValueEncoder
Decoder bson.ValueDecoder
}
// WithTypeCodec registers TypeCodecs to convert custom types.
func WithTypeCodec(typeCodecs ...TypeCodec) mon.Option {
return func(c *options.ClientOptions) {
registry := bson.NewRegistry()
for _, v := range typeCodecs {
registry.RegisterTypeEncoder(v.ValueType, v.Encoder)
registry.RegisterTypeDecoder(v.ValueType, v.Decoder)
}
c.SetRegistry(registry)
}
}
// SetCustomDecimalType force convert primitive.Decimal128 to decimal.Decimal.
func SetCustomDecimalType() mon.Option {
return WithTypeCodec(TypeCodec{
ValueType: reflect.TypeOf(decimal.Decimal{}),
Encoder: &MgoDecimal{},
Decoder: &MgoDecimal{},
})
}
func InitMongoOptions(cfg Conf) mon.Option {
return func(opts *options.ClientOptions) {
opts.SetMaxPoolSize(cfg.MaxPoolSize)
opts.SetMinPoolSize(cfg.MinPoolSize)
opts.SetMaxConnIdleTime(cfg.MaxConnIdleTime)
opts.SetCompressors([]string{"snappy"})
}
}

View File

@ -0,0 +1,230 @@
package mongo
import (
"reflect"
"testing"
"github.com/shopspring/decimal"
"github.com/zeromicro/go-zero/core/stores/mon"
)
func TestWithTypeCodec(t *testing.T) {
// Test creating a TypeCodec
codec := TypeCodec{
ValueType: reflect.TypeOf(decimal.Decimal{}),
Encoder: &MgoDecimal{},
Decoder: &MgoDecimal{},
}
if codec.ValueType != reflect.TypeOf(decimal.Decimal{}) {
t.Errorf("Expected ValueType to be decimal.Decimal, got %v", codec.ValueType)
}
if codec.Encoder == nil {
t.Error("Expected Encoder to be set")
}
if codec.Decoder == nil {
t.Error("Expected Decoder to be set")
}
// Test WithTypeCodec function
option := WithTypeCodec(codec)
if option == nil {
t.Error("Expected option to be non-nil")
}
}
func TestSetCustomDecimalType(t *testing.T) {
// Test setting custom decimal type
option := SetCustomDecimalType()
// Verify that the option is created
if option == nil {
t.Error("Expected option to be non-nil")
}
}
func TestInitMongoOptions(t *testing.T) {
// Test with default config
conf := Conf{}
opts := InitMongoOptions(conf)
if opts == nil {
t.Error("Expected options to be non-nil")
}
// Test with custom config
confWithValues := Conf{
Host: "localhost:27017",
Database: "testdb",
User: "testuser",
Password: "testpass",
}
optsWithValues := InitMongoOptions(confWithValues)
if optsWithValues == nil {
t.Error("Expected options to be non-nil")
}
// Test that the options are properly configured
// We can't directly test the internal configuration, but we can test that it doesn't panic
}
func TestTypeCodec_InterfaceCompliance(t *testing.T) {
codec := TypeCodec{
ValueType: reflect.TypeOf(decimal.Decimal{}),
Encoder: &MgoDecimal{},
Decoder: &MgoDecimal{},
}
// Test that the codec can be used
if codec.ValueType == nil {
t.Error("Expected ValueType to be set")
}
if codec.Encoder == nil {
t.Error("Expected Encoder to be set")
}
if codec.Decoder == nil {
t.Error("Expected Decoder to be set")
}
}
func TestMgoDecimal_WithRegistry(t *testing.T) {
// Test that MgoDecimal can be used with a registry
option := SetCustomDecimalType()
// Test that the option is created
if option == nil {
t.Error("Expected option to be non-nil")
}
// Test basic decimal operations
dec := decimal.NewFromFloat(123.45)
// Test that decimal operations work
if dec.IsZero() {
t.Error("Expected decimal to be non-zero")
}
// Test string conversion
decStr := dec.String()
if decStr != "123.45" {
t.Errorf("Expected '123.45', got '%s'", decStr)
}
}
func TestInitMongoOptions_WithDifferentConfigs(t *testing.T) {
testCases := []struct {
name string
conf Conf
}{
{
name: "empty config",
conf: Conf{},
},
{
name: "with host",
conf: Conf{
Host: "localhost:27017",
},
},
{
name: "with database",
conf: Conf{
Database: "testdb",
},
},
{
name: "with credentials",
conf: Conf{
User: "user",
Password: "pass",
},
},
{
name: "full config",
conf: Conf{
Host: "localhost:27017",
Database: "testdb",
User: "user",
Password: "pass",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
opts := InitMongoOptions(tc.conf)
if opts == nil {
t.Error("Expected options to be non-nil")
}
})
}
}
func TestWithTypeCodec_EdgeCases(t *testing.T) {
// Test with nil encoder
codec := TypeCodec{
ValueType: reflect.TypeOf(decimal.Decimal{}),
Encoder: nil,
Decoder: &MgoDecimal{},
}
if codec.Encoder != nil {
t.Error("Expected Encoder to be nil")
}
if codec.Decoder == nil {
t.Error("Expected Decoder to be set")
}
// Test with nil decoder
codec2 := TypeCodec{
ValueType: reflect.TypeOf(decimal.Decimal{}),
Encoder: &MgoDecimal{},
Decoder: nil,
}
if codec2.Encoder == nil {
t.Error("Expected Encoder to be set")
}
if codec2.Decoder != nil {
t.Error("Expected Decoder to be nil")
}
}
func TestSetCustomDecimalType_MultipleCalls(t *testing.T) {
// Test calling SetCustomDecimalType multiple times
// First call
option1 := SetCustomDecimalType()
// Second call should not panic
option2 := SetCustomDecimalType()
// Options should be valid
if option1 == nil {
t.Error("Expected option1 to be non-nil")
}
if option2 == nil {
t.Error("Expected option2 to be non-nil")
}
}
func TestInitMongoOptions_ReturnType(t *testing.T) {
conf := Conf{}
opts := InitMongoOptions(conf)
// Test that the returned type is correct
if opts == nil {
t.Error("Expected options to be non-nil")
}
// Test that we can use the options (basic type check)
var _ mon.Option = opts
}

View File

@ -0,0 +1,36 @@
package mongo
import (
"context"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/mongo"
mopt "go.mongodb.org/mongo-driver/v2/mongo/options"
)
type DocumentDBUseCase interface {
PopulateIndex(ctx context.Context, key string, sort int32, unique bool)
PopulateTTLIndex(ctx context.Context, key string, sort int32, unique bool, ttl int32)
PopulateMultiIndex(ctx context.Context, keys []string, sorts []int32, unique bool)
PopulateSparseMultiIndex(ctx context.Context, keys []string, sorts []int32, unique bool)
GetClient() *mon.Model
}
type DocumentDBWithCacheUseCase interface {
DocumentDBUseCase
CacheUseCase
DeleteOne(ctx context.Context, key string, filter any, opts ...*mopt.DeleteOneOptions) (int64, error)
FindOne(ctx context.Context, key string, v, filter any, opts ...*mopt.FindOneOptions) error
FindOneAndDelete(ctx context.Context, key string, v, filter any, opts ...*mopt.FindOneAndDeleteOptions) error
FindOneAndReplace(ctx context.Context, key string, v, filter, replacement any, opts ...*mopt.FindOneAndReplaceOptions) error
InsertOne(ctx context.Context, key string, document any, opts ...*mopt.InsertOneOptions) (*mongo.InsertOneResult, error)
UpdateByID(ctx context.Context, key string, id, update any, opts ...*mopt.UpdateOneOptions) (*mongo.UpdateResult, error)
UpdateMany(ctx context.Context, keys []string, filter, update any, opts ...*mopt.UpdateManyOptions) (*mongo.UpdateResult, error)
UpdateOne(ctx context.Context, key string, filter, update any, opts ...*mopt.UpdateOneOptions) (*mongo.UpdateResult, error)
}
type CacheUseCase interface {
DelCache(ctx context.Context, keys ...string) error
GetCache(key string, v any) error
SetCache(key string, v any) error
}

View File

@ -0,0 +1,50 @@
package required
import (
"fmt"
"github.com/go-playground/validator/v10"
"github.com/zeromicro/go-zero/core/logx"
)
type Validate interface {
ValidateAll(obj any) error
BindToValidator(opts ...Option) error
}
type Validator struct {
V *validator.Validate
}
func (v *Validator) ValidateAll(obj any) error {
err := v.V.Struct(obj)
if err != nil {
return err
}
return nil
}
func (v *Validator) BindToValidator(opts ...Option) error {
for _, item := range opts {
err := v.V.RegisterValidation(item.ValidatorName, item.ValidatorFunc)
if err != nil {
return fmt.Errorf("failed to register validator : %w", err)
}
}
return nil
}
func MustValidator(option ...Option) Validate {
v := &Validator{
V: validator.New(),
}
if err := v.BindToValidator(option...); err != nil {
logx.Error("failed to bind validator")
}
return v
}

View File

@ -0,0 +1,29 @@
package required
import (
"regexp"
"github.com/go-playground/validator/v10"
)
type Option struct {
ValidatorName string
ValidatorFunc func(fl validator.FieldLevel) bool
}
// WithAccount 創建一個新的 Option 結構,包含自定義的驗證函數,用於驗證 email 和台灣的手機號碼格式
func WithAccount(tagName string) Option {
return Option{
ValidatorName: tagName,
ValidatorFunc: func(fl validator.FieldLevel) bool {
value := fl.Field().String()
emailRegex := `^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}$`
phoneRegex := `^(\+886|0)?9\d{8}$`
emailMatch, _ := regexp.MatchString(emailRegex, value)
phoneMatch, _ := regexp.MatchString(phoneRegex, value)
return emailMatch || phoneMatch
},
}
}

View File

@ -0,0 +1,71 @@
package worker_pool
import (
"sync"
"github.com/panjf2000/ants/v2"
)
const defaultWorkerPoolSize = 2000
type WorkerPool interface {
Submit(task func()) error
SubmitAndWaitAll(tasks ...func() error) (taskErr chan error, submitErr error)
}
type workerPool struct {
p *ants.Pool
}
func NewWorkerPool(size int) WorkerPool {
if size <= 0 {
size = defaultWorkerPoolSize
}
p, err := ants.NewPool(
size,
ants.WithDisablePurge(true),
)
if err != nil {
return &workerPool{p: nil}
}
return &workerPool{p: p}
}
func (p *workerPool) Submit(task func()) error {
if p.p == nil {
return ants.Submit(task)
}
return p.p.Submit(task)
}
func (p *workerPool) SubmitAndWaitAll(tasks ...func() error) (chan error, error) {
taskErrCh := make(chan error, len(tasks))
submitErrCh := make(chan error, len(tasks))
wg := sync.WaitGroup{}
wg.Add(len(tasks))
for i := range tasks {
task := tasks[i]
err := p.Submit(func() {
defer wg.Done()
if err := task(); err != nil {
taskErrCh <- err
}
})
if err != nil {
submitErrCh <- err
wg.Done()
}
}
wg.Wait()
if len(submitErrCh) != 0 {
return nil, <-submitErrCh
}
return taskErrCh, nil
}

View File

@ -0,0 +1,89 @@
package worker_pool
import (
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewWorkerPool(t *testing.T) {
t.Run("default size pool", func(t *testing.T) {
pool := NewWorkerPool(0)
assert.NotNil(t, pool)
})
t.Run("custom size pool", func(t *testing.T) {
size := 100
pool := NewWorkerPool(size)
assert.NotNil(t, pool)
})
}
func TestSubmit(t *testing.T) {
t.Run("submit task to worker pool", func(t *testing.T) {
pool := NewWorkerPool(10)
var wg sync.WaitGroup
wg.Add(1)
err := pool.Submit(func() {
defer wg.Done()
time.Sleep(100 * time.Millisecond)
})
assert.NoError(t, err)
wg.Wait()
})
}
func TestSubmitAndWaitAll(t *testing.T) {
t.Run("submit and wait all tasks succeed", func(t *testing.T) {
pool := NewWorkerPool(10)
tasks := []func() error{
func() error {
time.Sleep(100 * time.Millisecond)
return nil
},
func() error {
time.Sleep(50 * time.Millisecond)
return nil
},
}
taskErrCh, submitErr := pool.SubmitAndWaitAll(tasks...)
assert.NoError(t, submitErr)
close(taskErrCh)
for err := range taskErrCh {
assert.NoError(t, err)
}
})
t.Run("submit and wait all tasks with errors", func(t *testing.T) {
pool := NewWorkerPool(10)
expectedError := errors.New("task error")
tasks := []func() error{
func() error {
time.Sleep(100 * time.Millisecond)
return nil
},
func() error {
time.Sleep(50 * time.Millisecond)
return expectedError
},
}
taskErrCh, submitErr := pool.SubmitAndWaitAll(tasks...)
assert.NoError(t, submitErr)
close(taskErrCh)
foundError := false
for err := range taskErrCh {
if err != nil {
foundError = true
assert.Equal(t, expectedError, err)
}
}
assert.True(t, foundError)
})
}

View File

@ -0,0 +1,78 @@
package middleware
import (
"chat/internal/config"
"context"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/zeromicro/go-zero/core/logx"
)
type AnonMiddleware struct {
jwtSecret string
}
func NewAnonMiddleware(c config.Config) *AnonMiddleware {
return &AnonMiddleware{
jwtSecret: c.JWT.Secret,
}
}
func (m *AnonMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 從 Authorization header 提取 token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header is required", http.StatusUnauthorized)
return
}
// 移除 "Bearer " 前綴
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
return
}
tokenString := parts[1]
// 解析和驗證 JWT
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// 驗證簽名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.jwtSecret), nil
})
if err != nil {
logx.Errorf("Failed to parse JWT: %v", err)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
if !token.Valid {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
// 提取 UID
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return
}
uid, ok := claims["uid"].(string)
if !ok || uid == "" {
http.Error(w, "UID not found in token", http.StatusUnauthorized)
return
}
// 將 UID 存入 context
ctx := context.WithValue(r.Context(), "uid", uid)
next(w, r.WithContext(ctx))
}
}

View File

@ -0,0 +1,24 @@
package middleware
import (
"net/http"
)
// CORSMiddleware 處理 CORS 請求
func CORSMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 設置 CORS 標頭
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Max-Age", "3600")
// 處理預檢請求
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next(w, r)
}
}

View File

@ -0,0 +1,132 @@
package repository
import (
"chat/internal/domain/const"
redisKey "chat/internal/domain/redis"
domainRepo "chat/internal/domain/repository"
"chat/internal/utils"
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
type matchmakingRepository struct {
client *redis.Client
}
// NewMatchmakingRepository 創建新的配對 Repository
func NewMatchmakingRepository(client *redis.Client) domainRepo.MatchmakingRepository {
return &matchmakingRepository{
client: client,
}
}
// JoinQueue 加入配對佇列
func (r *matchmakingRepository) JoinQueue(ctx context.Context, uid string) (status string, roomID string, err error) {
// 檢查使用者是否已經在配對中或已配對
userKey := redisKey.MatchUserKey(uid)
currentStatus, err := r.client.Get(ctx, userKey).Result()
if err != nil && err != redis.Nil {
return "", "", fmt.Errorf("failed to check user status: %w", err)
}
// 如果已經配對返回房間ID
if currentStatus != "" && currentStatus != consts.StatusWaiting {
return consts.StatusMatched, currentStatus, nil
}
// 如果正在等待,返回等待狀態
if currentStatus == consts.StatusWaiting {
return consts.StatusWaiting, "", nil
}
// 嘗試從佇列中取出一個使用者進行配對
queueKey := redisKey.MatchQueueKey()
otherUID, err := r.client.LPop(ctx, queueKey).Result()
if errors.Is(err, redis.Nil) {
// 佇列為空,將自己加入佇列
if err := r.client.RPush(ctx, queueKey, uid).Err(); err != nil {
return "", "", fmt.Errorf("failed to join queue: %w", err)
}
// 設置等待狀態TTL 5 分鐘
if err := r.client.Set(ctx, userKey, consts.StatusWaiting, 5*time.Minute).Err(); err != nil {
return "", "", fmt.Errorf("failed to set waiting status: %w", err)
}
return consts.StatusWaiting, "", nil
}
if err != nil {
return "", "", fmt.Errorf("failed to pop from queue: %w", err)
}
// 配對成功生成房間ID
roomID = utils.GenerateRoomID()
// 設置雙方的房間IDTTL 1 小時
roomTTL := 1 * time.Hour
if err := r.client.Set(ctx, userKey, roomID, roomTTL).Err(); err != nil {
return "", "", fmt.Errorf("failed to set room for user: %w", err)
}
otherUserKey := redisKey.MatchUserKey(otherUID)
if err := r.client.Set(ctx, otherUserKey, roomID, roomTTL).Err(); err != nil {
return "", "", fmt.Errorf("failed to set room for other user: %w", err)
}
// 建立房間成員 Set
membersKey := redisKey.RoomMembersKey(roomID)
if err := r.client.SAdd(ctx, membersKey, uid, otherUID).Err(); err != nil {
return "", "", fmt.Errorf("failed to create room members: %w", err)
}
if err := r.client.Expire(ctx, membersKey, roomTTL).Err(); err != nil {
return "", "", fmt.Errorf("failed to set room TTL: %w", err)
}
return consts.StatusMatched, roomID, nil
}
// GetMatchStatus 查詢使用者的配對狀態
func (r *matchmakingRepository) GetMatchStatus(ctx context.Context, uid string) (status string, roomID string, err error) {
userKey := redisKey.MatchUserKey(uid)
value, err := r.client.Get(ctx, userKey).Result()
if err == redis.Nil {
// 沒有配對記錄
return "", "", nil
}
if err != nil {
return "", "", fmt.Errorf("failed to get match status: %w", err)
}
if value == consts.StatusWaiting {
return consts.StatusWaiting, "", nil
}
// value 是 roomID
return consts.StatusMatched, value, nil
}
// CreateRoom 建立房間並添加成員
func (r *matchmakingRepository) CreateRoom(ctx context.Context, roomID string, members []string) error {
membersKey := redisKey.RoomMembersKey(roomID)
if err := r.client.SAdd(ctx, membersKey, members).Err(); err != nil {
return fmt.Errorf("failed to add room members: %w", err)
}
// 設置 TTL 1 小時
if err := r.client.Expire(ctx, membersKey, 1*time.Hour).Err(); err != nil {
return fmt.Errorf("failed to set room TTL: %w", err)
}
return nil
}
// IsRoomMember 檢查使用者是否為房間成員
func (r *matchmakingRepository) IsRoomMember(ctx context.Context, roomID string, uid string) (bool, error) {
membersKey := redisKey.RoomMembersKey(roomID)
isMember, err := r.client.SIsMember(ctx, membersKey, uid).Result()
if err != nil {
return false, fmt.Errorf("failed to check room membership: %w", err)
}
return isMember, nil
}

View File

@ -0,0 +1,73 @@
package repository
import (
"chat/internal/domain/entity"
"chat/internal/domain/repository"
"chat/internal/library/cassandra"
"context"
"fmt"
"math"
)
type messageRepository struct {
repo cassandra.Repository[entity.Message]
db *cassandra.DB
}
// NewMessageRepository 創建新的訊息 Repository
func NewMessageRepository(db *cassandra.DB, keyspace string) (repository.MessageRepository, error) {
repo, err := cassandra.NewRepository[entity.Message](db, keyspace)
if err != nil {
return nil, fmt.Errorf("failed to create message repository: %w", err)
}
return &messageRepository{
repo: repo,
db: db,
}, nil
}
// Insert 插入訊息
func (r *messageRepository) Insert(ctx context.Context, msg *entity.Message) error {
return r.repo.Insert(ctx, *msg)
}
// ListByRoom 查詢房間訊息(分頁)
func (r *messageRepository) ListByRoom(ctx context.Context, roomID string, bucketDay string, pageSize int, pageIndex int) ([]entity.Message, int64, error) {
// 計算分頁
if pageSize <= 0 {
pageSize = 20
}
if pageIndex <= 0 {
pageIndex = 1
}
// 構建查詢條件
query := r.repo.Query().
Where(cassandra.Eq("room_id", roomID)).
Where(cassandra.Eq("bucket_day", bucketDay)).
OrderBy("ts", cassandra.DESC).
OrderBy("message_id", cassandra.DESC).
Limit(pageSize)
// 先查詢總數
total, err := r.repo.Query().
Where(cassandra.Eq("room_id", roomID)).
Where(cassandra.Eq("bucket_day", bucketDay)).
Count(ctx)
if err != nil {
return nil, 0, fmt.Errorf("failed to count messages: %w", err)
}
// 執行查詢
var messages []entity.Message
if err := query.Scan(ctx, &messages); err != nil {
return nil, 0, fmt.Errorf("failed to query messages: %w", err)
}
// 計算總頁數
totalPages := int64(math.Ceil(float64(total) / float64(pageSize)))
return messages, totalPages, nil
}

View File

@ -0,0 +1,39 @@
package repository
import (
"chat/internal/library/cassandra"
"context"
"fmt"
)
// InitSchema 初始化 Cassandra keyspace 和表結構
func InitSchema(ctx context.Context, db *cassandra.DB, keyspace string) error {
// 建立 keyspace如果不存在
createKeyspaceStmt := fmt.Sprintf(
"CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}",
keyspace,
)
session := db.GetSession()
if err := session.Query(createKeyspaceStmt, nil).Exec(); err != nil {
return fmt.Errorf("failed to create keyspace: %w", err)
}
// 建立 messages_by_room 表
createTableStmt := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s.messages_by_room (
room_id text,
bucket_day text,
ts bigint,
message_id text,
uid text,
content text,
PRIMARY KEY ((room_id, bucket_day), ts, message_id)
) WITH CLUSTERING ORDER BY (ts DESC, message_id DESC)
`, keyspace)
if err := session.Query(createTableStmt, nil).Exec(); err != nil {
return fmt.Errorf("failed to create messages_by_room table: %w", err)
}
return nil
}

36
internal/svc/cassandra.go Normal file
View File

@ -0,0 +1,36 @@
package svc
import (
"chat/internal/config"
"chat/internal/library/cassandra"
repoImpl "chat/internal/repository"
"context"
"github.com/zeromicro/go-zero/core/logx"
)
func initCassandra(ctx context.Context, c config.CassandraConf) (*cassandra.DB, error) {
// 初始化 Cassandra DB
var cassandraOpts []cassandra.Option
cassandraOpts = append(cassandraOpts, cassandra.WithHosts(c.Hosts...))
cassandraOpts = append(cassandraOpts, cassandra.WithPort(c.Port))
cassandraOpts = append(cassandraOpts, cassandra.WithKeyspace(c.Keyspace))
if c.UseAuth {
cassandraOpts = append(cassandraOpts, cassandra.WithAuth(c.Username, c.Password))
}
cassandraDB, err := cassandra.New(cassandraOpts...)
if err != nil {
logx.Errorf("Failed to connect to Cassandra: %v", err)
return nil, err
}
logx.Infof("Connected to Cassandra at %v:%d", c.Hosts, c.Port)
// 初始化 schema創建 keyspace 和表)
if err := repoImpl.InitSchema(ctx, cassandraDB, c.Keyspace); err != nil {
logx.Errorf("Failed to initialize Cassandra schema: %v", err)
return nil, err
}
logx.Infof("Cassandra schema initialized for keyspace: %s", c.Keyspace)
return cassandraDB, nil
}

View File

@ -0,0 +1,10 @@
package svc
import (
"chat/internal/config"
"chat/internal/library/centrifugo"
)
func initCentrifugo(c config.CentrifugoConf) *centrifugo.Client {
return centrifugo.NewClient(c.APIURL, c.APIKey)
}

27
internal/svc/redis.go Normal file
View File

@ -0,0 +1,27 @@
package svc
import (
"chat/internal/config"
"context"
"fmt"
"github.com/redis/go-redis/v9"
"github.com/zeromicro/go-zero/core/logx"
)
func initRedis(ctx context.Context, c config.RedisConf) (*redis.Client, error) {
// 初始化 Redis 客戶端
redisClient := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", c.Host, c.Port),
Password: c.Password,
DB: c.DB,
})
// 測試 Redis 連線
if err := redisClient.Ping(ctx).Err(); err != nil {
logx.Errorf("Failed to connect to Redis: %v", err)
return nil, err
}
logx.Infof("Connected to Redis at %s:%d", c.Host, c.Port)
return redisClient, nil
}

View File

@ -0,0 +1,77 @@
package svc
import (
"chat/internal/config"
domainRepo "chat/internal/domain/repository"
domainUsecase "chat/internal/domain/usecase"
"chat/internal/library/cassandra"
"chat/internal/library/centrifugo"
"chat/internal/middleware"
repoImpl "chat/internal/repository"
usecaseImpl "chat/internal/usecase"
"context"
"github.com/redis/go-redis/v9"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest"
)
type ServiceContext struct {
Config config.Config
AnonMiddleware rest.Middleware
// Clients
RedisClient *redis.Client
CentrifugoClient *centrifugo.Client
CassandraDB *cassandra.DB
// Repositories
MatchmakingRepo domainRepo.MatchmakingRepository
MessageRepo domainRepo.MessageRepository
// UseCases
AuthUseCase domainUsecase.AuthUseCase
MatchmakingUseCase domainUsecase.MatchmakingUseCase
MessageUseCase domainUsecase.MessageUseCase
}
func NewServiceContext(c config.Config) *ServiceContext {
ctx := context.Background()
redis, err := initRedis(ctx, c.Redis)
if err != nil {
panic(err)
}
cassandraDB, err := initCassandra(ctx, c.Cassandra)
if err != nil {
panic(err)
}
centrifugoClient := initCentrifugo(c.Centrifugo)
// 初始化 Repositories
// 初始化 Repositories
matchmakingRepo := repoImpl.NewMatchmakingRepository(redis)
messageRepo, err := repoImpl.NewMessageRepository(cassandraDB, c.Cassandra.Keyspace)
if err != nil {
logx.Errorf("Failed to create message repository: %v", err)
}
// 初始化 UseCases
// AuthUseCase 需要 MatchmakingRepository 來查詢使用者所在的房間
authUseCase := usecaseImpl.NewAuthUseCase(c.JWT.Secret, c.JWT.Expire, c.JWT.CentrifugoSecret, matchmakingRepo)
matchmakingUseCase := usecaseImpl.NewMatchmakingUseCase(matchmakingRepo)
messageUseCase := usecaseImpl.NewMessageUseCase(messageRepo, matchmakingRepo, centrifugoClient)
return &ServiceContext{
Config: c,
AnonMiddleware: middleware.NewAnonMiddleware(c).Handle,
RedisClient: redis,
CentrifugoClient: centrifugoClient,
CassandraDB: cassandraDB,
MatchmakingRepo: matchmakingRepo,
MessageRepo: messageRepo,
AuthUseCase: authUseCase,
MatchmakingUseCase: matchmakingUseCase,
MessageUseCase: messageUseCase,
}
}

83
internal/types/types.go Normal file
View File

@ -0,0 +1,83 @@
// Code generated by goctl. DO NOT EDIT.
// goctl 1.9.2
package types
type AnonLoginReq struct {
Name string `json:"name" required:"required"`
}
type AnonLoginResp struct {
UID string `json:"uid"`
Token string `json:"token"`
CentrifugoToken string `json:"centrifugo_token"` // Centrifugo WebSocket 連線用的 token
RefreshToken string `json:"refresh_token"`
ExpireAt int64 `json:"expire_at"`
}
type AuthHeader struct {
Authorization string `header:"Authorization" binding:"required"`
}
type BaseReq struct {
}
type BaseResp struct {
}
type BaseResponse struct {
Code string `json:"code"` // 狀態碼
Message string `json:"message"` // 訊息
Data interface{} `json:"data,omitempty"` // 資料
Error interface{} `json:"error,omitempty"` // 可選的錯誤信息
}
type ListMessageReq struct {
AuthHeader
RoomID string `path:"room_id"`
PageSize int64 `form:"page_size,default=20"`
PageIndex int64 `form:"page_index,default=1"`
}
type ListMessageResp struct {
Pager Pagination `json:"pager"`
Data []Message `json:"data"`
}
type MatchJoinReq struct {
AuthHeader
}
type MatchJoinResp struct {
Status string `json:"status"` // waiting | matched
}
type MatchStatusResp struct {
Status string `json:"status"` // waiting | matched
RoomID string `json:"room_id,omitempty"`
}
type Message struct {
MessageID string `json:"message_id"`
UID string `json:"uid"`
Content string `json:"content"`
Timestamp int64 `json:"timestamp"`
}
type Pagination struct {
Total int64 `json:"total,example=100"`
Page int64 `json:"page,example=1"`
PageSize int64 `json:"pageSize,example=10"`
TotalPages int64 `json:"totalPages,example=10"`
}
type RefreshTokenReq struct {
Token string `json:"token"` // 舊的 token可以是已過期的
}
type SendMessageReq struct {
AuthHeader
RoomID string `path:"room_id"`
Content string `json:"content"`
ClientMsgID string `json:"client_msg_id"`
}

153
internal/usecase/auth.go Normal file
View File

@ -0,0 +1,153 @@
package usecase
import (
"chat/internal/domain/repository"
"chat/internal/domain/usecase"
"chat/internal/utils"
"context"
"fmt"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/zeromicro/go-zero/core/logx"
)
type authUseCase struct {
jwtSecret string
jwtExpire int64
centrifugoSecret string
matchmakingRepo repository.MatchmakingRepository
}
// NewAuthUseCase 創建新的認證 UseCase
func NewAuthUseCase(jwtSecret string, jwtExpire int64, centrifugoSecret string, matchmakingRepo repository.MatchmakingRepository) usecase.AuthUseCase {
if centrifugoSecret == "" {
centrifugoSecret = jwtSecret
}
return &authUseCase{
jwtSecret: jwtSecret,
jwtExpire: jwtExpire,
centrifugoSecret: centrifugoSecret,
matchmakingRepo: matchmakingRepo,
}
}
// AnonLogin 匿名登入
func (u *authUseCase) AnonLogin(ctx context.Context, name string) (uid string, token string, centrifugoToken string, expireAt int64, err error) {
// 生成匿名 UID
uid = utils.GenerateUID()
// 生成 API JWT token
now := time.Now()
expireAt = now.Add(time.Duration(u.jwtExpire) * time.Second).Unix()
claims := jwt.MapClaims{
"uid": uid,
"exp": expireAt,
"iat": now.Unix(),
"name": name,
}
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token, err = jwtToken.SignedString([]byte(u.jwtSecret))
if err != nil {
return "", "", "", 0, fmt.Errorf("failed to sign JWT: %w", err)
}
// 匿名登入時還未加入房間,只給予個人頻道的權限
centrifugoExpireAt := expireAt
centrifugoClaims := jwt.MapClaims{
"sub": uid,
"exp": centrifugoExpireAt,
"iat": now.Unix(),
"channels": []string{fmt.Sprintf("user:%s", uid)},
"name": name,
}
centrifugoJwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, centrifugoClaims)
centrifugoToken, err = centrifugoJwtToken.SignedString([]byte(u.centrifugoSecret))
if err != nil {
return "", "", "", 0, fmt.Errorf("failed to sign Centrifugo JWT: %w", err)
}
logx.Infof("User %s logged in anonymously", uid)
return uid, token, centrifugoToken, expireAt, nil
}
// RefreshToken 刷新 token
func (u *authUseCase) RefreshToken(ctx context.Context, oldToken string) (uid string, token string, centrifugoToken string, expireAt int64, err error) {
// 解析舊 token
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
tokenObj, err := parser.Parse(oldToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(u.jwtSecret), nil
})
if err != nil {
return "", "", "", 0, fmt.Errorf("failed to parse old token: %w", err)
}
claims, ok := tokenObj.Claims.(jwt.MapClaims)
if !ok {
return "", "", "", 0, fmt.Errorf("invalid token claims")
}
uid, ok = claims["uid"].(string)
if !ok || uid == "" {
return "", "", "", 0, fmt.Errorf("UID not found in token")
}
name, _ := claims["name"].(string)
// 檢查使用者當前的房間狀態
_, roomID, err := u.matchmakingRepo.GetMatchStatus(ctx, uid)
if err != nil {
logx.Errorf("Failed to get match status for refresh token: %v", err)
// 不中斷流程,只是不給房間權限
}
// 生成新的 API JWT token
now := time.Now()
expireAt = now.Add(time.Duration(u.jwtExpire) * time.Second).Unix()
newClaims := jwt.MapClaims{
"uid": uid,
"exp": expireAt,
"iat": now.Unix(),
"name": name,
}
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, newClaims)
token, err = jwtToken.SignedString([]byte(u.jwtSecret))
if err != nil {
return "", "", "", 0, fmt.Errorf("failed to sign new JWT: %w", err)
}
// 生成新的 Centrifugo JWT token
channels := []string{fmt.Sprintf("user:%s", uid)}
// 如果已經在房間中,添加房間頻道的權限
if roomID != "" {
channels = append(channels, fmt.Sprintf("room:%s", roomID))
}
centrifugoExpireAt := expireAt
centrifugoClaims := jwt.MapClaims{
"sub": uid,
"exp": centrifugoExpireAt,
"iat": now.Unix(),
"channels": channels,
"name": name,
}
centrifugoJwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, centrifugoClaims)
centrifugoToken, err = centrifugoJwtToken.SignedString([]byte(u.centrifugoSecret))
if err != nil {
return "", "", "", 0, fmt.Errorf("failed to sign new Centrifugo JWT: %w", err)
}
logx.Infof("User %s refreshed token, room: %s", uid, roomID)
return uid, token, centrifugoToken, expireAt, nil
}

View File

@ -0,0 +1,38 @@
package usecase
import (
"chat/internal/domain/repository"
"chat/internal/domain/usecase"
"context"
"fmt"
)
type matchmakingUseCase struct {
matchmakingRepo repository.MatchmakingRepository
}
// NewMatchmakingUseCase 創建新的配對 UseCase
func NewMatchmakingUseCase(matchmakingRepo repository.MatchmakingRepository) usecase.MatchmakingUseCase {
return &matchmakingUseCase{
matchmakingRepo: matchmakingRepo,
}
}
// JoinQueue 加入配對佇列
func (u *matchmakingUseCase) JoinQueue(ctx context.Context, uid string) (status string, err error) {
status, _, err = u.matchmakingRepo.JoinQueue(ctx, uid)
if err != nil {
return "", fmt.Errorf("failed to join queue: %w", err)
}
return status, nil
}
// GetStatus 查詢配對狀態
func (u *matchmakingUseCase) GetStatus(ctx context.Context, uid string) (status string, roomID string, err error) {
status, roomID, err = u.matchmakingRepo.GetMatchStatus(ctx, uid)
if err != nil {
return "", "", fmt.Errorf("failed to get match status: %w", err)
}
return status, roomID, nil
}

115
internal/usecase/message.go Normal file
View File

@ -0,0 +1,115 @@
package usecase
import (
"chat/internal/domain/entity"
"chat/internal/domain/repository"
"chat/internal/domain/usecase"
"chat/internal/library/centrifugo"
"chat/internal/utils"
"context"
"fmt"
"time"
"github.com/zeromicro/go-zero/core/logx"
)
type messageUseCase struct {
messageRepo repository.MessageRepository
matchmakingRepo repository.MatchmakingRepository
centrifugoClient *centrifugo.Client
}
// NewMessageUseCase 創建新的訊息 UseCase
func NewMessageUseCase(
messageRepo repository.MessageRepository,
matchmakingRepo repository.MatchmakingRepository,
centrifugoClient *centrifugo.Client,
) usecase.MessageUseCase {
return &messageUseCase{
messageRepo: messageRepo,
matchmakingRepo: matchmakingRepo,
centrifugoClient: centrifugoClient,
}
}
// SendMessage 發送訊息
func (u *messageUseCase) SendMessage(ctx context.Context, roomID string, uid string, content string, clientMsgID string) error {
// 驗證訊息內容不為空
if content == "" {
return fmt.Errorf("message content cannot be empty")
}
// 驗證使用者是否在房間中
isMember, err := u.matchmakingRepo.IsRoomMember(ctx, roomID, uid)
if err != nil {
logx.Errorf("Failed to check room membership: roomID=%s, uid=%s, error=%v", roomID, uid, err)
return fmt.Errorf("failed to check room membership: %w", err)
}
if !isMember {
logx.Errorf("User is not a member of the room: roomID=%s, uid=%s", roomID, uid)
return fmt.Errorf("user is not a member of the room")
}
// 生成訊息 ID
messageID := utils.GenerateMessageID()
if clientMsgID != "" {
messageID = clientMsgID
}
// 建立訊息實體
now := time.Now()
msg := &entity.Message{
RoomID: roomID,
BucketDay: utils.GetBucketDay(now),
TS: now.UnixNano() / 1e6, // milliseconds
MessageID: messageID,
UID: uid,
Content: content,
}
// 儲存到 Cassandra
if err := u.messageRepo.Insert(ctx, msg); err != nil {
return fmt.Errorf("failed to save message: %w", err)
}
// 發布到 Centrifugo
channel := fmt.Sprintf("room:%s", roomID)
messageData := map[string]interface{}{
"message_id": msg.MessageID,
"uid": msg.UID,
"content": msg.Content,
"timestamp": msg.TS,
"room_id": msg.RoomID,
}
if err := u.centrifugoClient.PublishJSON(channel, messageData); err != nil {
logx.Errorf("failed to publish message to Centrifugo: %v", err)
// 不返回錯誤,因為訊息已經儲存
}
return nil
}
// ListMessages 查詢訊息列表(分頁)
func (u *messageUseCase) ListMessages(ctx context.Context, roomID string, uid string, pageSize int, pageIndex int) ([]entity.Message, int64, error) {
// 驗證使用者是否在房間中
isMember, err := u.matchmakingRepo.IsRoomMember(ctx, roomID, uid)
if err != nil {
return nil, 0, fmt.Errorf("failed to check room membership: %w", err)
}
if !isMember {
return nil, 0, fmt.Errorf("user is not a member of the room")
}
// 取得今天的 bucket_day
bucketDay := utils.GetTodayBucketDay()
// 查詢訊息
messages, totalPages, err := u.messageRepo.ListByRoom(ctx, roomID, bucketDay, pageSize, pageIndex)
if err != nil {
return nil, 0, fmt.Errorf("failed to list messages: %w", err)
}
return messages, totalPages, nil
}

14
internal/utils/time.go Normal file
View File

@ -0,0 +1,14 @@
package utils
import "time"
// GetBucketDay 取得 bucket_dayyyyyMMdd 格式)
func GetBucketDay(t time.Time) string {
return t.Format("20060102")
}
// GetTodayBucketDay 取得今天的 bucket_day
func GetTodayBucketDay() string {
return GetBucketDay(time.Now())
}

24
internal/utils/uuid.go Normal file
View File

@ -0,0 +1,24 @@
package utils
import (
"chat/internal/domain/const"
"fmt"
"github.com/google/uuid"
)
// GenerateUID 生成匿名 UID
func GenerateUID() string {
return fmt.Sprintf("%s%s", consts.AnonUIDPrefix, uuid.New().String()[:8])
}
// GenerateRoomID 生成房間 ID
func GenerateRoomID() string {
return fmt.Sprintf("%s%s", consts.RoomIDPrefix, uuid.New().String())
}
// GenerateMessageID 生成訊息 ID
func GenerateMessageID() string {
return uuid.New().String()
}

130
makefile Normal file
View File

@ -0,0 +1,130 @@
GO_CTL_NAME=goctl
GO_ZERO_STYLE=go_zero
GO ?= go
GOFMT ?= gofmt "-s"
GOFILES := $(shell find . -name "*.go")
LDFLAGS := -s -w
VERSION="v1.0.0"
DOCKER_REPO="refactor-service"
# 默認目標
.DEFAULT_GOAL := help
# 顏色定義
GREEN := \033[0;32m
YELLOW := \033[0;33m
NC := \033[0m # No Color
help: ## 顯示幫助訊息
@echo "$(GREEN)可用命令:$(NC)"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " $(YELLOW)%-20s$(NC) %s\n", $$1, $$2}'
.PHONY: test
test: ## 進行測試
go test -v --cover ./...
.PHONY: gen-api
gen-api: ## 產生 api
goctl api go -api ./api/chat.api -dir . -style go_zero
.PHONY: gen-doc
gen-doc: ## 生成 Swagger 文檔
# go-doc openapi --api ./generate/api/gateway.api --filename gateway.json --host dev-api.truheart.com.tw --basepath /api/v1
go-doc -a ./api/chat.api -d ./ -f chat -s openapi3.0
.PHONY: mock-gen
mock-gen: ## 建立 mock 資料
mockgen -source=./pkg/member/domain/repository/account.go -destination=./pkg/member/mock/repository/account.go -package=mock
mockgen -source=./pkg/member/domain/repository/account_uid.go -destination=./pkg/member/mock/repository/account_uid.go -package=mock
mockgen -source=./pkg/member/domain/repository/auto_id.go -destination=./pkg/member/mock/repository/auto_id.go -package=mock
mockgen -source=./pkg/member/domain/repository/user.go -destination=./pkg/member/mock/repository/user.go -package=mock
mockgen -source=./pkg/member/domain/repository/verify_code.go -destination=./pkg/member/mock/repository/verify_code.go -package=mock
mockgen -source=./pkg/member/domain/usecase/generate_uid.go -destination=./pkg/member/mock/usecase/generate_uid.go -package=mock
mockgen -source=./pkg/permission/domain/repository/permission.go -destination=./pkg/permission/mock/repository/permission.go -package=mock
mockgen -source=./pkg/permission/domain/repository/role.go -destination=./pkg/permission/mock/repository/role.go -package=mock
mockgen -source=./pkg/permission/domain/repository/role_permission.go -destination=./pkg/permission/mock/repository/role_permission.go -package=mock
mockgen -source=./pkg/permission/domain/repository/user_role.go -destination=./pkg/permission/mock/repository/user_role.go -package=mock
mockgen -source=./pkg/permission/domain/repository/token.go -destination=./pkg/permission/mock/repository/token.go -package=mock
mockgen -source=./pkg/permission/domain/usecase/permission.go -destination=./pkg/permission/mock/usecase/permission.go -package=mock
mockgen -source=./pkg/permission/domain/usecase/role.go -destination=./pkg/permission/mock/usecase/role.go -package=mock
mockgen -source=./pkg/permission/domain/usecase/role_permission.go -destination=./pkg/permission/mock/usecase/role_permission.go -package=mock
mockgen -source=./pkg/permission/domain/usecase/user_role.go -destination=./pkg/permission/mock/usecase/user_role.go -package=mock
mockgen -source=./pkg/permission/domain/usecase/token.go -destination=./pkg/permission/mock/usecase/token.go -package=mock
@echo "Generate mock files successfully"
.PHONY: fmt
fmt: ## 格式優化
$(GOFMT) -w $(GOFILES)
goimports -w ./
.PHONY: run
run: ## 運行專案
go run geteway.go
.PHONY: clean
clean: ## 清理編譯文件
rm -rf bin/
.PHONY: install
install: ## 安裝依賴
go mod tidy
go mod download
# go install -tags 'mongodb' github.com/golang-migrate/migrate/v4/cmd/migrate@latest
# go get -u github.com/golang-migrate/migrate/v4/database/mongodb
# MongoDB Migration 環境變數(可覆寫)
MONGO_HOST ?= 127.0.0.1:27017
MONGO_DB ?= digimon
MONGO_USER ?= root
MONGO_PASSWORD ?= example
MONGO_AUTH_DB ?= admin
.PHONY: migrate-up
migrate-up: ## 執行 MongoDB migration (up) - 使用 mongosh + Docker
@echo "=== 執行 MongoDB Migration (UP) ==="
@echo "MongoDB: $(MONGO_HOST)/$(MONGO_DB)"
docker-compose -f ./build/docker-compose-migrate.yml run --rm \
-e MONGO_HOST=$(MONGO_HOST) \
-e MONGO_DB=$(MONGO_DB) \
-e MONGO_USER=$(MONGO_USER) \
-e MONGO_PASSWORD=$(MONGO_PASSWORD) \
-e MONGO_AUTH_DB=$(MONGO_AUTH_DB) \
migrate
.PHONY: migrate-down
migrate-down: ## 執行 MongoDB migration (down) - 使用 mongosh + Docker
@echo "=== 執行 MongoDB Migration (DOWN) ==="
@echo "MongoDB: $(MONGO_HOST)/$(MONGO_DB)"
docker-compose -f ./build/docker-compose-migrate.yml run --rm \
-e MONGO_HOST=$(MONGO_HOST) \
-e MONGO_DB=$(MONGO_DB) \
-e MONGO_USER=$(MONGO_USER) \
-e MONGO_PASSWORD=$(MONGO_PASSWORD) \
-e MONGO_AUTH_DB=$(MONGO_AUTH_DB) \
migrate sh -c " \
if [ -z \"$$MONGO_USER\" ] || [ \"$$MONGO_USER\" = \"\" ]; then \
MONGO_URI=\"mongodb://$$MONGO_HOST/$$MONGO_DB\"; \
else \
MONGO_URI=\"mongodb://$$MONGO_USER:$$MONGO_PASSWORD@$$MONGO_HOST/$$MONGO_DB?authSource=$$MONGO_AUTH_DB\"; \
fi && \
echo \"執行 MongoDB migration (DOWN)...\" && \
echo \"連接: $$MONGO_URI\" && \
for file in \$$(ls -1 /migrations/*.down.txt 2>/dev/null | sort -r); do \
echo \"執行: \$$(basename \$$file)\" && \
mongosh \"$$MONGO_URI\" --file \"\$$file\" || exit 1; \
done && \
echo \"✅ Migration DOWN 完成\" \
"
.PHONY: migrate-version
migrate-version: ## 查看已執行的 migration 文件列表
@echo "=== 已執行的 Migration 文件 ==="
@echo "注意:使用 mongosh 執行,無法追蹤版本"
@echo "Migration 文件列表:"
@ls -1 generate/database/mongo/*.up.txt | xargs -n1 basename