fix api server update version

This commit is contained in:
王性驊 2026-01-06 15:15:18 +08:00
parent 08d1cf7069
commit 377b52515d
37 changed files with 5760 additions and 9 deletions

View File

@ -0,0 +1,44 @@
{
"token_hmac_secret_key": "your-secret-key-change-in-production",
"admin_password": "admin",
"admin_secret": "admin-secret",
"api_key": "api-key",
"allowed_origins": [
"*"
],
"log_level": "info",
"log_handler": "stdout",
"websocket_compression": true,
"websocket_read_buffer_size": 1024,
"websocket_write_buffer_size": 1024,
"namespaces": [
{
"name": "default",
"publish": true,
"subscribe_to_publish": true,
"presence": true,
"join_leave": true,
"history_size": 100,
"history_ttl": "300s"
},
{
"name": "room",
"publish": true,
"subscribe_to_publish": true,
"presence": true,
"join_leave": true,
"history_size": 100,
"history_ttl": "300s"
},
{
"name": "user",
"publish": false,
"subscribe_to_publish": false,
"presence": false,
"join_leave": false,
"history_size": 0,
"history_ttl": "0s"
}
],
"redis_address": "redis:6379"
}

View File

@ -27,7 +27,6 @@ services:
restart: always restart: always
ports: ports:
- "6379:6379" - "6379:6379"
minio: minio:
image: minio/minio image: minio/minio
container_name: minio container_name: minio
@ -39,3 +38,36 @@ services:
MINIO_ROOT_PASSWORD: minioadmin # Replace with your desired root password MINIO_ROOT_PASSWORD: minioadmin # Replace with your desired root password
# MINIO_DEFAULT_BUCKETS: mybucket # Optional: Create a default bucket on startup # MINIO_DEFAULT_BUCKETS: mybucket # Optional: Create a default bucket on startup
command: server /data --console-address ":9001" # Start MinIO server and specify console address command: server /data --console-address ":9001" # Start MinIO server and specify console address
centrifugo:
image: centrifugo/centrifugo:v5
container_name: centrifugo
restart: always
ports:
- "8000:8000" # HTTP API
- "8001:8001" # WebSocket
volumes:
- ./centrifugo.json:/centrifugo/config.json:ro
command: centrifugo --config=/centrifugo/config.json
healthcheck:
test: ["CMD", "wget", "--quiet", "--tries=1", "--spider", "http://localhost:8000/health"]
interval: 10s
timeout: 5s
retries: 3
depends_on:
- redis
cassandra:
image: cassandra:5.0.4
restart: always
ports:
- "9042:9042"
environment:
TZ: ${TIMEZONE:-UTC}
MAX_HEAP_SIZE: 4G
HEAP_NEWSIZE: 2G
healthcheck:
test: [ "CMD", "cqlsh", "-k", "sccflex" ]
interval: 10s
timeout: 10s
retries: 12
mem_limit: 8g # <--- 單機 docker-compose up 時建議明確加這行
memswap_limit: 8g # <--- 關掉 swap

View File

@ -1259,7 +1259,7 @@
"url": "https://localhost:8888" "url": "https://localhost:8888"
} }
], ],
"x-date": "2025-11-12 14:59:58", "x-date": "2026-01-05 10:01:16",
"x-description": "This is a go-doc generated swagger file.", "x-description": "This is a go-doc generated swagger file.",
"x-generator": "go-doc", "x-generator": "go-doc",
"x-github": "https://github.com/danielchan-25/go-doc", "x-github": "https://github.com/danielchan-25/go-doc",

View File

@ -0,0 +1,9 @@
CREATE TABLE messages (
room_id uuid,
bucket_day date,
ts bigint,
msg_id uuid,
uid text,
content text,
PRIMARY KEY ((room_id, bucket_day), ts, msg_id)
) WITH CLUSTERING ORDER BY (ts ASC);

View File

@ -0,0 +1,8 @@
CREATE TABLE message_dedup (
room_id uuid,
uid text,
content_md5 text,
bucket_sec bigint,
PRIMARY KEY ((room_id, uid), bucket_sec, content_md5)
) WITH default_time_to_live = 2;

1
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/go-playground/validator/v10 v10.28.0 github.com/go-playground/validator/v10 v10.28.0
github.com/gocql/gocql v1.7.0 github.com/gocql/gocql v1.7.0
github.com/golang-jwt/jwt/v4 v4.5.2 github.com/golang-jwt/jwt/v4 v4.5.2
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/matcornic/hermes/v2 v2.1.0 github.com/matcornic/hermes/v2 v2.1.0
github.com/minchao/go-mitake v1.0.0 github.com/minchao/go-mitake v1.0.0

2
go.sum
View File

@ -101,6 +101,8 @@ github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4= github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=

View File

@ -108,4 +108,20 @@ type Config struct {
SecretKey string SecretKey string
CloudFrontID string CloudFrontID string
} }
// Cassandra 配置
Cassandra struct {
Hosts []string
Port int
Keyspace string
Username string
Password string
UseAuth bool
}
// Centrifugo 配置
Centrifugo struct {
APIURL string
APIKey string
}
} }

67
internal/svc/chat.go Normal file
View File

@ -0,0 +1,67 @@
package svc
import (
"backend/internal/config"
"backend/pkg/chat/domain/usecase"
repo "backend/pkg/chat/repository"
uc "backend/pkg/chat/usecase"
"backend/pkg/library/cassandra"
"backend/pkg/library/centrifugo"
errs "backend/pkg/library/errors"
"fmt"
)
func MustMessageUseCase(c *config.Config, logger errs.Logger) usecase.MessageUseCase {
// 初始化 Cassandra DB
cassandraDB, err := initCassandraDB(c)
if err != nil {
panic(fmt.Sprintf("failed to initialize Cassandra DB: %v", err))
}
// 初始化 Message Repository
messageRepo := repo.MustMessageRepository(repo.MessageRepositoryParam{
DB: cassandraDB,
Keyspace: c.Cassandra.Keyspace,
})
// 初始化 Room Repository
roomRepo := repo.MustRoomRepository(repo.RoomRepositoryParam{
DB: cassandraDB,
Keyspace: c.Cassandra.Keyspace,
})
// 初始化 Centrifugo Client
msgClient := centrifugo.NewClientWithConfig(centrifugo.HighPerformanceConfig(c.Centrifugo.APIURL, c.Centrifugo.APIKey))
return uc.NewMessageUseCase(uc.MessageUseCaseParam{
MessageRepo: messageRepo,
RoomRepo: roomRepo,
MsgClient: msgClient,
Logger: logger,
})
}
// initCassandraDB 初始化 Cassandra 資料庫連接
func initCassandraDB(c *config.Config) (*cassandra.DB, error) {
if len(c.Cassandra.Hosts) == 0 {
return nil, fmt.Errorf("cassandra hosts are required")
}
opts := []cassandra.Option{
cassandra.WithHosts(c.Cassandra.Hosts...),
}
if c.Cassandra.Port > 0 {
opts = append(opts, cassandra.WithPort(c.Cassandra.Port))
}
if c.Cassandra.Keyspace != "" {
opts = append(opts, cassandra.WithKeyspace(c.Cassandra.Keyspace))
}
if c.Cassandra.UseAuth {
opts = append(opts, cassandra.WithAuth(c.Cassandra.Username, c.Cassandra.Password))
}
return cassandra.New(opts...)
}

View File

@ -9,6 +9,7 @@ import (
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
chatUC "backend/pkg/chat/domain/usecase"
fileStorageUC "backend/pkg/fileStorage/domain/usecase" fileStorageUC "backend/pkg/fileStorage/domain/usecase"
vi "backend/pkg/library/validator" vi "backend/pkg/library/validator"
memberUC "backend/pkg/member/domain/usecase" memberUC "backend/pkg/member/domain/usecase"
@ -31,6 +32,7 @@ type ServiceContext struct {
UserRoleUC tokenUC.UserRoleUseCase UserRoleUC tokenUC.UserRoleUseCase
DeliveryUC deliveryUC.DeliveryUseCase DeliveryUC deliveryUC.DeliveryUseCase
FileStorageUC fileStorageUC.FileStorageUseCase FileStorageUC fileStorageUC.FileStorageUseCase
MessageUC chatUC.MessageUseCase
Redis *redis.Redis Redis *redis.Redis
Logger errs.Logger Logger errs.Logger
} }
@ -62,6 +64,7 @@ func NewServiceContext(c config.Config) *ServiceContext {
Redis: rds, Redis: rds,
DeliveryUC: MustDeliveryUseCase(&c, lgr), DeliveryUC: MustDeliveryUseCase(&c, lgr),
FileStorageUC: MustS3Storage(&c, lgr), FileStorageUC: MustS3Storage(&c, lgr),
MessageUC: MustMessageUseCase(&c, lgr),
Logger: lgr, Logger: lgr,
} }
} }

View File

@ -0,0 +1,38 @@
package chat
// RoomRole 聊天室成員角色
type RoomRole string
const (
// RoomRoleMember 一般成員
RoomRoleMember RoomRole = "member"
// RoomRoleAdmin 管理員
RoomRoleAdmin RoomRole = "admin"
// RoomRoleOwner 擁有者
RoomRoleOwner RoomRole = "owner"
)
// String 返回角色的字串表示
func (r RoomRole) String() string {
return string(r)
}
// IsValid 檢查角色是否有效
func (r RoomRole) IsValid() bool {
switch r {
case RoomRoleMember, RoomRoleAdmin, RoomRoleOwner:
return true
default:
return false
}
}
// IsAdmin 檢查是否為管理員或擁有者
func (r RoomRole) IsAdmin() bool {
return r == RoomRoleAdmin || r == RoomRoleOwner
}
// IsOwner 檢查是否為擁有者
func (r RoomRole) IsOwner() bool {
return r == RoomRoleOwner
}

View File

@ -0,0 +1,23 @@
package entity
import (
"github.com/gocql/gocql"
"github.com/google/uuid"
)
// Message 對應 Cassandra 的 messages_by_room 表
// Primary Key: (room_id, bucket_day)
// Clustering Key: ts DESC
type Message struct {
RoomID gocql.UUID `db:"room_id" partition_key:"true"`
BucketDay string `db:"bucket_day" partition_key:"true"` // yyyyMMdd
TS int64 `db:"ts" clustering_key:"true"` // timestamp
UID string `db:"uid"`
Content string `db:"content"`
MsgID uuid.UUID `db:"msg_id"`
}
// TableName 返回表名
func (m Message) TableName() string {
return "messages_by_room"
}

View File

@ -0,0 +1,21 @@
package entity
import (
"github.com/gocql/gocql"
)
// MessageDedup 對應 Cassandra 的 message_dedup 表
// Primary Key: ((room_id, uid), bucket_sec, content_md5)
// TTL: 2 秒
type MessageDedup struct {
RoomID gocql.UUID `db:"room_id" partition_key:"true"`
UID string `db:"uid" partition_key:"true"`
BucketSec int64 `db:"bucket_sec" clustering_key:"true"` // Unix timestamp in seconds
ContentMD5 string `db:"content_md5" clustering_key:"true"` // MD5 hash of content
}
// TableName 返回表名
func (m MessageDedup) TableName() string {
return "message_dedup"
}

View File

@ -0,0 +1,59 @@
package entity
import "github.com/gocql/gocql"
// Room 對應 Cassandra 的 room 表
// Primary Key: (room_id)
// 設計說明:存儲聊天室本身的基本資訊
type Room struct {
RoomID gocql.UUID `db:"room_id" partition_key:"true"`
Name string `db:"name"` // 聊天室名稱
Status string `db:"status"` // 狀態active, archived, deleted
CreatedAt int64 `db:"created_at"` // 創建時間
UpdatedAt int64 `db:"updated_at"` // 更新時間
}
// TableName 返回表名
func (r Room) TableName() string {
return "room"
}
// RoomMember 對應 Cassandra 的 room_member 表
// Primary Key: (room_id)
// Clustering Key: uid
// 設計說明:
// - room_id 作為 partition key可以高效地查詢/刪除整個聊天室的所有成員
// - uid 作為 clustering key可以高效地查詢/刪除特定成員
// - 支援三種操作:
// 1. 刪除整個聊天室:刪除整個 partition需要查詢後批量刪除或使用原生 CQL
// 2. 刪除特定成員:使用 Delete(roomID, uid)
// 3. 查詢聊天室所有成員:使用 Query().Where(Eq("room_id", roomID)).Scan()
type RoomMember struct {
RoomID gocql.UUID `db:"room_id" partition_key:"true"`
UID string `db:"uid" clustering_key:"true"`
Role string `db:"role"` // Role 角色member一般成員、admin管理員、owner擁有者
JoinedAt int64 `db:"joined_at"` // JoinedAt 加入時間(可選,用於記錄加入時間)
}
// TableName 返回表名
func (m RoomMember) TableName() string {
return "room_member"
}
// UserRoom 對應 Cassandra 的 user_room 表(反向查詢表)
// Primary Key: (uid)
// Clustering Key: room_id
// 設計說明:
// - uid 作為 partition key可以高效地查詢某個用戶所在的所有聊天室
// - room_id 作為 clustering key可以高效地查詢/刪除特定關聯
// - 這個表用於支援「查詢用戶在哪些聊天室中」的需求
type UserRoom struct {
UID string `db:"uid" partition_key:"true"`
RoomID gocql.UUID `db:"room_id" clustering_key:"true"`
JoinedAt int64 `db:"joined_at"` // 加入時間
}
// TableName 返回表名
func (u UserRoom) TableName() string {
return "user_room"
}

View File

@ -0,0 +1,41 @@
package repository
import (
"backend/pkg/chat/domain/entity"
"context"
)
// MessageRepository 定義訊息相關的資料存取介面
type MessageRepository interface {
// Insert 插入訊息
Insert(ctx context.Context, msg *entity.Message) error
// ListMessages 查詢訊息列表(分頁)
ListMessages(ctx context.Context, param ListMessagesReq) ([]entity.Message, error)
// Count 計算符合條件的訊息總數
Count(ctx context.Context, RoomID string) (int64, error)
// CheckAndInsertDedup 檢查並插入去重記錄,如果已存在則返回 true表示重複
CheckAndInsertDedup(ctx context.Context, param CheckDupReq) (bool, error)
}
type SendMessageReq struct {
RoomID string
UID string
Content string
ClientMsgID string
}
type ListMessagesReq struct {
RoomID string
BucketDay string
PageSize int
// LastTS 用於 cursor-based pagination獲取 ts < LastTS 的訊息
// 如果為 0則獲取最新的訊息
LastTS int64
}
type CheckDupReq struct {
RoomID string
UID string
BucketSec int64
ContentMD5 string
}

View File

@ -0,0 +1,51 @@
package repository
import (
"backend/pkg/chat/domain/entity"
"context"
)
type RoomRepository interface {
Room
Member
User
}
// ListRoomsReq 查詢聊天室列表的請求參數
type ListRoomsReq struct {
Status string // 聊天室狀態active, archived 等)
PageSize int // 每頁大小
LastID string // 用於 cursor-based pagination
}
// CountRoomsReq 統計聊天室的請求參數
type CountRoomsReq struct {
Status string // 可選的狀態篩選
}
type Room interface {
Create(ctx context.Context, room *entity.Room) error // Create 創建聊天室
RoomGet(ctx context.Context, roomID string) (*entity.Room, error) // Get 獲取聊天室資訊
RoomUpdate(ctx context.Context, room *entity.Room) error // Update 更新聊天室資訊
RoomDelete(ctx context.Context, roomID string) error // Delete 刪除聊天室(同時需要刪除相關的成員和訊息)
RoomList(ctx context.Context, param ListRoomsReq) ([]entity.Room, error) // List 查詢聊天室列表(支援分頁和篩選)
RoomCount(ctx context.Context, param CountRoomsReq) (int64, error) // Count 統計聊天室總數
RoomExists(ctx context.Context, roomID string) (bool, error) // Exists 檢查聊天室是否存在
RoomGetByID(ctx context.Context, roomIDs []string) ([]entity.Room, error) // 取得 Room by id
}
type Member interface {
Insert(ctx context.Context, member *entity.RoomMember) error // Insert 添加成員到聊天室
Get(ctx context.Context, roomID, uid string) (*entity.RoomMember, error) // Get 獲取特定成員資訊
AllMembers(ctx context.Context, roomID string) ([]entity.RoomMember, error) // AllMembers 查詢聊天室所有成員
UpdateRole(ctx context.Context, member *entity.RoomMember) error // UpdateRole 更新成員資訊(例如更新角色)
DeleteMember(ctx context.Context, roomID, uid string) error // DeleteMember 刪除特定成員(某人退出聊天室)
DeleteRoom(ctx context.Context, roomID string) error // DeleteRoom 刪除整個聊天室的所有成員
Count(ctx context.Context, roomID string) (int64, error) // Count 計算聊天室成員數量
}
type User interface {
GetUserRooms(ctx context.Context, uid string) ([]entity.UserRoom, error) // GetUserRooms 查詢用戶所在的所有聊天室
CountUserRooms(ctx context.Context, uid string) (int64, error) // CountUserRooms 統計用戶所在的聊天室數量
IsUserInRoom(ctx context.Context, uid, roomID string) (bool, error) // IsUserInRoom 檢查用戶是否在某個聊天室中
}

View File

@ -0,0 +1,35 @@
package usecase
import (
"context"
)
// MessageUseCase 定義訊息相關的業務邏輯介面
type MessageUseCase interface {
// SendMessage 發送訊息
SendMessage(ctx context.Context, param SendMessageReq) error
// ListMessages 查詢訊息列表(分頁)
ListMessages(ctx context.Context, req ListMessagesReq) ([]Message, int64, error)
}
type SendMessageReq struct {
RoomID string
UID string
Content string
}
type ListMessagesReq struct {
RoomID string
UID string
BucketDay string
PageSize int64
LastTS int64
}
type Message struct {
RoomID string `json:"room_id"`
BucketDay string `json:"bucket_day"` // yyyyMMdd
TS int64 `json:"ts"` // timestamp
UID string `json:"uid"`
Content string `json:"content"`
}

View File

@ -0,0 +1,165 @@
package repository
import (
"backend/pkg/chat/domain/entity"
"backend/pkg/chat/domain/repository"
"backend/pkg/library/cassandra"
"context"
"fmt"
"time"
"github.com/gocql/gocql"
)
type messageRepository struct {
repo cassandra.Repository[entity.Message]
dedupRepo cassandra.Repository[entity.MessageDedup]
db *cassandra.DB
keyspace string
}
// MessageRepositoryParam 創建 MessageRepository 所需的參數
type MessageRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// MustMessageRepository 創建 MessageRepository如果失敗會 panic
func MustMessageRepository(param MessageRepositoryParam) repository.MessageRepository {
repo, err := NewMessageRepository(param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create message repository: %v", err))
}
return repo
}
// NewMessageRepository 創建新的訊息 Repository
func NewMessageRepository(db *cassandra.DB, keyspace string) (repository.MessageRepository, error) {
repo, err := cassandra.NewRepository[entity.Message](db, keyspace)
if err != nil {
return nil, err
}
dedupRepo, err := cassandra.NewRepository[entity.MessageDedup](db, keyspace)
if err != nil {
return nil, err
}
return &messageRepository{
repo: repo,
dedupRepo: dedupRepo,
db: db,
keyspace: keyspace,
}, nil
}
func (message *messageRepository) Insert(ctx context.Context, msg *entity.Message) error {
now := time.Now().UTC()
if msg.TS == 0 {
msg.TS = now.UnixNano()
}
// 只在 BucketDay 為空時才自動設置,保留先前傳入的值
if msg.BucketDay == "" {
msg.BucketDay = now.Format(time.DateOnly)
}
return message.repo.Insert(ctx, *msg)
}
func (message *messageRepository) ListMessages(ctx context.Context, param repository.ListMessagesReq) ([]entity.Message, error) {
// 設定預設分頁大小
if param.PageSize <= 0 {
param.PageSize = 20
}
// 將字串 RoomID 轉換為 UUID
roomUUID, err := gocql.ParseUUID(param.RoomID)
if err != nil {
return nil, err
}
// 構建查詢條件
query := message.repo.Query().
Where(cassandra.Eq("room_id", roomUUID)).
Where(cassandra.Eq("bucket_day", param.BucketDay))
// 使用 cursor-based pagination如果提供了 LastTS則查詢 ts < LastTS 的訊息
// 因為排序是 DESC所以使用 < 來獲取更早的訊息(下一頁)
if param.LastTS > 0 {
query = query.Where(cassandra.Lt("ts", param.LastTS))
}
// 添加排序和限制
query = query.
OrderBy("ts", cassandra.DESC).
Limit(int(param.PageSize))
// 執行查詢
var messages []entity.Message
if err := query.Scan(ctx, &messages); err != nil {
return nil, err
}
return messages, nil
}
func (message *messageRepository) Count(ctx context.Context, roomID string) (int64, error) {
// 將字串 RoomID 轉換為 UUID
roomUUID, err := gocql.ParseUUID(roomID)
if err != nil {
return 0, err
}
// 注意:由於 partition key 是 (room_id, bucket_day),只用 room_id 查詢需要 ALLOW FILTERING
// 這在生產環境中效能較差,建議改用按 bucket_day 分別查詢後加總
count, err := message.repo.Query().
Where(cassandra.Eq("room_id", roomUUID)).
AllowFiltering().
Count(ctx)
if err != nil {
return 0, err
}
return count, nil
}
// CheckAndInsertDedup 檢查並插入去重記錄
// 使用 IF NOT EXISTS 來實現原子性的去重檢查
// 返回值true 表示已存在重複false 表示成功插入(不重複)
func (message *messageRepository) CheckAndInsertDedup(ctx context.Context, param repository.CheckDupReq) (bool, error) {
// 將字串 RoomID 轉換為 UUID
roomUUID, err := gocql.ParseUUID(param.RoomID)
if err != nil {
return false, err
}
// 使用 IF NOT EXISTS 來實現原子性的去重檢查
// 如果記錄已存在INSERT 不會插入,且 applied = false
dedup := entity.MessageDedup{
RoomID: roomUUID,
UID: param.UID,
BucketSec: param.BucketSec,
ContentMD5: param.ContentMD5,
}
// 使用原生 CQL 語句來實現 IF NOT EXISTS
tableName := dedup.TableName()
stmt := fmt.Sprintf(
"INSERT INTO %s.%s (room_id, uid, bucket_sec, content_md5) VALUES (?, ?, ?, ?) IF NOT EXISTS",
message.keyspace,
tableName,
)
// 執行 INSERT IF NOT EXISTS
applied, err := message.db.GetSession().Query(stmt, nil).
Bind(roomUUID, param.UID, param.BucketSec, param.ContentMD5).
WithContext(ctx).
MapScanCAS(make(map[string]interface{}))
if err != nil {
return false, fmt.Errorf("failed to check dedup: %w", err)
}
// applied = false 表示記錄已存在(重複)
// applied = true 表示成功插入(不重複)
return !applied, nil
}

View File

@ -0,0 +1,526 @@
package repository
import (
"backend/pkg/chat/domain/entity"
"backend/pkg/chat/domain/repository"
"backend/pkg/library/cassandra"
"context"
"os"
"strconv"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
var (
// 全局變量:所有測試共享的資源
testDB *cassandra.DB
testRepo repository.MessageRepository
testContainer testcontainers.Container
cleanupContainer func()
)
// TestMain 在所有測試之前執行一次,設置共享的資料庫
func TestMain(m *testing.M) {
// 使用更長的 context 超時時間Cassandra 啟動需要較長時間)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// 啟動 Cassandra 容器(只執行一次)
// 不指定固定 port讓 testcontainers 自動分配,避免與本地開發環境衝突
req := testcontainers.ContainerRequest{
Image: "cassandra:4.1",
ExposedPorts: []string{"9042/tcp"},
// 等待 Cassandra 日誌確認啟動完成Cassandra 啟動需要較長時間)
WaitingFor: wait.ForLog("Startup complete").
WithStartupTimeout(3 * time.Minute).
WithOccurrence(1),
Env: map[string]string{
"CASSANDRA_CLUSTER_NAME": "test-cluster",
"HEAP_NEWSIZE": "128M",
"MAX_HEAP_SIZE": "512M",
},
}
var err error
testContainer, err = testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
panic("Failed to start Cassandra container: " + err.Error())
}
port, err := testContainer.MappedPort(ctx, "9042")
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to get mapped port: " + err.Error())
}
host, err := testContainer.Host(ctx)
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to get host: " + err.Error())
}
portInt, err := strconv.Atoi(port.Port())
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to convert port to int: " + err.Error())
}
// 先創建 DB 連接(不指定 keyspace因為 keyspace 還不存在)
testDB, err = cassandra.New(
cassandra.WithHosts(host),
cassandra.WithPort(portInt),
// 不指定 keyspace先連接後再創建
)
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to create DB: " + err.Error())
}
// 創建 keyspace需要在連接後才能創建
createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
if err := testDB.GetSession().Query(createKeyspaceStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create keyspace: " + err.Error())
}
// 等待 keyspace 創建完成
time.Sleep(500 * time.Millisecond)
// 創建 messages_by_room 表(需要指定 keyspace
createTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.messages_by_room (
room_id uuid,
bucket_day text,
ts bigint,
uid text,
content text,
PRIMARY KEY ((room_id, bucket_day), ts)
) WITH CLUSTERING ORDER BY (ts DESC)`
if err := testDB.GetSession().Query(createTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create table: " + err.Error())
}
// 創建 messages repository
testRepo, err = NewMessageRepository(testDB, "test_keyspace")
if err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create message repository: " + err.Error())
}
// 創建 room 相關表(供 room_test.go 使用)
createRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room (
room_id uuid PRIMARY KEY,
name text,
status text,
created_at bigint,
updated_at bigint
)`
if err := testDB.GetSession().Query(createRoomTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create room table: " + err.Error())
}
createMemberTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room_member (
room_id uuid,
uid text,
role text,
joined_at bigint,
PRIMARY KEY (room_id, uid)
)`
if err := testDB.GetSession().Query(createMemberTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create room_member table: " + err.Error())
}
createUserRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.user_room (
uid text,
room_id uuid,
joined_at bigint,
PRIMARY KEY (uid, room_id)
)`
if err := testDB.GetSession().Query(createUserRoomTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create user_room table: " + err.Error())
}
// 設置清理函數
cleanupContainer = func() {
if testDB != nil {
testDB.Close()
}
if testContainer != nil {
_ = testContainer.Terminate(ctx)
}
}
// 執行所有測試
code := m.Run()
// 清理資源
cleanupContainer()
// 退出
os.Exit(code)
}
// clearMessages 清空 messages_by_room 表的所有數據
func clearMessages(t *testing.T) {
ctx := context.Background()
// 使用完整的 keyspace.table 名稱
truncateStmt := "TRUNCATE test_keyspace.messages_by_room"
if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil {
t.Fatalf("Failed to truncate messages_by_room table: %v", err)
}
// 等待數據清空
time.Sleep(50 * time.Millisecond)
}
func TestNewMessageRepository(t *testing.T) {
clearMessages(t)
assert.NotNil(t, testRepo)
}
func TestMessageRepository_Insert(t *testing.T) {
clearMessages(t)
ctx := context.Background()
tests := []struct {
name string
message *entity.Message
wantErr bool
}{
{
name: "successful insert",
message: &entity.Message{
RoomID: gocql.TimeUUID(),
BucketDay: "20250117",
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Hello, world!",
},
wantErr: false,
},
{
name: "insert with auto timestamp",
message: &entity.Message{
RoomID: gocql.TimeUUID(),
BucketDay: "20250117",
UID: "user-2",
Content: "Test message",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := testRepo.Insert(ctx, tt.message)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
// 驗證 TS 和 BucketDay 是否被自動設置
if tt.message.TS == 0 {
assert.NotZero(t, tt.message.TS)
}
if tt.message.BucketDay == "" {
assert.NotEmpty(t, tt.message.BucketDay)
}
}
})
}
}
func TestMessageRepository_ListMessages(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入測試數據(使用相同的 roomID
messages := []*entity.Message{
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().UnixNano(), UID: "user-1", Content: "Message 1"},
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Second).UnixNano(), UID: "user-2", Content: "Message 2"},
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(2 * time.Second).UnixNano(), UID: "user-3", Content: "Message 3"},
}
for _, msg := range messages {
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
param repository.ListMessagesReq
wantLen int
wantErr bool
}{
{
name: "list all messages",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 10,
LastTS: 0,
},
wantLen: 3,
wantErr: false,
},
{
name: "list with page size limit",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 2,
LastTS: 0,
},
wantLen: 2,
wantErr: false,
},
{
name: "list with cursor pagination",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 10,
LastTS: messages[1].TS, // 使用第二條訊息的 TS 作為 cursor
},
wantLen: 1, // 應該只返回比 LastTS 更早的訊息
wantErr: false,
},
{
name: "list with default page size",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 0, // 應該使用預設值 20
LastTS: 0,
},
wantLen: 3,
wantErr: false,
},
{
name: "list empty room",
param: repository.ListMessagesReq{
RoomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
BucketDay: bucketDay,
PageSize: 10,
LastTS: 0,
},
wantLen: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := testRepo.ListMessages(ctx, tt.param)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, result, tt.wantLen)
// 驗證排序(應該按 TS DESC 排序)
if len(result) > 1 {
for i := 0; i < len(result)-1; i++ {
assert.GreaterOrEqual(t, result[i].TS, result[i+1].TS, "Messages should be sorted by TS DESC")
}
}
}
})
}
}
func TestMessageRepository_Count(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入測試數據(使用相同的 roomID
for i := 0; i < 5; i++ {
msg := &entity.Message{
RoomID: roomID,
BucketDay: bucketDay,
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
UID: "user-1",
Content: "Message",
}
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
roomID string
want int64
wantErr bool
}{
{
name: "count existing room",
roomID: roomIDStr,
want: 5,
wantErr: false,
},
{
name: "count non-existent room",
roomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
want: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
count, err := testRepo.Count(ctx, tt.roomID)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, count)
}
})
}
}
func TestMessageRepository_Insert_ListMessages_Integration(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入多條訊息(使用相同的 roomID
messages := make([]*entity.Message, 10)
for i := 0; i < 10; i++ {
msg := &entity.Message{
RoomID: roomID,
BucketDay: bucketDay,
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
UID: "user-1",
Content: "Message",
}
messages[i] = msg
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 驗證可以查詢到所有訊息
result, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 20,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, result, 10)
// 驗證分頁功能
firstPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 5,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, firstPage, 5)
// 使用 cursor 獲取下一頁
if len(firstPage) > 0 {
lastTS := firstPage[len(firstPage)-1].TS
secondPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 5,
LastTS: lastTS,
})
require.NoError(t, err)
assert.Len(t, secondPage, 5)
// 驗證第二頁的訊息都比第一頁的最後一條更早
if len(secondPage) > 0 {
assert.Less(t, secondPage[0].TS, lastTS)
}
}
// 驗證計數
count, err := testRepo.Count(ctx, roomIDStr)
require.NoError(t, err)
assert.Equal(t, int64(10), count)
}
func TestMessageRepository_DifferentBuckets(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
today := time.Now().Format(time.DateOnly)
yesterday := time.Now().AddDate(0, 0, -1).Format(time.DateOnly)
// 插入不同 bucket 的訊息(使用相同的 roomID
todayMsg := &entity.Message{
RoomID: roomID,
BucketDay: today,
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Today's message",
}
yesterdayMsg := &entity.Message{
RoomID: roomID,
BucketDay: yesterday,
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Yesterday's message",
}
require.NoError(t, testRepo.Insert(ctx, todayMsg))
require.NoError(t, testRepo.Insert(ctx, yesterdayMsg))
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 查詢今天的訊息
todayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: today,
PageSize: 10,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, todayMessages, 1)
assert.Equal(t, "Today's message", todayMessages[0].Content)
// 查詢昨天的訊息
yesterdayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: yesterday,
PageSize: 10,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, yesterdayMessages, 1)
assert.Equal(t, "Yesterday's message", yesterdayMessages[0].Content)
}

404
pkg/chat/repository/room.go Normal file
View File

@ -0,0 +1,404 @@
package repository
import (
"backend/pkg/chat/domain/chat"
"backend/pkg/chat/domain/entity"
"backend/pkg/chat/domain/repository"
"backend/pkg/library/cassandra"
"context"
"fmt"
"time"
"github.com/gocql/gocql"
)
type roomRepository struct {
roomRepo cassandra.Repository[entity.Room]
memberRepo cassandra.Repository[entity.RoomMember]
userRoomRepo cassandra.Repository[entity.UserRoom]
db *cassandra.DB
keyspace string
}
// RoomRepositoryParam 創建 RoomRepository 所需的參數
type RoomRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// MustRoomRepository 創建 RoomRepository如果失敗會 panic
func MustRoomRepository(param RoomRepositoryParam) repository.RoomRepository {
repo, err := NewRoomRepository(param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create room repository: %v", err))
}
return repo
}
// NewRoomRepository 創建新的聊天室 Repository
func NewRoomRepository(db *cassandra.DB, keyspace string) (repository.RoomRepository, error) {
roomRepo, err := cassandra.NewRepository[entity.Room](db, keyspace)
if err != nil {
return nil, err
}
memberRepo, err := cassandra.NewRepository[entity.RoomMember](db, keyspace)
if err != nil {
return nil, err
}
userRoomRepo, err := cassandra.NewRepository[entity.UserRoom](db, keyspace)
if err != nil {
return nil, err
}
return &roomRepository{
roomRepo: roomRepo,
memberRepo: memberRepo,
userRoomRepo: userRoomRepo,
db: db,
keyspace: keyspace,
}, nil
}
// ==================== Room Interface 實作 ====================
func (r *roomRepository) Create(ctx context.Context, room *entity.Room) error {
now := time.Now().UTC().UnixNano()
if room.CreatedAt == 0 {
room.CreatedAt = now
}
if room.UpdatedAt == 0 {
room.UpdatedAt = now
}
room.RoomID = gocql.TimeUUID()
return r.roomRepo.Insert(ctx, *room)
}
func (r *roomRepository) RoomGet(ctx context.Context, roomID string) (*entity.Room, error) {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return nil, err
}
room, err := r.roomRepo.Get(ctx, entity.Room{
RoomID: uuid,
})
if err != nil {
if cassandra.IsNotFound(err) {
return nil, cassandra.ErrNotFound
}
return nil, err
}
return &room, nil
}
func (r *roomRepository) RoomUpdate(ctx context.Context, room *entity.Room) error {
update := entity.Room{}
now := time.Now().UTC().UnixNano()
get, err := r.RoomGet(ctx, room.RoomID.String())
if err != nil {
return err
}
update.CreatedAt = get.CreatedAt
update.UpdatedAt = now
update.Status = room.Status
update.Name = room.Name
update.RoomID = room.RoomID
if err = r.roomRepo.Update(ctx, update); err != nil {
return err
}
return nil
}
func (r *roomRepository) RoomDelete(ctx context.Context, roomID string) error {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return err
}
// 使用原生 CQL 語句來刪除,避免 cassandra library Delete 方法的 struct 綁定問題
e := entity.Room{}
stmt := fmt.Sprintf("DELETE FROM %s.%s WHERE room_id = ?", r.keyspace, e.TableName())
return r.db.GetSession().Query(stmt, nil).
Bind(uuid).
WithContext(ctx).
ExecRelease()
}
func (r *roomRepository) RoomList(ctx context.Context, param repository.ListRoomsReq) ([]entity.Room, error) {
query := r.roomRepo.Query()
if param.Status != "" {
query = query.Where(cassandra.Eq("status", param.Status)).AllowFiltering()
}
if param.PageSize > 0 {
query = query.Limit(param.PageSize)
}
if param.LastID != "" {
lastUUID, err := gocql.ParseUUID(param.LastID)
if err != nil {
return nil, err
}
query = query.Where(cassandra.Lt("room_id", lastUUID))
}
var rooms []entity.Room
if err := query.Scan(ctx, &rooms); err != nil {
return nil, err
}
return rooms, nil
}
func (r *roomRepository) RoomCount(ctx context.Context, param repository.CountRoomsReq) (int64, error) {
query := r.roomRepo.Query()
if param.Status != "" {
query = query.Where(cassandra.Eq("status", param.Status)).AllowFiltering()
}
return query.Count(ctx)
}
func (r *roomRepository) RoomExists(ctx context.Context, roomID string) (bool, error) {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return false, err
}
_, err = r.roomRepo.Get(ctx, entity.Room{
RoomID: uuid,
})
if err != nil {
if cassandra.IsNotFound(err) {
return false, nil
}
return false, err
}
return true, nil
}
func (r *roomRepository) RoomGetByID(ctx context.Context, roomIDs []string) ([]entity.Room, error) {
if len(roomIDs) == 0 {
return []entity.Room{}, nil
}
// 將字串 ID 轉換為 UUID
uuids := make([]any, 0, len(roomIDs))
for _, id := range roomIDs {
uuid, err := gocql.ParseUUID(id)
if err != nil {
return nil, err
}
uuids = append(uuids, uuid)
}
var rooms []entity.Room
err := r.roomRepo.Query().
Where(cassandra.In("room_id", uuids)).
Scan(ctx, &rooms)
if err != nil {
return nil, err
}
return rooms, nil
}
// ==================== Member Interface 實作 ====================
func (r *roomRepository) Insert(ctx context.Context, member *entity.RoomMember) error {
now := time.Now().UTC().UnixNano()
if member.JoinedAt == 0 {
member.JoinedAt = now
}
if member.Role == "" {
member.Role = chat.RoomRoleMember.String()
}
// 同時插入到 user_room 表(反向查詢表)
userRoom := entity.UserRoom{
UID: member.UID,
RoomID: member.RoomID,
JoinedAt: member.JoinedAt,
}
if err := r.memberRepo.Insert(ctx, *member); err != nil {
return err
}
if err := r.userRoomRepo.Insert(ctx, userRoom); err != nil {
return err
}
return nil
}
func (r *roomRepository) Get(ctx context.Context, roomID, uid string) (*entity.RoomMember, error) {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return nil, err
}
member, err := r.memberRepo.Get(ctx, entity.RoomMember{
RoomID: uuid,
UID: uid,
})
if err != nil {
if cassandra.IsNotFound(err) {
return nil, cassandra.ErrNotFound
}
return nil, err
}
return &member, nil
}
func (r *roomRepository) AllMembers(ctx context.Context, roomID string) ([]entity.RoomMember, error) {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return nil, err
}
var members []entity.RoomMember
err = r.memberRepo.Query().
Where(cassandra.Eq("room_id", uuid)).
Scan(ctx, &members)
if err != nil {
return nil, err
}
return members, nil
}
func (r *roomRepository) UpdateRole(ctx context.Context, member *entity.RoomMember) error {
get, err := r.Get(ctx, member.RoomID.String(), member.UID)
if err != nil {
return err
}
update := entity.RoomMember{
RoomID: member.RoomID,
UID: member.UID,
Role: member.Role,
JoinedAt: get.JoinedAt,
}
return r.memberRepo.Update(ctx, update)
}
func (r *roomRepository) DeleteMember(ctx context.Context, roomID, uid string) error {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return err
}
// 同時從兩個表中刪除
if err := r.memberRepo.Delete(ctx, entity.RoomMember{
RoomID: uuid,
UID: uid,
}); err != nil {
return err
}
if err := r.userRoomRepo.Delete(ctx, entity.UserRoom{
UID: uid,
RoomID: uuid,
}); err != nil {
return err
}
return nil
}
func (r *roomRepository) DeleteRoom(ctx context.Context, roomID string) error {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return err
}
// 先查詢所有成員(必須在刪除之前查詢)
members, err := r.AllMembers(ctx, roomID)
if err != nil {
return err
}
// 刪除 user_room 表中的所有關聯
for _, member := range members {
if err := r.userRoomRepo.Delete(ctx, entity.UserRoom{
UID: member.UID,
RoomID: uuid,
}); err != nil {
return err
}
}
// 最後刪除 room_member 表中的所有成員
e := entity.RoomMember{}
stmt := fmt.Sprintf("DELETE FROM %s.%s WHERE room_id = ?", r.keyspace, e.TableName())
if err := r.db.GetSession().Query(stmt, nil).
Bind(uuid).
WithContext(ctx).
WithTimestamp(time.Now().UnixNano() / 1e3).
ExecRelease(); err != nil {
return err
}
return nil
}
func (r *roomRepository) Count(ctx context.Context, roomID string) (int64, error) {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return 0, err
}
return r.memberRepo.Query().
Where(cassandra.Eq("room_id", uuid)).
Count(ctx)
}
// ==================== User Interface 實作 ====================
func (r *roomRepository) GetUserRooms(ctx context.Context, uid string) ([]entity.UserRoom, error) {
var userRooms []entity.UserRoom
err := r.userRoomRepo.Query().
Where(cassandra.Eq("uid", uid)).
Scan(ctx, &userRooms)
if err != nil {
return nil, err
}
return userRooms, nil
}
func (r *roomRepository) CountUserRooms(ctx context.Context, uid string) (int64, error) {
return r.userRoomRepo.Query().
Where(cassandra.Eq("uid", uid)).
Count(ctx)
}
func (r *roomRepository) IsUserInRoom(ctx context.Context, uid, roomID string) (bool, error) {
uuid, err := gocql.ParseUUID(roomID)
if err != nil {
return false, err
}
_, err = r.userRoomRepo.Get(ctx, entity.UserRoom{
UID: uid,
RoomID: uuid,
})
if err != nil {
if cassandra.IsNotFound(err) {
return false, nil
}
return false, err
}
return true, nil
}

View File

@ -0,0 +1,953 @@
package repository
import (
"backend/pkg/chat/domain/chat"
"backend/pkg/chat/domain/entity"
"backend/pkg/chat/domain/repository"
"backend/pkg/library/cassandra"
"context"
"strconv"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
// 全局變量:所有測試共享的資源(與 message_test.go 共享)
roomTestRepo repository.RoomRepository
)
// ensureRoomRepository 確保 Room Repository 已初始化
func ensureRoomRepository(t *testing.T) {
if roomTestRepo == nil {
if testDB == nil {
t.Fatal("testDB is not initialized. Make sure TestMain in message_test.go runs first.")
}
var err error
roomTestRepo, err = NewRoomRepository(testDB, "test_keyspace")
if err != nil {
t.Fatalf("Failed to create room repository: %v", err)
}
}
}
// clearRoomData 清空所有 room 相關表的數據
func clearRoomData(t *testing.T) {
ctx := context.Background()
tables := []string{"user_room", "room_member", "room"}
for _, table := range tables {
truncateStmt := "TRUNCATE test_keyspace." + table
if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil {
t.Fatalf("Failed to truncate %s table: %v", table, err)
}
}
// 等待數據清空
time.Sleep(50 * time.Millisecond)
}
// ==================== Room Interface 測試 ====================
func TestNewRoomRepository(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
assert.NotNil(t, roomTestRepo)
}
func TestRoomRepository_Create(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
tests := []struct {
name string
room *entity.Room
wantErr bool
}{
{
name: "successful create",
room: &entity.Room{
Name: "Test Room",
Status: "active",
},
wantErr: false,
},
{
name: "create with auto timestamp",
room: &entity.Room{
Name: "Auto Timestamp Room",
Status: "active",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := roomTestRepo.Create(ctx, tt.room)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
// 驗證 RoomID 和時間戳是否被自動設置
assert.NotZero(t, tt.room.RoomID)
assert.NotZero(t, tt.room.CreatedAt)
assert.NotZero(t, tt.room.UpdatedAt)
}
})
}
}
func TestRoomRepository_RoomGet(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "Get Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
roomIDStr := room.RoomID.String()
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
roomID string
wantErr bool
}{
{
name: "get existing room",
roomID: roomIDStr,
wantErr: false,
},
{
name: "get non-existent room",
roomID: gocql.TimeUUID().String(),
wantErr: true,
},
{
name: "get with invalid UUID",
roomID: "invalid-uuid",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := roomTestRepo.RoomGet(ctx, tt.roomID)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, room.Name, result.Name)
assert.Equal(t, room.Status, result.Status)
}
})
}
}
func TestRoomRepository_RoomUpdate(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "Original Name",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
originalCreatedAt := room.CreatedAt
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 更新房間
room.Name = "Updated Name"
room.Status = "archived"
err := roomTestRepo.RoomUpdate(ctx, room)
require.NoError(t, err)
// 驗證更新
updated, err := roomTestRepo.RoomGet(ctx, room.RoomID.String())
require.NoError(t, err)
assert.Equal(t, "Updated Name", updated.Name)
assert.Equal(t, "archived", updated.Status)
assert.Equal(t, originalCreatedAt, updated.CreatedAt) // CreatedAt 不應該改變
assert.Greater(t, updated.UpdatedAt, originalCreatedAt) // UpdatedAt 應該更新
}
func TestRoomRepository_RoomDelete(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "Delete Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
roomIDStr := room.RoomID.String()
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 刪除房間
err := roomTestRepo.RoomDelete(ctx, roomIDStr)
require.NoError(t, err)
// 驗證房間已被刪除
_, err = roomTestRepo.RoomGet(ctx, roomIDStr)
assert.Error(t, err)
assert.True(t, cassandra.IsNotFound(err))
}
func TestRoomRepository_RoomList(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建多個測試房間
rooms := make([]*entity.Room, 5)
for i := 0; i < 5; i++ {
room := &entity.Room{
Name: "Room " + strconv.Itoa(i),
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
rooms[i] = room
time.Sleep(10 * time.Millisecond) // 確保時間戳不同
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
param repository.ListRoomsReq
wantLen int
wantErr bool
}{
{
name: "list all rooms",
param: repository.ListRoomsReq{
PageSize: 10,
},
wantLen: 5,
wantErr: false,
},
{
name: "list with status filter",
param: repository.ListRoomsReq{
Status: "active",
PageSize: 10,
},
wantLen: 5,
wantErr: false,
},
{
name: "list with page size limit",
param: repository.ListRoomsReq{
PageSize: 2,
},
wantLen: 2,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := roomTestRepo.RoomList(ctx, tt.param)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, result, tt.wantLen)
}
})
}
}
func TestRoomRepository_RoomCount(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
for i := 0; i < 3; i++ {
room := &entity.Room{
Name: "Count Room " + strconv.Itoa(i),
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
}
// 創建一個 archived 房間
archivedRoom := &entity.Room{
Name: "Archived Room",
Status: "archived",
}
require.NoError(t, roomTestRepo.Create(ctx, archivedRoom))
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
param repository.CountRoomsReq
want int64
wantErr bool
}{
{
name: "count all rooms",
param: repository.CountRoomsReq{},
want: 4,
},
{
name: "count active rooms",
param: repository.CountRoomsReq{
Status: "active",
},
want: 3,
},
{
name: "count archived rooms",
param: repository.CountRoomsReq{
Status: "archived",
},
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
count, err := roomTestRepo.RoomCount(ctx, tt.param)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, count)
}
})
}
}
func TestRoomRepository_RoomExists(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "Exists Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
roomIDStr := room.RoomID.String()
nonExistentID := gocql.TimeUUID().String()
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
roomID string
want bool
wantErr bool
}{
{
name: "existing room",
roomID: roomIDStr,
want: true,
wantErr: false,
},
{
name: "non-existent room",
roomID: nonExistentID,
want: false,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
exists, err := roomTestRepo.RoomExists(ctx, tt.roomID)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, exists)
}
})
}
}
func TestRoomRepository_RoomGetByID(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建多個測試房間
rooms := make([]*entity.Room, 3)
roomIDs := make([]string, 3)
for i := 0; i < 3; i++ {
room := &entity.Room{
Name: "Room " + strconv.Itoa(i),
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
rooms[i] = room
roomIDs[i] = room.RoomID.String()
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
roomIDs []string
wantLen int
wantErr bool
}{
{
name: "get multiple rooms",
roomIDs: roomIDs,
wantLen: 3,
wantErr: false,
},
{
name: "get empty list",
roomIDs: []string{},
wantLen: 0,
wantErr: false,
},
{
name: "get with non-existent ID",
roomIDs: []string{roomIDs[0], gocql.TimeUUID().String()},
wantLen: 1, // 只返回存在的房間
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := roomTestRepo.RoomGetByID(ctx, tt.roomIDs)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, result, tt.wantLen)
}
})
}
}
// ==================== Member Interface 測試 ====================
func TestRoomRepository_Insert(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "Member Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
tests := []struct {
name string
member *entity.RoomMember
wantErr bool
}{
{
name: "successful insert",
member: &entity.RoomMember{
RoomID: room.RoomID,
UID: "user-1",
Role: chat.RoomRoleMember.String(),
},
wantErr: false,
},
{
name: "insert with auto role",
member: &entity.RoomMember{
RoomID: room.RoomID,
UID: "user-2",
},
wantErr: false,
},
{
name: "insert with admin role",
member: &entity.RoomMember{
RoomID: room.RoomID,
UID: "user-3",
Role: chat.RoomRoleAdmin.String(),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := roomTestRepo.Insert(ctx, tt.member)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
// 驗證 JoinedAt 是否被自動設置
assert.NotZero(t, tt.member.JoinedAt)
// 驗證 Role 是否被自動設置(如果為空)
if tt.member.Role == "" {
assert.Equal(t, chat.RoomRoleMember.String(), tt.member.Role)
}
// 驗證 user_room 表也被更新
userRooms, err := roomTestRepo.GetUserRooms(ctx, tt.member.UID)
require.NoError(t, err)
assert.Greater(t, len(userRooms), 0)
}
})
}
}
func TestRoomRepository_Get(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間和成員
room := &entity.Room{
Name: "Get Member Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
member := &entity.RoomMember{
RoomID: room.RoomID,
UID: "user-1",
Role: chat.RoomRoleAdmin.String(),
}
require.NoError(t, roomTestRepo.Insert(ctx, member))
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
roomID string
uid string
wantErr bool
}{
{
name: "get existing member",
roomID: room.RoomID.String(),
uid: "user-1",
wantErr: false,
},
{
name: "get non-existent member",
roomID: room.RoomID.String(),
uid: "non-existent-user",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := roomTestRepo.Get(ctx, tt.roomID, tt.uid)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, member.UID, result.UID)
assert.Equal(t, member.Role, result.Role)
}
})
}
}
func TestRoomRepository_AllMembers(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "All Members Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
// 插入多個成員
members := []*entity.RoomMember{
{RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleMember.String()},
{RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()},
{RoomID: room.RoomID, UID: "user-3", Role: chat.RoomRoleMember.String()},
}
for _, member := range members {
require.NoError(t, roomTestRepo.Insert(ctx, member))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 查詢所有成員
result, err := roomTestRepo.AllMembers(ctx, room.RoomID.String())
require.NoError(t, err)
assert.Len(t, result, 3)
}
func TestRoomRepository_UpdateRole(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間和成員
room := &entity.Room{
Name: "Update Role Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
member := &entity.RoomMember{
RoomID: room.RoomID,
UID: "user-1",
Role: chat.RoomRoleMember.String(),
}
require.NoError(t, roomTestRepo.Insert(ctx, member))
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 更新角色
member.Role = chat.RoomRoleAdmin.String()
err := roomTestRepo.UpdateRole(ctx, member)
require.NoError(t, err)
// 驗證更新
updated, err := roomTestRepo.Get(ctx, room.RoomID.String(), "user-1")
require.NoError(t, err)
assert.Equal(t, chat.RoomRoleAdmin.String(), updated.Role)
}
func TestRoomRepository_DeleteMember(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間和成員
room := &entity.Room{
Name: "Delete Member Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
member := &entity.RoomMember{
RoomID: room.RoomID,
UID: "user-1",
Role: chat.RoomRoleMember.String(),
}
require.NoError(t, roomTestRepo.Insert(ctx, member))
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 刪除成員
err := roomTestRepo.DeleteMember(ctx, room.RoomID.String(), "user-1")
require.NoError(t, err)
// 驗證成員已被刪除
_, err = roomTestRepo.Get(ctx, room.RoomID.String(), "user-1")
assert.Error(t, err)
assert.True(t, cassandra.IsNotFound(err))
// 驗證 user_room 表也被刪除
userRooms, err := roomTestRepo.GetUserRooms(ctx, "user-1")
require.NoError(t, err)
assert.Len(t, userRooms, 0)
}
func TestRoomRepository_DeleteRoom(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "Delete Room Test",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
// 插入多個成員
members := []*entity.RoomMember{
{RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleMember.String()},
{RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()},
}
for _, member := range members {
require.NoError(t, roomTestRepo.Insert(ctx, member))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 刪除整個房間
err := roomTestRepo.DeleteRoom(ctx, room.RoomID.String())
require.NoError(t, err)
// 驗證所有成員已被刪除
allMembers, err := roomTestRepo.AllMembers(ctx, room.RoomID.String())
require.NoError(t, err)
assert.Len(t, allMembers, 0)
// 驗證 user_room 表中的關聯也被刪除
for _, member := range members {
userRooms, err := roomTestRepo.GetUserRooms(ctx, member.UID)
require.NoError(t, err)
// 驗證該用戶的 user_room 記錄中不包含這個房間
found := false
for _, ur := range userRooms {
if ur.RoomID == room.RoomID {
found = true
break
}
}
assert.False(t, found, "user_room should not contain deleted room")
}
}
func TestRoomRepository_Count(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "Count Members Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
// 插入多個成員
for i := 0; i < 5; i++ {
member := &entity.RoomMember{
RoomID: room.RoomID,
UID: "user-" + strconv.Itoa(i),
Role: chat.RoomRoleMember.String(),
}
require.NoError(t, roomTestRepo.Insert(ctx, member))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 計算成員數
count, err := roomTestRepo.Count(ctx, room.RoomID.String())
require.NoError(t, err)
assert.Equal(t, int64(5), count)
}
// ==================== User Interface 測試 ====================
func TestRoomRepository_GetUserRooms(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建多個測試房間
rooms := make([]*entity.Room, 3)
for i := 0; i < 3; i++ {
room := &entity.Room{
Name: "User Room " + strconv.Itoa(i),
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
rooms[i] = room
}
// 將用戶加入多個房間
uid := "user-1"
for _, room := range rooms {
member := &entity.RoomMember{
RoomID: room.RoomID,
UID: uid,
Role: chat.RoomRoleMember.String(),
}
require.NoError(t, roomTestRepo.Insert(ctx, member))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 查詢用戶所在的所有房間
userRooms, err := roomTestRepo.GetUserRooms(ctx, uid)
require.NoError(t, err)
assert.Len(t, userRooms, 3)
}
func TestRoomRepository_CountUserRooms(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
rooms := make([]*entity.Room, 5)
for i := 0; i < 5; i++ {
room := &entity.Room{
Name: "Count User Room " + strconv.Itoa(i),
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
rooms[i] = room
}
// 將用戶加入多個房間
uid := "user-1"
for _, room := range rooms {
member := &entity.RoomMember{
RoomID: room.RoomID,
UID: uid,
Role: chat.RoomRoleMember.String(),
}
require.NoError(t, roomTestRepo.Insert(ctx, member))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 計算用戶所在的房間數
count, err := roomTestRepo.CountUserRooms(ctx, uid)
require.NoError(t, err)
assert.Equal(t, int64(5), count)
}
func TestRoomRepository_IsUserInRoom(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建測試房間
room := &entity.Room{
Name: "Is User In Room Test",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
// 插入成員
member := &entity.RoomMember{
RoomID: room.RoomID,
UID: "user-1",
Role: chat.RoomRoleMember.String(),
}
require.NoError(t, roomTestRepo.Insert(ctx, member))
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
uid string
roomID string
want bool
wantErr bool
}{
{
name: "user in room",
uid: "user-1",
roomID: room.RoomID.String(),
want: true,
wantErr: false,
},
{
name: "user not in room",
uid: "user-2",
roomID: room.RoomID.String(),
want: false,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := roomTestRepo.IsUserInRoom(ctx, tt.uid, tt.roomID)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, result)
}
})
}
}
// ==================== 整合測試 ====================
func TestRoomRepository_Integration(t *testing.T) {
ensureRoomRepository(t)
clearRoomData(t)
ctx := context.Background()
// 創建房間
room := &entity.Room{
Name: "Integration Test Room",
Status: "active",
}
require.NoError(t, roomTestRepo.Create(ctx, room))
// 添加多個成員
members := []*entity.RoomMember{
{RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleOwner.String()},
{RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()},
{RoomID: room.RoomID, UID: "user-3", Role: chat.RoomRoleMember.String()},
}
for _, member := range members {
require.NoError(t, roomTestRepo.Insert(ctx, member))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 驗證房間存在
exists, err := roomTestRepo.RoomExists(ctx, room.RoomID.String())
require.NoError(t, err)
assert.True(t, exists)
// 驗證成員數量
count, err := roomTestRepo.Count(ctx, room.RoomID.String())
require.NoError(t, err)
assert.Equal(t, int64(3), count)
// 驗證所有成員
allMembers, err := roomTestRepo.AllMembers(ctx, room.RoomID.String())
require.NoError(t, err)
assert.Len(t, allMembers, 3)
// 驗證用戶所在的房間
for _, member := range members {
userRooms, err := roomTestRepo.GetUserRooms(ctx, member.UID)
require.NoError(t, err)
assert.Greater(t, len(userRooms), 0)
inRoom, err := roomTestRepo.IsUserInRoom(ctx, member.UID, room.RoomID.String())
require.NoError(t, err)
assert.True(t, inRoom)
}
}

272
pkg/chat/usecase/message.go Normal file
View File

@ -0,0 +1,272 @@
package usecase
import (
"backend/pkg/chat/domain/entity"
"backend/pkg/chat/domain/repository"
"backend/pkg/chat/domain/usecase"
"backend/pkg/library/centrifugo"
errs "backend/pkg/library/errors"
"backend/pkg/utils"
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"math"
"time"
"github.com/gocql/gocql"
"github.com/google/uuid"
)
const (
defaultPageSize = 20
maxPageSize = 100 // 設置最大分頁大小
)
type MessageUseCaseParam struct {
MessageRepo repository.MessageRepository
RoomRepo repository.RoomRepository
MsgClient *centrifugo.Client
Logger errs.Logger
}
type MessageUseCase struct {
MessageUseCaseParam
}
// NewMessageUseCase 創建新的訊息 UseCase
func NewMessageUseCase(param MessageUseCaseParam) usecase.MessageUseCase {
return &MessageUseCase{
param,
}
}
func (use *MessageUseCase) SendMessage(ctx context.Context, param usecase.SendMessageReq) error {
// 驗證輸入參數
if err := use.validateSendMessageReq(param); err != nil {
return err
}
// 驗證使用者是否在房間中(優先檢查,避免不必要的去重操作)
if err := use.verifyRoomMembership(ctx, param.UID, param.RoomID); err != nil {
return err
}
// 去重檢查
now := time.Now().UTC()
bucketSec := now.Unix()
contentMD5 := calculateMD5(param.Content)
isDuplicate, err := use.MessageRepo.CheckAndInsertDedup(ctx, repository.CheckDupReq{
RoomID: param.RoomID,
UID: param.UID, // 修復:應該是 param.UID 而不是 param.RoomID
BucketSec: bucketSec,
ContentMD5: contentMD5,
})
if err != nil {
return use.logError("messageRepo.CheckAndInsertDedup", param, err, "failed to check message deduplication")
}
if isDuplicate {
return errs.InputInvalidFormatError("duplicate message detected")
}
// 建立並儲存訊息
msg, err := use.createMessage(param, now)
if err != nil {
return err
}
if err = use.MessageRepo.Insert(ctx, msg); err != nil {
return use.logError("messageRepo.Insert", param, err, "failed to insert message")
}
// 發布到 Centrifugo非阻塞失敗不影響主流程
use.publishToCentrifugo(ctx, param.RoomID, msg, now)
return nil
}
// validateSendMessageReq 驗證發送訊息的請求參數
func (use *MessageUseCase) validateSendMessageReq(param usecase.SendMessageReq) error {
if param.Content == "" {
return errs.InputInvalidFormatError("content cannot be empty")
}
if param.RoomID == "" {
return errs.InputInvalidFormatError("room_id cannot be empty")
}
if param.UID == "" {
return errs.InputInvalidFormatError("uid cannot be empty")
}
return nil
}
// verifyRoomMembership 驗證使用者是否在房間中
func (use *MessageUseCase) verifyRoomMembership(ctx context.Context, uid, roomID string) error {
isMember, err := use.RoomRepo.IsUserInRoom(ctx, uid, roomID)
if err != nil {
return use.logError("roomRepo.IsUserInRoom", map[string]interface{}{
"uid": uid,
"roomID": roomID,
}, err, "failed to check room membership")
}
if !isMember {
return errs.AuthForbiddenErrorL(
use.Logger,
[]errs.LogField{
{Key: "uid", Val: uid},
{Key: "roomID", Val: roomID},
},
"user is not a member of the room")
}
return nil
}
// createMessage 建立訊息實體
func (use *MessageUseCase) createMessage(param usecase.SendMessageReq, now time.Time) (*entity.Message, error) {
roomID, err := gocql.ParseUUID(param.RoomID)
if err != nil {
return nil, errs.InputInvalidFormatError("invalid room_id format").Wrap(err)
}
msgID, err := uuid.NewV7()
if err != nil {
return nil, errs.SysInternalError("failed to generate message id").Wrap(err)
}
return &entity.Message{
RoomID: roomID,
BucketDay: utils.GetBucketDay(now),
UID: param.UID,
MsgID: msgID,
Content: param.Content,
}, nil
}
// publishToCentrifugo 發布訊息到 Centrifugo
func (use *MessageUseCase) publishToCentrifugo(ctx context.Context, roomID string, msg *entity.Message, now time.Time) {
channel := fmt.Sprintf("room:%s", roomID)
messageData := map[string]interface{}{
"msg_id": msg.MsgID.String(),
"uid": msg.UID,
"content": msg.Content,
"timestamp": now.UnixNano(), // 使用實際時間戳,而不是 msg.TS可能為 0
"room_id": msg.RoomID.String(),
}
if _, err := use.MsgClient.PublishJSON(ctx, channel, messageData); err != nil {
// 記錄錯誤但不影響主流程,因為訊息已經成功儲存
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{Key: "roomID", Val: roomID},
errs.LogField{Key: "msgID", Val: msg.MsgID.String()},
errs.LogField{Key: "error", Val: err.Error()},
).Error(fmt.Sprintf("failed to publish message to Centrifugo: %v", err))
}
}
}
// logError 統一的錯誤記錄方法
func (use *MessageUseCase) logError(funcName string, param interface{}, err error, message string) error {
return errs.DBErrorErrorL(
use.Logger,
[]errs.LogField{
{Key: "param", Val: param},
{Key: "func", Val: funcName},
{Key: "err", Val: err.Error()},
},
message)
}
func (use *MessageUseCase) ListMessages(ctx context.Context, req usecase.ListMessagesReq) ([]usecase.Message, int64, error) {
// 驗證輸入參數
if err := use.validateListMessagesReq(req); err != nil {
return nil, 0, err
}
// 驗證使用者是否在房間中
if err := use.verifyRoomMembership(ctx, req.UID, req.RoomID); err != nil {
return nil, 0, err
}
// 取得 bucket_day如果未提供則使用今天
bucketDay := req.BucketDay
if bucketDay == "" {
bucketDay = utils.GetTodayBucketDay()
}
// 防止 PageSize overflow 並設置合理的範圍
pageSize := use.normalizePageSize(req.PageSize)
// 查詢訊息
messages, err := use.MessageRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: req.RoomID,
BucketDay: bucketDay,
PageSize: pageSize,
LastTS: req.LastTS,
})
if err != nil {
return nil, 0, use.logError("messageRepo.ListMessages", req, err, "failed to list messages")
}
// 轉換為 usecase.Message
result := make([]usecase.Message, 0, len(messages))
for _, msg := range messages {
result = append(result, usecase.Message{
RoomID: msg.RoomID.String(),
BucketDay: msg.BucketDay,
TS: msg.TS,
UID: msg.UID,
Content: msg.Content,
})
}
// 計算總數(只在第一頁時計算)
var total int64
if req.LastTS == 0 {
total, err = use.MessageRepo.Count(ctx, req.RoomID)
if err != nil {
return nil, 0, use.logError("messageRepo.Count", req, err, "failed to count messages")
}
}
return result, total, nil
}
// validateListMessagesReq 驗證查詢訊息的請求參數
func (use *MessageUseCase) validateListMessagesReq(req usecase.ListMessagesReq) error {
if req.RoomID == "" {
return errs.InputInvalidFormatError("room_id cannot be empty")
}
if req.UID == "" {
return errs.InputInvalidFormatError("uid cannot be empty")
}
return nil
}
// normalizePageSize 正規化分頁大小
func (use *MessageUseCase) normalizePageSize(pageSize int64) int {
// 檢查是否超過 int 的最大值
if pageSize > int64(math.MaxInt) {
return maxPageSize
}
size := int(pageSize)
if size <= 0 {
return defaultPageSize
}
if size > maxPageSize {
return maxPageSize
}
return size
}
// calculateMD5 計算字串的 MD5 雜湊值
func calculateMD5(content string) string {
hash := md5.Sum([]byte(content))
return hex.EncodeToString(hash[:])
}

View File

@ -75,6 +75,7 @@ type QueryBuilder[T Table] interface {
OrderBy(column string, order Order) QueryBuilder[T] OrderBy(column string, order Order) QueryBuilder[T]
Limit(n int) QueryBuilder[T] Limit(n int) QueryBuilder[T]
Select(columns ...string) QueryBuilder[T] Select(columns ...string) QueryBuilder[T]
AllowFiltering() QueryBuilder[T]
Scan(ctx context.Context, dest *[]T) error Scan(ctx context.Context, dest *[]T) error
One(ctx context.Context) (T, error) One(ctx context.Context) (T, error)
Count(ctx context.Context) (int64, error) Count(ctx context.Context) (int64, error)
@ -82,11 +83,12 @@ type QueryBuilder[T Table] interface {
// queryBuilder 是 QueryBuilder 的具體實作 // queryBuilder 是 QueryBuilder 的具體實作
type queryBuilder[T Table] struct { type queryBuilder[T Table] struct {
repo *repository[T] repo *repository[T]
conditions []Condition conditions []Condition
orders []orderBy orders []orderBy
limit int limit int
columns []string columns []string
allowFiltering bool
} }
type orderBy struct { type orderBy struct {
@ -125,6 +127,12 @@ func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] {
return q return q
} }
// AllowFiltering 允許不使用 partition key 的查詢(效能較差,慎用)
func (q *queryBuilder[T]) AllowFiltering() QueryBuilder[T] {
q.allowFiltering = true
return q
}
// Scan 執行查詢並將結果掃描到 dest // Scan 執行查詢並將結果掃描到 dest
func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error { func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
if dest == nil { if dest == nil {
@ -171,6 +179,11 @@ func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
builder = builder.Limit(uint(q.limit)) builder = builder.Limit(uint(q.limit))
} }
// 添加 ALLOW FILTERING
if q.allowFiltering {
builder = builder.AllowFiltering()
}
stmt, names := builder.ToCql() stmt, names := builder.ToCql()
query := q.repo.db.withContextAndTimestamp(ctx, query := q.repo.db.withContextAndTimestamp(ctx,
q.repo.db.session.Query(stmt, names).BindMap(bindMap)) q.repo.db.session.Query(stmt, names).BindMap(bindMap))
@ -213,6 +226,11 @@ func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) {
builder = builder.Where(cmps...) builder = builder.Where(cmps...)
} }
// 添加 ALLOW FILTERING
if q.allowFiltering {
builder = builder.AllowFiltering()
}
stmt, names := builder.ToCql() stmt, names := builder.ToCql()
query := q.repo.db.withContextAndTimestamp(ctx, query := q.repo.db.withContextAndTimestamp(ctx,
q.repo.db.session.Query(stmt, names).BindMap(bindMap)) q.repo.db.session.Query(stmt, names).BindMap(bindMap))

View File

@ -130,8 +130,23 @@ func (r *repository[T]) updateSelective(ctx context.Context, doc T, includeZero
func (r *repository[T]) Delete(ctx context.Context, pk any) error { func (r *repository[T]) Delete(ctx context.Context, pk any) error {
t := table.New(r.metadata) t := table.New(r.metadata)
stmt, names := t.Delete() stmt, names := t.Delete()
q := r.db.withContextAndTimestamp(ctx,
r.db.session.Query(stmt, names).Bind(pk)) // 如果 pk 是 struct使用 BindStruct否則使用 Bind
var q *gocqlx.Queryx
if reflect.TypeOf(pk).Kind() == reflect.Struct {
q = r.db.withContextAndTimestamp(ctx,
r.db.session.Query(stmt, names).BindStruct(pk))
} else {
// 單一主鍵欄位的情況
// 注意:這只適用於單一 Partition Key 且無 Clustering Key 的情況
if len(r.metadata.PartKey) != 1 || len(r.metadata.SortKey) > 0 {
return ErrInvalidInput.WithTable(r.table).WithError(
fmt.Errorf("single value primary key only supported for single partition key without clustering key"),
)
}
q = r.db.withContextAndTimestamp(ctx,
r.db.session.Query(stmt, names).Bind(pk))
}
return q.ExecRelease() return q.ExecRelease()
} }

View File

@ -0,0 +1,474 @@
# Centrifugo Client Library
Go 語言的 Centrifugo 即時訊息服務客戶端庫,提供完整的 Server-side API 支援。
## 功能特色
- ✅ **HTTP API 客戶端** - 發布訊息、訂閱管理、在線狀態、歷史訊息
- ✅ **JWT Token 生成** - 連線認證和私有頻道訂閱
- ✅ **Token 黑名單** - 撤銷單一 Token 或用戶所有 Token
- ✅ **在線狀態追蹤** - Redis 存儲 + Centrifugo Presence API
- ✅ **統一服務入口** - 簡單易用的 `Service` 整合介面
## 安裝依賴
```bash
go get github.com/golang-jwt/jwt/v5
go get github.com/zeromicro/go-zero/core/stores/redis
```
## 快速開始
### 1. 創建服務實例(推薦方式)
```go
import (
"backend/pkg/library/centrifugo"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// 創建 Redis 客戶端
rds, _ := redis.NewRedis(redis.RedisConf{
Host: "localhost:6379",
Type: "node",
})
// 創建 Centrifugo 服務
svc := centrifugo.NewService(centrifugo.ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "your-api-key",
TokenSecret: "your-jwt-secret",
Redis: rds, // 可選,用於黑名單和在線狀態
})
```
### 2. 發布訊息
```go
ctx := context.Background()
// 方法 1: 使用 Service 便捷方法
result, err := svc.PublishJSON(ctx, "chat:room-123", map[string]interface{}{
"message": "Hello, World!",
"user": "daniel",
})
// 方法 2: 使用 Client 直接調用
result, err := svc.Client().Publish(ctx, "chat:room-123", []byte(`{"message": "Hello!"}`))
// 批量發布到多個頻道
channels := []string{"user:1", "user:2", "user:3"}
err := svc.BroadcastJSON(ctx, channels, map[string]string{
"type": "notification",
"message": "System maintenance",
})
```
### 3. Token 生成
```go
// 快速生成連線 Token
token, err := svc.GenerateToken("user-123")
// 生成帶用戶資訊的 Token
token, err := svc.GenerateTokenWithInfo("user-123", map[string]interface{}{
"name": "Daniel",
"avatar": "https://example.com/avatar.jpg",
})
// 完整選項
token, err := svc.Token().GenerateConnectionToken(centrifugo.ConnectionTokenOptions{
UserID: "user-123",
Info: map[string]interface{}{"role": "admin"},
Channels: []string{"chat:room-1", "chat:room-2"}, // 自動訂閱
})
// 訂閱 Token用於私有頻道
token, err := svc.Token().QuickSubscriptionToken("user-123", "private:room-456")
```
### 4. 撤銷 Token踢人
```go
// 最常用:撤銷用戶所有 Token 並斷開連線
// 適用於:用戶被封禁、密碼變更、用戶登出全部設備
err := svc.InvalidateUser(ctx, "user-123")
// 只斷開連線(不撤銷 Token
err := svc.Disconnect(ctx, "user-123")
// 撤銷特定 Token需要 JTI
err := svc.Blacklist().RevokeToken(ctx, jti, time.Hour)
// 撤銷用戶所有 Token不斷開連線
err := svc.Blacklist().RevokeUserTokens(ctx, "user-123")
```
### 5. 在線狀態追蹤
```go
// 檢查單一用戶是否在線
online, err := svc.IsUserOnline(ctx, "user-123")
// 批量獲取在線狀態
status, err := svc.GetUsersOnlineStatus(ctx, []string{"user-1", "user-2", "user-3"})
// status = map[string]bool{"user-1": true, "user-2": false, "user-3": true}
// 處理 Centrifugo Connect/Disconnect Proxy 事件
svc.Online().HandleConnect(ctx, "user-123")
svc.Online().HandleDisconnect(ctx, "user-123")
// 使用 Centrifugo Presence API頻道級別
users, err := svc.Online().GetChannelOnlineUsers(ctx, "chat:room-123")
stats, err := svc.Online().GetChannelStats(ctx, "chat:room-123")
```
---
## 獨立使用各元件
如果不需要完整的 Service可以獨立使用各元件
### HTTP API Client
```go
// 創建客戶端
client := centrifugo.NewClient("http://localhost:8000", "your-api-key")
// 使用自定義配置
client := centrifugo.NewClientWithConfig(centrifugo.ClientConfig{
APIURL: "http://localhost:8000",
APIKey: "your-api-key",
Timeout: 5 * time.Second,
MaxIdleConns: 200,
MaxIdleConnsPerHost: 50,
})
// API 調用
client.Publish(ctx, channel, data)
client.PublishJSON(ctx, channel, data)
client.Broadcast(ctx, channels, data)
client.Subscribe(ctx, user, channel)
client.Unsubscribe(ctx, user, channel)
client.Disconnect(ctx, user)
client.DisconnectWithCode(ctx, user, code, reason)
client.Presence(ctx, channel)
client.PresenceStats(ctx, channel)
client.History(ctx, channel, limit)
client.HistoryReverse(ctx, channel, limit)
client.Channels(ctx)
client.ChannelsWithPattern(ctx, pattern)
client.Info(ctx)
client.Ping(ctx)
```
### Token Generator
```go
// 創建生成器
tokenGen := centrifugo.NewTokenGenerator("your-jwt-secret")
// 使用自定義配置
tokenGen := centrifugo.NewTokenGeneratorWithConfig(centrifugo.TokenConfig{
Secret: "your-jwt-secret",
ExpireIn: 24 * time.Hour,
})
// 生成 Token
tokenGen.QuickConnectionToken(userID)
tokenGen.QuickSubscriptionToken(userID, channel)
tokenGen.GenerateConnectionToken(opts)
tokenGen.GenerateSubscriptionToken(opts)
tokenGen.GenerateAnonymousToken()
```
### Token Blacklist
```go
// 創建黑名單管理器
blacklist := centrifugo.NewTokenBlacklist(redisClient)
// 撤銷操作
blacklist.RevokeToken(ctx, jti, ttl) // 撤銷單一 Token
blacklist.RevokeUserTokens(ctx, userID) // 撤銷用戶所有 Token
// 驗證操作
blacklist.IsTokenRevoked(ctx, jti) // 檢查 Token 是否被撤銷
blacklist.GetUserTokenVersion(ctx, userID) // 獲取用戶 Token 版本
blacklist.IsTokenVersionValid(ctx, userID, v) // 檢查版本是否有效
```
### Online Manager
```go
// 創建管理器
store := centrifugo.NewRedisOnlineStore(redisClient)
onlineManager := centrifugo.NewOnlineManagerWithTTL(client, store, 5*time.Minute)
// Redis 存儲操作
onlineManager.HandleConnect(ctx, userID)
onlineManager.HandleDisconnect(ctx, userID)
onlineManager.IsUserOnline(ctx, userID)
onlineManager.GetUsersOnlineStatus(ctx, userIDs)
onlineManager.RefreshOnline(ctx, userID)
// Centrifugo Presence API
onlineManager.IsUserInChannel(ctx, userID, channel)
onlineManager.GetChannelOnlineUsers(ctx, channel)
onlineManager.GetChannelStats(ctx, channel)
```
---
## 前端整合範例
### JavaScript (使用 centrifuge-js)
```javascript
import { Centrifuge } from 'centrifuge';
// 從後端 API 獲取 Token
const getToken = async () => {
const response = await fetch('/api/centrifugo/token');
const data = await response.json();
return data.token;
};
// 創建連線
const centrifuge = new Centrifuge('ws://localhost:8000/connection/websocket', {
getToken: getToken,
});
// 訂閱頻道
const sub = centrifuge.newSubscription('chat:room-123');
sub.on('publication', (ctx) => {
console.log('Received:', ctx.data);
});
sub.subscribe();
// 連線
centrifuge.connect();
```
### 後端 Token API
```go
// handlers/centrifugo.go
func (h *Handler) GetConnectionToken(c *gin.Context) {
userID := c.GetString("user_id") // 從 JWT 或 session 獲取
token, err := h.svc.GenerateToken(userID)
if err != nil {
c.JSON(500, gin.H{"error": "failed to generate token"})
return
}
c.JSON(200, gin.H{"token": token})
}
func (h *Handler) GetSubscriptionToken(c *gin.Context) {
userID := c.GetString("user_id")
channel := c.Query("channel")
// 驗證用戶是否有權限訂閱此頻道
if !h.canSubscribe(userID, channel) {
c.JSON(403, gin.H{"error": "forbidden"})
return
}
token, err := h.svc.Token().QuickSubscriptionToken(userID, channel)
if err != nil {
c.JSON(500, gin.H{"error": "failed to generate token"})
return
}
c.JSON(200, gin.H{"token": token})
}
```
---
## Centrifugo Proxy 整合
### Connect Proxy
```go
// POST /centrifugo/connect
func (h *Handler) CentrifugoConnect(c *gin.Context) {
var req struct {
Client string `json:"client"`
Transport string `json:"transport"`
Protocol string `json:"protocol"`
Data []byte `json:"data"`
}
c.BindJSON(&req)
// 從 Token 驗證用戶Centrifugo 會傳遞)
userID := extractUserID(req.Data)
// 記錄連線
h.svc.Online().HandleConnect(c, userID)
c.JSON(200, gin.H{
"result": map[string]interface{}{
"user": userID,
},
})
}
```
### Disconnect Proxy
```go
// POST /centrifugo/disconnect
func (h *Handler) CentrifugoDisconnect(c *gin.Context) {
var req struct {
Client string `json:"client"`
User string `json:"user"`
}
c.BindJSON(&req)
// 記錄離線
h.svc.Online().HandleDisconnect(c, req.User)
c.JSON(200, gin.H{"result": map[string]interface{}{}})
}
```
---
## Centrifugo 配置參考
```json
{
"token_hmac_secret_key": "your-jwt-secret",
"api_key": "your-api-key",
"admin": true,
"allowed_origins": ["http://localhost:3000"],
"proxy_connect_endpoint": "http://localhost:8080/centrifugo/connect",
"proxy_disconnect_endpoint": "http://localhost:8080/centrifugo/disconnect",
"namespaces": [
{
"name": "chat",
"presence": true,
"history_size": 100,
"history_ttl": "300s"
},
{
"name": "private",
"presence": true,
"protected": true
}
]
}
```
---
## API 參考
### Service 方法
| 方法 | 說明 |
|------|------|
| `Client()` | 返回 HTTP API 客戶端 |
| `Token()` | 返回 Token 生成器 |
| `Blacklist()` | 返回黑名單管理器(可能為 nil |
| `Online()` | 返回在線狀態管理器(可能為 nil |
| `PublishJSON(ctx, channel, data)` | 發布 JSON 訊息 |
| `BroadcastJSON(ctx, channels, data)` | 批量發布 JSON 訊息 |
| `Disconnect(ctx, userID)` | 斷開用戶連線 |
| `GenerateToken(userID)` | 快速生成連線 Token |
| `GenerateTokenWithInfo(userID, info)` | 生成帶資訊的連線 Token |
| `InvalidateUser(ctx, userID)` | 撤銷所有 Token 並斷開連線 |
| `IsUserOnline(ctx, userID)` | 檢查用戶是否在線 |
| `GetUsersOnlineStatus(ctx, userIDs)` | 批量獲取在線狀態 |
### Client 方法
| 方法 | 說明 | 返回值 |
|------|------|--------|
| `Publish(ctx, channel, data)` | 發布訊息 | `*PublishResult, error` |
| `PublishJSON(ctx, channel, data)` | 發布 JSON | `*PublishResult, error` |
| `Broadcast(ctx, channels, data)` | 批量發布 | `error` |
| `BroadcastJSON(ctx, channels, data)` | 批量發布 JSON | `error` |
| `Subscribe(ctx, user, channel)` | 訂閱用戶 | `error` |
| `Unsubscribe(ctx, user, channel)` | 取消訂閱 | `error` |
| `Disconnect(ctx, user)` | 斷開連線 | `error` |
| `DisconnectWithCode(ctx, user, code, reason)` | 帶代碼斷開連線 | `error` |
| `Presence(ctx, channel)` | 在線用戶 | `*PresenceResult, error` |
| `PresenceStats(ctx, channel)` | 在線統計 | `*PresenceStatsResult, error` |
| `History(ctx, channel, limit)` | 歷史訊息 | `*HistoryResult, error` |
| `HistoryReverse(ctx, channel, limit)` | 歷史訊息(倒序) | `*HistoryResult, error` |
| `Channels(ctx)` | 活躍頻道 | `*ChannelsResult, error` |
| `ChannelsWithPattern(ctx, pattern)` | 匹配頻道 | `*ChannelsResult, error` |
| `Info(ctx)` | 伺服器資訊 | `*InfoResult, error` |
| `Ping(ctx)` | 健康檢查 | `error` |
### TokenGenerator 方法
| 方法 | 說明 |
|------|------|
| `GenerateConnectionToken(opts)` | 生成連線 Token完整選項 |
| `GenerateSubscriptionToken(opts)` | 生成訂閱 Token完整選項 |
| `GenerateAnonymousToken()` | 生成匿名 Token |
| `QuickConnectionToken(userID)` | 快速生成連線 Token |
| `QuickSubscriptionToken(userID, channel)` | 快速生成訂閱 Token |
### TokenBlacklist 方法
| 方法 | 說明 |
|------|------|
| `RevokeToken(ctx, jti, ttl)` | 撤銷特定 Token |
| `RevokeUserTokens(ctx, userID)` | 撤銷用戶所有 Token |
| `IsTokenRevoked(ctx, jti)` | 檢查 Token 是否被撤銷 |
| `GetUserTokenVersion(ctx, userID)` | 獲取用戶 Token 版本 |
| `IsTokenVersionValid(ctx, userID, version)` | 檢查版本是否有效 |
---
## 錯誤處理
```go
result, err := svc.Client().Publish(ctx, channel, data)
if err != nil {
// 檢查是否為 Centrifugo API 錯誤
if apiErr, ok := err.(*centrifugo.APIError); ok {
fmt.Printf("Centrifugo error code: %d, message: %s\n",
apiErr.Code, apiErr.Message)
} else {
// 網路錯誤或其他錯誤
fmt.Printf("Error: %v\n", err)
}
}
// 檢查特定錯誤
if errors.Is(err, centrifugo.ErrBlacklistNotConfigured) {
// 黑名單未配置
}
if errors.Is(err, centrifugo.ErrOnlineStoreNotConfigured) {
// 在線狀態存儲未配置
}
```
---
## 檔案結構
```
pkg/library/centrifugo/
├── centrifugo.go # 主入口Service 整合介面
├── client.go # HTTP API 客戶端
├── token.go # JWT Token 生成器
├── blacklist.go # Token 黑名單管理
├── online.go # 在線狀態管理介面
├── online_redis.go # Redis 在線狀態實作
├── README.md # 文檔
└── *_test.go # 測試文件
```
---
## License
MIT

View File

@ -0,0 +1,130 @@
package centrifugo
import (
"context"
"errors"
"fmt"
"time"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// 錯誤定義
var (
ErrBlacklistNotConfigured = errors.New("token blacklist is not configured")
ErrOnlineStoreNotConfigured = errors.New("online store is not configured")
)
// TokenBlacklist Token 黑名單管理器
// 提供兩種撤銷機制:
// 1. 單一 Token 撤銷:使用 JTIJWT ID將特定 Token 加入黑名單
// 2. 用戶全部撤銷:使用版本號機制,使用戶之前所有 Token 失效
type TokenBlacklist struct {
redis *redis.Redis
prefix string
}
// NewTokenBlacklist 創建 Token 黑名單管理器
func NewTokenBlacklist(redisClient *redis.Redis) *TokenBlacklist {
return &TokenBlacklist{
redis: redisClient,
prefix: "centrifugo:blacklist:",
}
}
// NewTokenBlacklistWithPrefix 創建帶自定義前綴的 Token 黑名單管理器
func NewTokenBlacklistWithPrefix(redisClient *redis.Redis, prefix string) *TokenBlacklist {
return &TokenBlacklist{
redis: redisClient,
prefix: prefix,
}
}
// ==================== 撤銷操作 ====================
// RevokeToken 撤銷特定 Token使用 JTI
// ttl: 黑名單過期時間,應設置為 Token 的剩餘有效時間
//
// 使用場景:
// - 用戶登出單一設備
// - 檢測到可疑活動的特定 session
func (b *TokenBlacklist) RevokeToken(ctx context.Context, jti string, ttl time.Duration) error {
if jti == "" {
return errors.New("jti cannot be empty")
}
key := b.tokenKey(jti)
return b.redis.SetexCtx(ctx, key, "revoked", int(ttl.Seconds()))
}
// RevokeUserTokens 撤銷用戶的所有 Token使用版本控制
// 通過更新版本號,使該用戶之前發出的所有 Token 失效
//
// 使用場景:
// - 用戶被封禁
// - 密碼變更
// - 用戶主動登出全部設備
func (b *TokenBlacklist) RevokeUserTokens(ctx context.Context, userID string) error {
if userID == "" {
return errors.New("userID cannot be empty")
}
key := b.userVersionKey(userID)
version := time.Now().UnixNano()
// 設置 7 天過期,足夠長於任何 Token 的有效期
return b.redis.SetexCtx(ctx, key, fmt.Sprintf("%d", version), 7*24*3600)
}
// ==================== 驗證操作 ====================
// IsTokenRevoked 檢查 Token 是否被撤銷(使用 JTI
func (b *TokenBlacklist) IsTokenRevoked(ctx context.Context, jti string) (bool, error) {
if jti == "" {
return false, nil
}
key := b.tokenKey(jti)
exists, err := b.redis.ExistsCtx(ctx, key)
if err != nil {
return false, err
}
return exists, nil
}
// GetUserTokenVersion 獲取用戶的 Token 版本
// 返回 0 表示沒有設置版本(用戶從未被撤銷過)
func (b *TokenBlacklist) GetUserTokenVersion(ctx context.Context, userID string) (int64, error) {
if userID == "" {
return 0, nil
}
key := b.userVersionKey(userID)
val, err := b.redis.GetCtx(ctx, key)
if err != nil {
return 0, err
}
if val == "" {
return 0, nil
}
var version int64
_, err = fmt.Sscanf(val, "%d", &version)
return version, err
}
// IsTokenVersionValid 檢查 Token 版本是否有效
// tokenVersion: Token 內嵌的版本號
// 如果 currentVersion > tokenVersion表示 Token 已被撤銷
func (b *TokenBlacklist) IsTokenVersionValid(ctx context.Context, userID string, tokenVersion int64) (bool, error) {
currentVersion, err := b.GetUserTokenVersion(ctx, userID)
if err != nil {
return false, err
}
// 如果沒有設置版本,或 Token 版本 >= 當前版本,則有效
return currentVersion == 0 || tokenVersion >= currentVersion, nil
}
// ==================== Key 生成 ====================
func (b *TokenBlacklist) tokenKey(jti string) string {
return b.prefix + "token:" + jti
}
func (b *TokenBlacklist) userVersionKey(userID string) string {
return b.prefix + "user_version:" + userID
}

View File

@ -0,0 +1,246 @@
package centrifugo
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/core/stores/redis"
)
func setupTestRedis(t *testing.T) (*redis.Redis, func()) {
mr, err := miniredis.Run()
require.NoError(t, err)
rds, err := redis.NewRedis(redis.RedisConf{
Host: mr.Addr(),
Type: "node",
})
require.NoError(t, err)
return rds, func() {
mr.Close()
}
}
func TestNewTokenBlacklist(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
assert.NotNil(t, blacklist)
assert.Equal(t, "centrifugo:blacklist:", blacklist.prefix)
}
func TestNewTokenBlacklistWithPrefix(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklistWithPrefix(rds, "custom:prefix:")
assert.NotNil(t, blacklist)
assert.Equal(t, "custom:prefix:", blacklist.prefix)
}
func TestRevokeToken(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
jti := "test-jti-123"
ttl := 1 * time.Hour
// 撤銷 Token
err := blacklist.RevokeToken(ctx, jti, ttl)
require.NoError(t, err)
// 檢查是否被撤銷
revoked, err := blacklist.IsTokenRevoked(ctx, jti)
require.NoError(t, err)
assert.True(t, revoked)
}
func TestRevokeToken_EmptyJTI(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
err := blacklist.RevokeToken(ctx, "", time.Hour)
assert.Error(t, err)
assert.Contains(t, err.Error(), "jti cannot be empty")
}
func TestIsTokenRevoked_NotRevoked(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
// 檢查未撤銷的 Token
revoked, err := blacklist.IsTokenRevoked(ctx, "non-existent-jti")
require.NoError(t, err)
assert.False(t, revoked)
}
func TestIsTokenRevoked_EmptyJTI(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
// 空 JTI 應該返回 false
revoked, err := blacklist.IsTokenRevoked(ctx, "")
require.NoError(t, err)
assert.False(t, revoked)
}
func TestRevokeUserTokens(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
userID := "user-123"
// 撤銷用戶所有 Token
err := blacklist.RevokeUserTokens(ctx, userID)
require.NoError(t, err)
// 獲取版本
version, err := blacklist.GetUserTokenVersion(ctx, userID)
require.NoError(t, err)
assert.Greater(t, version, int64(0))
}
func TestRevokeUserTokens_EmptyUserID(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
err := blacklist.RevokeUserTokens(ctx, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "userID cannot be empty")
}
func TestGetUserTokenVersion_NoVersion(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
// 未設置版本的用戶應該返回 0
version, err := blacklist.GetUserTokenVersion(ctx, "new-user")
require.NoError(t, err)
assert.Equal(t, int64(0), version)
}
func TestGetUserTokenVersion_EmptyUserID(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
version, err := blacklist.GetUserTokenVersion(ctx, "")
require.NoError(t, err)
assert.Equal(t, int64(0), version)
}
func TestIsTokenVersionValid(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
userID := "user-123"
// 未設置版本時,任何版本都應該有效
valid, err := blacklist.IsTokenVersionValid(ctx, userID, 0)
require.NoError(t, err)
assert.True(t, valid)
// 撤銷用戶 Token
err = blacklist.RevokeUserTokens(ctx, userID)
require.NoError(t, err)
// 獲取當前版本
currentVersion, err := blacklist.GetUserTokenVersion(ctx, userID)
require.NoError(t, err)
// 舊版本應該無效
valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion-1)
require.NoError(t, err)
assert.False(t, valid)
// 當前版本應該有效
valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion)
require.NoError(t, err)
assert.True(t, valid)
// 更新版本應該有效
valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion+1)
require.NoError(t, err)
assert.True(t, valid)
}
func TestRevokeUserTokens_MultipleRevokes(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklist(rds)
ctx := context.Background()
userID := "user-123"
// 第一次撤銷
err := blacklist.RevokeUserTokens(ctx, userID)
require.NoError(t, err)
version1, err := blacklist.GetUserTokenVersion(ctx, userID)
require.NoError(t, err)
// 等待一點時間確保時間戳不同
time.Sleep(10 * time.Millisecond)
// 第二次撤銷
err = blacklist.RevokeUserTokens(ctx, userID)
require.NoError(t, err)
version2, err := blacklist.GetUserTokenVersion(ctx, userID)
require.NoError(t, err)
// 第二次的版本應該更大
assert.Greater(t, version2, version1)
// 第一次的版本應該已經無效
valid, err := blacklist.IsTokenVersionValid(ctx, userID, version1)
require.NoError(t, err)
assert.False(t, valid)
}
func TestKeyGeneration(t *testing.T) {
rds, cleanup := setupTestRedis(t)
defer cleanup()
blacklist := NewTokenBlacklistWithPrefix(rds, "test:")
// 測試 key 生成
assert.Equal(t, "test:token:jti-123", blacklist.tokenKey("jti-123"))
assert.Equal(t, "test:user_version:user-456", blacklist.userVersionKey("user-456"))
}

View File

@ -0,0 +1,188 @@
// Package centrifugo 提供 Centrifugo 即時訊息服務的完整 Go 客戶端
//
// 功能包含:
// - HTTP API 客戶端(發布訊息、訂閱管理、在線狀態等)
// - JWT Token 生成(連線認證、私有頻道訂閱)
// - Token 黑名單管理(撤銷單一 Token、撤銷用戶所有 Token
// - 在線狀態追蹤Redis 或記憶體存儲)
//
// 基本使用:
//
// // 創建服務實例
// svc := centrifugo.NewService(centrifugo.ServiceConfig{
// APIURL: "http://localhost:8000",
// APIKey: "your-api-key",
// TokenSecret: "your-jwt-secret",
// Redis: redisClient, // 可選,用於黑名單和在線狀態
// })
//
// // 發布訊息
// svc.Client().PublishJSON(ctx, "chat:room-1", data)
//
// // 生成 Token
// token, _ := svc.Token().QuickConnectionToken("user-123")
//
// // 撤銷用戶所有 Token 並踢出
// svc.InvalidateUser(ctx, "user-123")
package centrifugo
import (
"context"
"time"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// Service Centrifugo 服務整合介面
// 提供 HTTP API、Token 生成、黑名單管理、在線狀態追蹤的統一入口
type Service struct {
client *Client
token *TokenGenerator
blacklist *TokenBlacklist
online *OnlineManager
}
// ServiceConfig 服務配置
type ServiceConfig struct {
// APIURL Centrifugo HTTP API 地址(必填)
APIURL string
// APIKey Centrifugo API 密鑰(必填)
APIKey string
// TokenSecret JWT Token 簽名密鑰(必填)
TokenSecret string
// TokenExpire Token 過期時間(預設 1 小時)
TokenExpire time.Duration
// Redis 客戶端(可選,用於黑名單和在線狀態)
Redis *redis.Redis
// ClientConfig HTTP 客戶端配置(可選)
ClientConfig *ClientConfig
// OnlineTTL 在線狀態過期時間(預設 5 分鐘)
OnlineTTL time.Duration
// KeyPrefix Redis key 前綴(預設 "centrifugo:"
KeyPrefix string
}
// NewService 創建 Centrifugo 服務實例
func NewService(cfg ServiceConfig) *Service {
// 設定預設值
if cfg.TokenExpire == 0 {
cfg.TokenExpire = time.Hour
}
if cfg.OnlineTTL == 0 {
cfg.OnlineTTL = 5 * time.Minute
}
if cfg.KeyPrefix == "" {
cfg.KeyPrefix = "centrifugo:"
}
// 創建 HTTP 客戶端
var client *Client
if cfg.ClientConfig != nil {
client = NewClientWithConfig(*cfg.ClientConfig)
} else {
client = NewClient(cfg.APIURL, cfg.APIKey)
}
// 創建 Token 生成器
token := NewTokenGeneratorWithConfig(TokenConfig{
Secret: cfg.TokenSecret,
ExpireIn: cfg.TokenExpire,
})
svc := &Service{
client: client,
token: token,
}
// 如果有 Redis創建黑名單管理器和在線狀態管理器
if cfg.Redis != nil {
svc.blacklist = NewTokenBlacklistWithPrefix(cfg.Redis, cfg.KeyPrefix+"blacklist:")
store := NewRedisOnlineStoreWithPrefix(cfg.Redis, cfg.KeyPrefix+"online:")
svc.online = NewOnlineManagerWithTTL(client, store, cfg.OnlineTTL)
}
return svc
}
// Client 返回 HTTP API 客戶端
func (s *Service) Client() *Client {
return s.client
}
// Token 返回 Token 生成器
func (s *Service) Token() *TokenGenerator {
return s.token
}
// Blacklist 返回黑名單管理器(可能為 nil
func (s *Service) Blacklist() *TokenBlacklist {
return s.blacklist
}
// Online 返回在線狀態管理器(可能為 nil
func (s *Service) Online() *OnlineManager {
return s.online
}
// ==================== 便捷方法 ====================
// PublishJSON 發布 JSON 訊息到頻道
func (s *Service) PublishJSON(ctx context.Context, channel string, data interface{}) (*PublishResult, error) {
return s.client.PublishJSON(ctx, channel, data)
}
// BroadcastJSON 批量發布 JSON 訊息到多個頻道
func (s *Service) BroadcastJSON(ctx context.Context, channels []string, data interface{}) error {
return s.client.BroadcastJSON(ctx, channels, data)
}
// Disconnect 斷開用戶連線
func (s *Service) Disconnect(ctx context.Context, userID string) error {
return s.client.Disconnect(ctx, userID)
}
// GenerateToken 快速生成連線 Token
func (s *Service) GenerateToken(userID string) (string, error) {
return s.token.QuickConnectionToken(userID)
}
// GenerateTokenWithInfo 生成帶用戶資訊的連線 Token
func (s *Service) GenerateTokenWithInfo(userID string, info map[string]interface{}) (string, error) {
return s.token.GenerateConnectionToken(ConnectionTokenOptions{
UserID: userID,
Info: info,
})
}
// InvalidateUser 撤銷用戶所有 Token 並斷開連線
// 這是最常用的「踢人」方法,適用於:
// - 用戶被封禁
// - 密碼變更
// - 用戶登出(全設備)
func (s *Service) InvalidateUser(ctx context.Context, userID string) error {
// 撤銷所有 Token
if s.blacklist != nil {
if err := s.blacklist.RevokeUserTokens(ctx, userID); err != nil {
return err
}
}
// 斷開連線
return s.client.Disconnect(ctx, userID)
}
// IsUserOnline 檢查用戶是否在線
func (s *Service) IsUserOnline(ctx context.Context, userID string) (bool, error) {
if s.online == nil {
return false, ErrOnlineStoreNotConfigured
}
return s.online.IsUserOnline(ctx, userID)
}
// GetUsersOnlineStatus 批量獲取用戶在線狀態
func (s *Service) GetUsersOnlineStatus(ctx context.Context, userIDs []string) (map[string]bool, error) {
if s.online == nil {
return nil, ErrOnlineStoreNotConfigured
}
return s.online.GetUsersOnlineStatus(ctx, userIDs)
}

View File

@ -0,0 +1,303 @@
package centrifugo
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/core/stores/redis"
)
func setupServiceTestRedis(t *testing.T) (*redis.Redis, func()) {
mr, err := miniredis.Run()
require.NoError(t, err)
rds, err := redis.NewRedis(redis.RedisConf{
Host: mr.Addr(),
Type: "node",
})
require.NoError(t, err)
return rds, func() {
mr.Close()
}
}
func TestNewService(t *testing.T) {
rds, cleanup := setupServiceTestRedis(t)
defer cleanup()
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
Redis: rds,
})
assert.NotNil(t, svc)
assert.NotNil(t, svc.Client())
assert.NotNil(t, svc.Token())
assert.NotNil(t, svc.Blacklist())
assert.NotNil(t, svc.Online())
}
func TestNewService_WithoutRedis(t *testing.T) {
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
Redis: nil, // 沒有 Redis
})
assert.NotNil(t, svc)
assert.NotNil(t, svc.Client())
assert.NotNil(t, svc.Token())
assert.Nil(t, svc.Blacklist()) // 應該為 nil
assert.Nil(t, svc.Online()) // 應該為 nil
}
func TestNewService_DefaultValues(t *testing.T) {
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
})
assert.NotNil(t, svc)
// 檢查預設值已被應用(通過生成 Token 間接驗證)
token, err := svc.GenerateToken("user-123")
require.NoError(t, err)
assert.NotEmpty(t, token)
}
func TestNewService_WithCustomConfig(t *testing.T) {
rds, cleanup := setupServiceTestRedis(t)
defer cleanup()
customExpire := 24 * time.Hour
customTTL := 10 * time.Minute
customPrefix := "custom:"
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
TokenExpire: customExpire,
Redis: rds,
OnlineTTL: customTTL,
KeyPrefix: customPrefix,
})
assert.NotNil(t, svc)
}
func TestService_GenerateToken(t *testing.T) {
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
})
token, err := svc.GenerateToken("user-123")
require.NoError(t, err)
assert.NotEmpty(t, token)
}
func TestService_GenerateTokenWithInfo(t *testing.T) {
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
})
info := map[string]interface{}{
"name": "Daniel",
"avatar": "https://example.com/avatar.jpg",
}
token, err := svc.GenerateTokenWithInfo("user-123", info)
require.NoError(t, err)
assert.NotEmpty(t, token)
}
func TestService_IsUserOnline_WithoutRedis(t *testing.T) {
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
Redis: nil,
})
_, err := svc.IsUserOnline(context.Background(), "user-123")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrOnlineStoreNotConfigured)
}
func TestService_GetUsersOnlineStatus_WithoutRedis(t *testing.T) {
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
Redis: nil,
})
_, err := svc.GetUsersOnlineStatus(context.Background(), []string{"user-1", "user-2"})
assert.Error(t, err)
assert.ErrorIs(t, err, ErrOnlineStoreNotConfigured)
}
func TestService_IsUserOnline_WithRedis(t *testing.T) {
rds, cleanup := setupServiceTestRedis(t)
defer cleanup()
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
Redis: rds,
})
ctx := context.Background()
userID := "user-123"
// 初始狀態應該是離線
online, err := svc.IsUserOnline(ctx, userID)
require.NoError(t, err)
assert.False(t, online)
// 處理連線事件
err = svc.Online().HandleConnect(ctx, userID)
require.NoError(t, err)
// 現在應該在線
online, err = svc.IsUserOnline(ctx, userID)
require.NoError(t, err)
assert.True(t, online)
// 處理斷線事件
err = svc.Online().HandleDisconnect(ctx, userID)
require.NoError(t, err)
// 現在應該離線
online, err = svc.IsUserOnline(ctx, userID)
require.NoError(t, err)
assert.False(t, online)
}
func TestService_GetUsersOnlineStatus_WithRedis(t *testing.T) {
rds, cleanup := setupServiceTestRedis(t)
defer cleanup()
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
Redis: rds,
})
ctx := context.Background()
// 設置一些用戶在線
err := svc.Online().HandleConnect(ctx, "user-1")
require.NoError(t, err)
err = svc.Online().HandleConnect(ctx, "user-3")
require.NoError(t, err)
// 批量獲取在線狀態
status, err := svc.GetUsersOnlineStatus(ctx, []string{"user-1", "user-2", "user-3"})
require.NoError(t, err)
assert.True(t, status["user-1"])
assert.False(t, status["user-2"])
assert.True(t, status["user-3"])
}
func TestService_Blacklist_Integration(t *testing.T) {
rds, cleanup := setupServiceTestRedis(t)
defer cleanup()
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
Redis: rds,
})
ctx := context.Background()
userID := "user-123"
// 撤銷用戶所有 Token
err := svc.Blacklist().RevokeUserTokens(ctx, userID)
require.NoError(t, err)
// 獲取版本
version, err := svc.Blacklist().GetUserTokenVersion(ctx, userID)
require.NoError(t, err)
assert.Greater(t, version, int64(0))
// 檢查舊版本無效
valid, err := svc.Blacklist().IsTokenVersionValid(ctx, userID, version-1)
require.NoError(t, err)
assert.False(t, valid)
// 檢查當前版本有效
valid, err = svc.Blacklist().IsTokenVersionValid(ctx, userID, version)
require.NoError(t, err)
assert.True(t, valid)
}
func TestService_MultipleConnections(t *testing.T) {
rds, cleanup := setupServiceTestRedis(t)
defer cleanup()
svc := NewService(ServiceConfig{
APIURL: "http://localhost:8000",
APIKey: "test-api-key",
TokenSecret: "test-secret",
Redis: rds,
})
ctx := context.Background()
userID := "user-123"
// 模擬多個設備連線
err := svc.Online().HandleConnect(ctx, userID)
require.NoError(t, err)
err = svc.Online().HandleConnect(ctx, userID)
require.NoError(t, err)
err = svc.Online().HandleConnect(ctx, userID)
require.NoError(t, err)
// 用戶應該在線
online, err := svc.IsUserOnline(ctx, userID)
require.NoError(t, err)
assert.True(t, online)
// 斷開一個設備
err = svc.Online().HandleDisconnect(ctx, userID)
require.NoError(t, err)
// 用戶仍然在線(還有 2 個連線)
online, err = svc.IsUserOnline(ctx, userID)
require.NoError(t, err)
assert.True(t, online)
// 斷開剩餘設備
err = svc.Online().HandleDisconnect(ctx, userID)
require.NoError(t, err)
err = svc.Online().HandleDisconnect(ctx, userID)
require.NoError(t, err)
// 用戶現在離線
online, err = svc.IsUserOnline(ctx, userID)
require.NoError(t, err)
assert.False(t, online)
}

View File

@ -0,0 +1,519 @@
package centrifugo
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"time"
)
// Client Centrifugo 客戶端
type Client struct {
apiURL string
apiKey string
client *http.Client
}
// ClientConfig 客戶端配置
type ClientConfig struct {
// APIURL Centrifugo API 地址(必填)
APIURL string
// APIKey API 密鑰(必填)
APIKey string
// Timeout 整體請求超時時間(預設 10 秒)
Timeout time.Duration
// MaxIdleConns 最大閒置連線數(預設 100
MaxIdleConns int
// MaxIdleConnsPerHost 每個 host 最大閒置連線數(預設 20適合高併發
MaxIdleConnsPerHost int
// IdleConnTimeout 閒置連線超時時間(預設 90 秒)
IdleConnTimeout time.Duration
// DialTimeout 建立連線超時時間(預設 5 秒)
DialTimeout time.Duration
// TLSHandshakeTimeout TLS 握手超時時間(預設 5 秒)
TLSHandshakeTimeout time.Duration
// ResponseHeaderTimeout 等待響應頭超時時間(預設 10 秒)
ResponseHeaderTimeout time.Duration
}
// DefaultConfig 返回預設配置
func DefaultConfig(apiURL, apiKey string) ClientConfig {
return ClientConfig{
APIURL: apiURL,
APIKey: apiKey,
Timeout: 10 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20, // 提高以支援高併發
IdleConnTimeout: 90 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
}
}
// HighPerformanceConfig 返回高效能配置(適合高併發場景)
func HighPerformanceConfig(apiURL, apiKey string) ClientConfig {
return ClientConfig{
APIURL: apiURL,
APIKey: apiKey,
Timeout: 5 * time.Second, // 更短的超時
MaxIdleConns: 200, // 更多閒置連線
MaxIdleConnsPerHost: 50, // 更多每 host 連線
IdleConnTimeout: 120 * time.Second,
DialTimeout: 3 * time.Second,
TLSHandshakeTimeout: 3 * time.Second,
ResponseHeaderTimeout: 5 * time.Second,
}
}
// NewClient 創建新的 Centrifugo 客戶端(使用預設配置)
func NewClient(apiURL, apiKey string) *Client {
return NewClientWithConfig(DefaultConfig(apiURL, apiKey))
}
// NewClientWithConfig 創建使用自定義配置的 Centrifugo 客戶端
func NewClientWithConfig(config ClientConfig) *Client {
// 設定預設值
if config.Timeout == 0 {
config.Timeout = 10 * time.Second
}
if config.MaxIdleConns == 0 {
config.MaxIdleConns = 100
}
if config.MaxIdleConnsPerHost == 0 {
config.MaxIdleConnsPerHost = 20
}
if config.IdleConnTimeout == 0 {
config.IdleConnTimeout = 90 * time.Second
}
if config.DialTimeout == 0 {
config.DialTimeout = 5 * time.Second
}
if config.TLSHandshakeTimeout == 0 {
config.TLSHandshakeTimeout = 5 * time.Second
}
if config.ResponseHeaderTimeout == 0 {
config.ResponseHeaderTimeout = 10 * time.Second
}
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: config.DialTimeout,
KeepAlive: 30 * time.Second, // TCP keep-alive
}).DialContext,
MaxIdleConns: config.MaxIdleConns,
MaxIdleConnsPerHost: config.MaxIdleConnsPerHost,
IdleConnTimeout: config.IdleConnTimeout,
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
ResponseHeaderTimeout: config.ResponseHeaderTimeout,
ForceAttemptHTTP2: true, // 嘗試使用 HTTP/2
}
return &Client{
apiURL: config.APIURL,
apiKey: config.APIKey,
client: &http.Client{
Timeout: config.Timeout,
Transport: transport,
},
}
}
// NewClientWithHTTP 創建使用自定義 HTTP 客戶端的 Centrifugo 客戶端
func NewClientWithHTTP(apiURL, apiKey string, httpClient *http.Client) *Client {
return &Client{
apiURL: apiURL,
apiKey: apiKey,
client: httpClient,
}
}
// ==================== Request/Response 結構 ====================
// PublishRequest 發布請求
type PublishRequest struct {
Channel string `json:"channel"`
Data interface{} `json:"data"`
}
// BroadcastRequest 批量發布請求
type BroadcastRequest struct {
Channels []string `json:"channels"`
Data interface{} `json:"data"`
}
// SubscribeRequest 訂閱請求
type SubscribeRequest struct {
User string `json:"user"`
Channel string `json:"channel"`
}
// UnsubscribeRequest 取消訂閱請求
type UnsubscribeRequest struct {
User string `json:"user"`
Channel string `json:"channel"`
}
// DisconnectRequest 斷開連線請求
type DisconnectRequest struct {
User string `json:"user"`
}
// DisconnectWithCodeRequest 帶代碼的斷開連線請求
type DisconnectWithCodeRequest struct {
User string `json:"user"`
Disconnect DisconnectInfo `json:"disconnect,omitempty"`
}
// DisconnectInfo 斷開連線資訊
type DisconnectInfo struct {
Code uint32 `json:"code"`
Reason string `json:"reason"`
}
// PresenceRequest 在線狀態請求
type PresenceRequest struct {
Channel string `json:"channel"`
}
// PresenceStatsRequest 在線統計請求
type PresenceStatsRequest struct {
Channel string `json:"channel"`
}
// HistoryRequest 歷史訊息請求
type HistoryRequest struct {
Channel string `json:"channel"`
Limit int `json:"limit,omitempty"`
Reverse bool `json:"reverse,omitempty"`
}
// ChannelsRequest 頻道列表請求
type ChannelsRequest struct {
Pattern string `json:"pattern,omitempty"`
}
// InfoRequest 伺服器資訊請求
type InfoRequest struct{}
// APIResponse Centrifugo API 通用響應
type APIResponse struct {
Error *APIError `json:"error,omitempty"`
Result interface{} `json:"result,omitempty"`
}
// APIError Centrifugo API 錯誤
type APIError struct {
Code int `json:"code"`
Message string `json:"message"`
}
func (e *APIError) Error() string {
return fmt.Sprintf("centrifugo error %d: %s", e.Code, e.Message)
}
// PublishResult 發布結果
type PublishResult struct {
Offset uint64 `json:"offset,omitempty"`
Epoch string `json:"epoch,omitempty"`
}
// PresenceResult 在線狀態結果
type PresenceResult struct {
Presence map[string]ClientInfo `json:"presence"`
}
// ClientInfo 客戶端資訊
type ClientInfo struct {
User string `json:"user"`
Client string `json:"client"`
ConnInfo json.RawMessage `json:"conn_info,omitempty"`
ChanInfo json.RawMessage `json:"chan_info,omitempty"`
}
// PresenceStatsResult 在線統計結果
type PresenceStatsResult struct {
NumClients int `json:"num_clients"`
NumUsers int `json:"num_users"`
}
// HistoryResult 歷史訊息結果
type HistoryResult struct {
Publications []Publication `json:"publications"`
Offset uint64 `json:"offset,omitempty"`
Epoch string `json:"epoch,omitempty"`
}
// Publication 發布的訊息
type Publication struct {
Offset uint64 `json:"offset,omitempty"`
Data json.RawMessage `json:"data"`
Info *ClientInfo `json:"info,omitempty"`
}
// ChannelsResult 頻道列表結果
type ChannelsResult struct {
Channels map[string]ChannelInfo `json:"channels"`
}
// ChannelInfo 頻道資訊
type ChannelInfo struct {
NumClients int `json:"num_clients"`
}
// InfoResult 伺服器資訊結果
type InfoResult struct {
Nodes []NodeInfo `json:"nodes"`
}
// NodeInfo 節點資訊
type NodeInfo struct {
UID string `json:"uid"`
Name string `json:"name"`
Version string `json:"version"`
NumClients int `json:"num_clients"`
NumUsers int `json:"num_users"`
NumChannels int `json:"num_channels"`
Uptime int `json:"uptime"`
}
// ==================== 發布相關方法 ====================
// Publish 發布訊息到指定頻道
func (c *Client) Publish(ctx context.Context, channel string, data []byte) (*PublishResult, error) {
req := PublishRequest{
Channel: channel,
Data: json.RawMessage(data),
}
return c.publish(ctx, req)
}
// PublishJSON 發布 JSON 訊息到指定頻道
func (c *Client) PublishJSON(ctx context.Context, channel string, data interface{}) (*PublishResult, error) {
req := PublishRequest{
Channel: channel,
Data: data,
}
return c.publish(ctx, req)
}
func (c *Client) publish(ctx context.Context, req PublishRequest) (*PublishResult, error) {
var result PublishResult
if err := c.callAPI(ctx, "publish", req, &result); err != nil {
return nil, err
}
return &result, nil
}
// Broadcast 批量發布訊息到多個頻道
func (c *Client) Broadcast(ctx context.Context, channels []string, data []byte) error {
req := BroadcastRequest{
Channels: channels,
Data: json.RawMessage(data),
}
return c.callAPI(ctx, "broadcast", req, nil)
}
// BroadcastJSON 批量發布 JSON 訊息到多個頻道
func (c *Client) BroadcastJSON(ctx context.Context, channels []string, data interface{}) error {
req := BroadcastRequest{
Channels: channels,
Data: data,
}
return c.callAPI(ctx, "broadcast", req, nil)
}
// ==================== 訂閱管理方法 ====================
// Subscribe 訂閱用戶到頻道
func (c *Client) Subscribe(ctx context.Context, user, channel string) error {
req := SubscribeRequest{
User: user,
Channel: channel,
}
return c.callAPI(ctx, "subscribe", req, nil)
}
// Unsubscribe 取消用戶訂閱
func (c *Client) Unsubscribe(ctx context.Context, user, channel string) error {
req := UnsubscribeRequest{
User: user,
Channel: channel,
}
return c.callAPI(ctx, "unsubscribe", req, nil)
}
// Disconnect 強制斷開用戶連線
func (c *Client) Disconnect(ctx context.Context, user string) error {
req := DisconnectRequest{
User: user,
}
return c.callAPI(ctx, "disconnect", req, nil)
}
// DisconnectWithCode 強制斷開用戶連線(帶斷開代碼和原因)
func (c *Client) DisconnectWithCode(ctx context.Context, user string, code uint32, reason string) error {
req := DisconnectWithCodeRequest{
User: user,
Disconnect: DisconnectInfo{
Code: code,
Reason: reason,
},
}
return c.callAPI(ctx, "disconnect", req, nil)
}
// ==================== 在線狀態方法 ====================
// Presence 獲取頻道在線用戶
func (c *Client) Presence(ctx context.Context, channel string) (*PresenceResult, error) {
req := PresenceRequest{
Channel: channel,
}
var result PresenceResult
if err := c.callAPI(ctx, "presence", req, &result); err != nil {
return nil, err
}
return &result, nil
}
// PresenceStats 獲取頻道在線統計
func (c *Client) PresenceStats(ctx context.Context, channel string) (*PresenceStatsResult, error) {
req := PresenceStatsRequest{
Channel: channel,
}
var result PresenceStatsResult
if err := c.callAPI(ctx, "presence_stats", req, &result); err != nil {
return nil, err
}
return &result, nil
}
// ==================== 歷史訊息方法 ====================
// History 獲取頻道歷史訊息
func (c *Client) History(ctx context.Context, channel string, limit int) (*HistoryResult, error) {
req := HistoryRequest{
Channel: channel,
Limit: limit,
}
var result HistoryResult
if err := c.callAPI(ctx, "history", req, &result); err != nil {
return nil, err
}
return &result, nil
}
// HistoryReverse 獲取頻道歷史訊息(倒序)
func (c *Client) HistoryReverse(ctx context.Context, channel string, limit int) (*HistoryResult, error) {
req := HistoryRequest{
Channel: channel,
Limit: limit,
Reverse: true,
}
var result HistoryResult
if err := c.callAPI(ctx, "history", req, &result); err != nil {
return nil, err
}
return &result, nil
}
// ==================== 頻道管理方法 ====================
// Channels 獲取所有活躍頻道
func (c *Client) Channels(ctx context.Context) (*ChannelsResult, error) {
req := ChannelsRequest{}
var result ChannelsResult
if err := c.callAPI(ctx, "channels", req, &result); err != nil {
return nil, err
}
return &result, nil
}
// ChannelsWithPattern 獲取匹配模式的活躍頻道
func (c *Client) ChannelsWithPattern(ctx context.Context, pattern string) (*ChannelsResult, error) {
req := ChannelsRequest{
Pattern: pattern,
}
var result ChannelsResult
if err := c.callAPI(ctx, "channels", req, &result); err != nil {
return nil, err
}
return &result, nil
}
// ==================== 伺服器狀態方法 ====================
// Info 獲取伺服器狀態資訊
func (c *Client) Info(ctx context.Context) (*InfoResult, error) {
req := InfoRequest{}
var result InfoResult
if err := c.callAPI(ctx, "info", req, &result); err != nil {
return nil, err
}
return &result, nil
}
// Ping 檢查伺服器是否健康
func (c *Client) Ping(ctx context.Context) error {
_, err := c.Info(ctx)
return err
}
// ==================== 內部方法 ====================
// callAPI 調用 Centrifugo API
func (c *Client) callAPI(ctx context.Context, method string, params interface{}, result interface{}) error {
body, err := json.Marshal(params)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/api/%s", c.apiURL, method)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
httpReq.Header.Set("Authorization", "apikey "+c.apiKey)
}
resp, err := c.client.Do(httpReq)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("centrifugo returned status %d: %s", resp.StatusCode, string(respBody))
}
// 解析響應
var apiResp APIResponse
if result != nil {
apiResp.Result = result
}
if err := json.Unmarshal(respBody, &apiResp); err != nil {
return fmt.Errorf("failed to unmarshal response: %w", err)
}
if apiResp.Error != nil {
return apiResp.Error
}
return nil
}

View File

@ -0,0 +1,175 @@
package centrifugo
import (
"context"
"fmt"
"time"
)
// OnlineStatus 在線狀態
type OnlineStatus struct {
UserID string `json:"user_id"`
IsOnline bool `json:"is_online"`
LastSeenAt time.Time `json:"last_seen_at,omitempty"`
Clients int `json:"clients,omitempty"` // 連線數(可能多個設備)
}
// OnlineStore 在線狀態存儲介面
// 可以用 Redis、Memory 或其他存儲實作
type OnlineStore interface {
// SetOnline 設置用戶在線
SetOnline(ctx context.Context, userID string, ttl time.Duration) error
// SetOffline 設置用戶離線
SetOffline(ctx context.Context, userID string) error
// IsOnline 檢查用戶是否在線
IsOnline(ctx context.Context, userID string) (bool, error)
// GetOnlineUsers 獲取在線用戶列表
GetOnlineUsers(ctx context.Context, userIDs []string) (map[string]bool, error)
// IncrClient 增加用戶連線數
IncrClient(ctx context.Context, userID string) (int64, error)
// DecrClient 減少用戶連線數
DecrClient(ctx context.Context, userID string) (int64, error)
}
// OnlineManager 在線狀態管理器
// 結合 Redis 存儲和 Centrifugo Presence API 提供在線狀態追蹤
type OnlineManager struct {
client *Client
store OnlineStore
ttl time.Duration
}
// NewOnlineManager 創建在線狀態管理器
// store 可以為 nil此時只使用 Centrifugo Presence API
func NewOnlineManager(client *Client, store OnlineStore) *OnlineManager {
return &OnlineManager{
client: client,
store: store,
ttl: 5 * time.Minute, // 預設 5 分鐘過期
}
}
// NewOnlineManagerWithTTL 創建帶 TTL 的在線狀態管理器
func NewOnlineManagerWithTTL(client *Client, store OnlineStore, ttl time.Duration) *OnlineManager {
return &OnlineManager{
client: client,
store: store,
ttl: ttl,
}
}
// ==================== 連線事件處理 ====================
// HandleConnect 處理用戶連線事件(用於 Centrifugo Connect Proxy
func (m *OnlineManager) HandleConnect(ctx context.Context, userID string) error {
if m.store == nil {
return nil
}
// 增加連線數
count, err := m.store.IncrClient(ctx, userID)
if err != nil {
return fmt.Errorf("failed to incr client: %w", err)
}
// 如果是第一個連線,設置在線狀態
if count == 1 {
if err := m.store.SetOnline(ctx, userID, m.ttl); err != nil {
return fmt.Errorf("failed to set online: %w", err)
}
}
return nil
}
// HandleDisconnect 處理用戶斷線事件(用於 Centrifugo Disconnect Proxy
func (m *OnlineManager) HandleDisconnect(ctx context.Context, userID string) error {
if m.store == nil {
return nil
}
// 減少連線數
count, err := m.store.DecrClient(ctx, userID)
if err != nil {
return fmt.Errorf("failed to decr client: %w", err)
}
// 如果沒有連線了,設置離線狀態
if count <= 0 {
if err := m.store.SetOffline(ctx, userID); err != nil {
return fmt.Errorf("failed to set offline: %w", err)
}
}
return nil
}
// ==================== 在線狀態查詢 ====================
// IsUserOnline 檢查用戶是否在線(使用 Store
func (m *OnlineManager) IsUserOnline(ctx context.Context, userID string) (bool, error) {
if m.store == nil {
return false, ErrOnlineStoreNotConfigured
}
return m.store.IsOnline(ctx, userID)
}
// GetUsersOnlineStatus 批量獲取用戶在線狀態
func (m *OnlineManager) GetUsersOnlineStatus(ctx context.Context, userIDs []string) (map[string]bool, error) {
if m.store == nil {
return nil, ErrOnlineStoreNotConfigured
}
return m.store.GetOnlineUsers(ctx, userIDs)
}
// RefreshOnline 刷新用戶在線狀態(用於心跳)
func (m *OnlineManager) RefreshOnline(ctx context.Context, userID string) error {
if m.store == nil {
return nil
}
return m.store.SetOnline(ctx, userID, m.ttl)
}
// ==================== Centrifugo Presence API ====================
// IsUserInChannel 檢查用戶是否在指定頻道中(使用 Centrifugo Presence
func (m *OnlineManager) IsUserInChannel(ctx context.Context, userID, channel string) (bool, error) {
presence, err := m.client.Presence(ctx, channel)
if err != nil {
return false, err
}
for _, info := range presence.Presence {
if info.User == userID {
return true, nil
}
}
return false, nil
}
// GetChannelOnlineUsers 獲取頻道中的在線用戶(使用 Centrifugo Presence
func (m *OnlineManager) GetChannelOnlineUsers(ctx context.Context, channel string) ([]string, error) {
presence, err := m.client.Presence(ctx, channel)
if err != nil {
return nil, err
}
// 去重(一個用戶可能有多個連線)
userMap := make(map[string]bool)
for _, info := range presence.Presence {
userMap[info.User] = true
}
users := make([]string, 0, len(userMap))
for userID := range userMap {
users = append(users, userID)
}
return users, nil
}
// GetChannelStats 獲取頻道在線統計
func (m *OnlineManager) GetChannelStats(ctx context.Context, channel string) (*PresenceStatsResult, error) {
return m.client.PresenceStats(ctx, channel)
}

View File

@ -0,0 +1,141 @@
package centrifugo
import (
"context"
"fmt"
"strconv"
"time"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// RedisOnlineStore 使用 Redis 實作的在線狀態存儲
type RedisOnlineStore struct {
client *redis.Redis
keyPrefix string
}
// NewRedisOnlineStore 創建 Redis 在線狀態存儲
func NewRedisOnlineStore(client *redis.Redis) *RedisOnlineStore {
return &RedisOnlineStore{
client: client,
keyPrefix: "online:",
}
}
// NewRedisOnlineStoreWithPrefix 創建帶自定義前綴的 Redis 在線狀態存儲
func NewRedisOnlineStoreWithPrefix(client *redis.Redis, prefix string) *RedisOnlineStore {
return &RedisOnlineStore{
client: client,
keyPrefix: prefix,
}
}
// key 生成 Redis key
func (s *RedisOnlineStore) key(userID string) string {
return s.keyPrefix + userID
}
// clientCountKey 生成連線數 key
func (s *RedisOnlineStore) clientCountKey(userID string) string {
return s.keyPrefix + "clients:" + userID
}
// SetOnline 設置用戶在線
func (s *RedisOnlineStore) SetOnline(ctx context.Context, userID string, ttl time.Duration) error {
return s.client.SetexCtx(ctx, s.key(userID), fmt.Sprintf("%d", time.Now().Unix()), int(ttl.Seconds()))
}
// SetOffline 設置用戶離線
func (s *RedisOnlineStore) SetOffline(ctx context.Context, userID string) error {
// 刪除在線狀態
_, err := s.client.DelCtx(ctx, s.key(userID))
if err != nil {
return err
}
// 刪除連線數
_, err = s.client.DelCtx(ctx, s.clientCountKey(userID))
return err
}
// IsOnline 檢查用戶是否在線
func (s *RedisOnlineStore) IsOnline(ctx context.Context, userID string) (bool, error) {
return s.client.ExistsCtx(ctx, s.key(userID))
}
// GetOnlineUsers 批量獲取在線用戶狀態
func (s *RedisOnlineStore) GetOnlineUsers(ctx context.Context, userIDs []string) (map[string]bool, error) {
if len(userIDs) == 0 {
return make(map[string]bool), nil
}
result := make(map[string]bool, len(userIDs))
for _, userID := range userIDs {
exists, err := s.client.ExistsCtx(ctx, s.key(userID))
if err != nil {
return nil, fmt.Errorf("failed to check online status for %s: %w", userID, err)
}
result[userID] = exists
}
return result, nil
}
// IncrClient 增加用戶連線數
func (s *RedisOnlineStore) IncrClient(ctx context.Context, userID string) (int64, error) {
count, err := s.client.IncrCtx(ctx, s.clientCountKey(userID))
if err != nil {
return 0, err
}
return int64(count), nil
}
// DecrClient 減少用戶連線數
func (s *RedisOnlineStore) DecrClient(ctx context.Context, userID string) (int64, error) {
count, err := s.client.DecrCtx(ctx, s.clientCountKey(userID))
if err != nil {
return 0, err
}
// 確保不會變成負數
if count < 0 {
_ = s.client.SetCtx(ctx, s.clientCountKey(userID), "0")
return 0, nil
}
return int64(count), nil
}
// GetClientCount 獲取用戶連線數
func (s *RedisOnlineStore) GetClientCount(ctx context.Context, userID string) (int64, error) {
val, err := s.client.GetCtx(ctx, s.clientCountKey(userID))
if err != nil {
return 0, err
}
if val == "" {
return 0, nil
}
return strconv.ParseInt(val, 10, 64)
}
// GetAllOnlineUserIDs 獲取所有在線用戶 ID
// 注意:此方法使用 KEYS 命令,在大規模生產環境中可能有性能問題
// 建議在需要時使用 Centrifugo Presence API 替代
func (s *RedisOnlineStore) GetAllOnlineUserIDs(ctx context.Context) ([]string, error) {
// 使用 KEYS 查找所有在線用戶(排除 clients: 開頭的 key
pattern := s.keyPrefix + "[^c]*"
keys, err := s.client.KeysCtx(ctx, pattern)
if err != nil {
return nil, err
}
userIDs := make([]string, 0, len(keys))
prefixLen := len(s.keyPrefix)
for _, key := range keys {
if len(key) > prefixLen {
userIDs = append(userIDs, key[prefixLen:])
}
}
return userIDs, nil
}

View File

@ -0,0 +1,272 @@
package centrifugo
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/core/stores/redis"
)
func setupOnlineTestRedis(t *testing.T) (*redis.Redis, func()) {
mr, err := miniredis.Run()
require.NoError(t, err)
rds, err := redis.NewRedis(redis.RedisConf{
Host: mr.Addr(),
Type: "node",
})
require.NoError(t, err)
return rds, func() {
mr.Close()
}
}
func TestNewRedisOnlineStore(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
assert.NotNil(t, store)
assert.Equal(t, "online:", store.keyPrefix)
}
func TestNewRedisOnlineStoreWithPrefix(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStoreWithPrefix(rds, "custom:online:")
assert.NotNil(t, store)
assert.Equal(t, "custom:online:", store.keyPrefix)
}
func TestSetOnline(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
userID := "user-123"
ttl := 5 * time.Minute
// 設置在線
err := store.SetOnline(ctx, userID, ttl)
require.NoError(t, err)
// 檢查是否在線
online, err := store.IsOnline(ctx, userID)
require.NoError(t, err)
assert.True(t, online)
}
func TestSetOffline(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
userID := "user-123"
// 先設置在線
err := store.SetOnline(ctx, userID, 5*time.Minute)
require.NoError(t, err)
// 設置離線
err = store.SetOffline(ctx, userID)
require.NoError(t, err)
// 檢查是否離線
online, err := store.IsOnline(ctx, userID)
require.NoError(t, err)
assert.False(t, online)
}
func TestIsOnline_NotOnline(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
// 未設置在線的用戶應該返回 false
online, err := store.IsOnline(ctx, "non-existent-user")
require.NoError(t, err)
assert.False(t, online)
}
func TestGetOnlineUsers(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
// 設置一些用戶在線
err := store.SetOnline(ctx, "user-1", 5*time.Minute)
require.NoError(t, err)
err = store.SetOnline(ctx, "user-3", 5*time.Minute)
require.NoError(t, err)
// 批量獲取在線狀態
userIDs := []string{"user-1", "user-2", "user-3"}
status, err := store.GetOnlineUsers(ctx, userIDs)
require.NoError(t, err)
assert.True(t, status["user-1"])
assert.False(t, status["user-2"])
assert.True(t, status["user-3"])
}
func TestGetOnlineUsers_Empty(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
// 空列表應該返回空 map
status, err := store.GetOnlineUsers(ctx, []string{})
require.NoError(t, err)
assert.Empty(t, status)
}
func TestIncrClient(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
userID := "user-123"
// 第一次增加
count, err := store.IncrClient(ctx, userID)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
// 第二次增加
count, err = store.IncrClient(ctx, userID)
require.NoError(t, err)
assert.Equal(t, int64(2), count)
// 獲取連線數
count, err = store.GetClientCount(ctx, userID)
require.NoError(t, err)
assert.Equal(t, int64(2), count)
}
func TestDecrClient(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
userID := "user-123"
// 先增加到 2
_, err := store.IncrClient(ctx, userID)
require.NoError(t, err)
_, err = store.IncrClient(ctx, userID)
require.NoError(t, err)
// 減少一次
count, err := store.DecrClient(ctx, userID)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
// 再減少一次
count, err = store.DecrClient(ctx, userID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
func TestDecrClient_NegativeProtection(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
userID := "user-123"
// 直接減少(沒有先增加)
count, err := store.DecrClient(ctx, userID)
require.NoError(t, err)
assert.Equal(t, int64(0), count) // 應該被保護為 0
// 再次減少
count, err = store.DecrClient(ctx, userID)
require.NoError(t, err)
assert.Equal(t, int64(0), count) // 仍然是 0
}
func TestGetClientCount_NoClient(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
// 未設置連線數的用戶應該返回 0
count, err := store.GetClientCount(ctx, "non-existent-user")
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
func TestSetOffline_ClearsClientCount(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
ctx := context.Background()
userID := "user-123"
// 設置在線和連線數
err := store.SetOnline(ctx, userID, 5*time.Minute)
require.NoError(t, err)
_, err = store.IncrClient(ctx, userID)
require.NoError(t, err)
_, err = store.IncrClient(ctx, userID)
require.NoError(t, err)
// 設置離線
err = store.SetOffline(ctx, userID)
require.NoError(t, err)
// 連線數也應該被清除
count, err := store.GetClientCount(ctx, userID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
func TestKeyGeneration_OnlineStore(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStoreWithPrefix(rds, "test:")
// 測試 key 生成
assert.Equal(t, "test:user-123", store.key("user-123"))
assert.Equal(t, "test:clients:user-456", store.clientCountKey("user-456"))
}
func TestOnlineStore_ImplementsInterface(t *testing.T) {
rds, cleanup := setupOnlineTestRedis(t)
defer cleanup()
store := NewRedisOnlineStore(rds)
// 確保 RedisOnlineStore 實現了 OnlineStore 介面
var _ OnlineStore = store
}

View File

@ -0,0 +1,203 @@
package centrifugo
import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// TokenConfig JWT Token 配置
type TokenConfig struct {
// Secret 用於簽名的密鑰(與 Centrifugo 配置的 token_hmac_secret_key 一致)
Secret string
// ExpireIn Token 過期時間(預設 1 小時)
ExpireIn time.Duration
}
// TokenGenerator JWT Token 生成器
type TokenGenerator struct {
config TokenConfig
}
// NewTokenGenerator 創建新的 Token 生成器
func NewTokenGenerator(secret string) *TokenGenerator {
return &TokenGenerator{
config: TokenConfig{
Secret: secret,
ExpireIn: time.Hour,
},
}
}
// NewTokenGeneratorWithConfig 創建使用自定義配置的 Token 生成器
func NewTokenGeneratorWithConfig(config TokenConfig) *TokenGenerator {
if config.ExpireIn == 0 {
config.ExpireIn = time.Hour
}
return &TokenGenerator{
config: config,
}
}
// ConnectionClaims 連線 Token 的 Claims
type ConnectionClaims struct {
jwt.RegisteredClaims
// Sub 用戶 ID必填
Sub string `json:"sub"`
// Info 用戶資訊(可選,會在 presence 中顯示)
Info map[string]interface{} `json:"info,omitempty"`
// Channels 自動訂閱的頻道列表(可選)
Channels []string `json:"channels,omitempty"`
// TokenVersion Token 版本(用於批量撤銷)
TokenVersion int64 `json:"tv,omitempty"`
}
// SubscriptionClaims 訂閱 Token 的 Claims用於私有頻道
type SubscriptionClaims struct {
jwt.RegisteredClaims
// Sub 用戶 ID必填
Sub string `json:"sub"`
// Channel 頻道名稱(必填)
Channel string `json:"channel"`
// Info 頻道特定的用戶資訊(可選)
Info map[string]interface{} `json:"info,omitempty"`
// TokenVersion Token 版本(用於批量撤銷)
TokenVersion int64 `json:"tv,omitempty"`
}
// ConnectionTokenOptions 連線 Token 選項
type ConnectionTokenOptions struct {
// UserID 用戶 ID必填
UserID string
// Info 用戶資訊(可選)
Info map[string]interface{}
// Channels 自動訂閱的頻道列表(可選)
Channels []string
// ExpireAt 自定義過期時間(可選,為空則使用預設)
ExpireAt *time.Time
// TokenVersion Token 版本(可選,用於黑名單機制)
TokenVersion int64
}
// SubscriptionTokenOptions 訂閱 Token 選項
type SubscriptionTokenOptions struct {
// UserID 用戶 ID必填
UserID string
// Channel 頻道名稱(必填)
Channel string
// Info 頻道特定的用戶資訊(可選)
Info map[string]interface{}
// ExpireAt 自定義過期時間(可選,為空則使用預設)
ExpireAt *time.Time
// TokenVersion Token 版本(可選,用於黑名單機制)
TokenVersion int64
}
// TokenResult 生成 Token 的結果
type TokenResult struct {
Token string // JWT Token 字串
JTI string // JWT ID用於撤銷單一 Token
ExpiresAt time.Time // 過期時間
}
// GenerateConnectionToken 生成連線 Token
// 用於前端建立 WebSocket 連線時的身份驗證
func (g *TokenGenerator) GenerateConnectionToken(opts ConnectionTokenOptions) (string, error) {
result, err := g.GenerateConnectionTokenWithJTI(opts)
if err != nil {
return "", err
}
return result.Token, nil
}
// GenerateConnectionTokenWithJTI 生成連線 Token 並返回 JTI
// 用於需要支援單一 Token 撤銷的場景
func (g *TokenGenerator) GenerateConnectionTokenWithJTI(opts ConnectionTokenOptions) (*TokenResult, error) {
now := time.Now()
expireAt := now.Add(g.config.ExpireIn)
if opts.ExpireAt != nil {
expireAt = *opts.ExpireAt
}
// 生成唯一的 JTI
jti := uuid.New().String()
// 如果沒有指定 TokenVersion使用當前時間戳
tokenVersion := opts.TokenVersion
if tokenVersion == 0 {
tokenVersion = now.UnixNano()
}
claims := ConnectionClaims{
RegisteredClaims: jwt.RegisteredClaims{
ID: jti,
Subject: opts.UserID,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(expireAt),
},
Sub: opts.UserID,
Info: opts.Info,
Channels: opts.Channels,
TokenVersion: tokenVersion,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte(g.config.Secret))
if err != nil {
return nil, err
}
return &TokenResult{
Token: tokenStr,
JTI: jti,
ExpiresAt: expireAt,
}, nil
}
// GenerateSubscriptionToken 生成訂閱 Token
// 用於訂閱私有頻道時的身份驗證
func (g *TokenGenerator) GenerateSubscriptionToken(opts SubscriptionTokenOptions) (string, error) {
now := time.Now()
expireAt := now.Add(g.config.ExpireIn)
if opts.ExpireAt != nil {
expireAt = *opts.ExpireAt
}
claims := SubscriptionClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: opts.UserID,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(expireAt),
},
Sub: opts.UserID,
Channel: opts.Channel,
Info: opts.Info,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(g.config.Secret))
}
// GenerateAnonymousToken 生成匿名連線 Token
// 用於允許匿名用戶連線(但仍需要驗證)
func (g *TokenGenerator) GenerateAnonymousToken() (string, error) {
return g.GenerateConnectionToken(ConnectionTokenOptions{
UserID: "", // 空的 UserID 表示匿名用戶
})
}
// QuickConnectionToken 快速生成連線 Token只需用戶 ID
func (g *TokenGenerator) QuickConnectionToken(userID string) (string, error) {
return g.GenerateConnectionToken(ConnectionTokenOptions{
UserID: userID,
})
}
// QuickSubscriptionToken 快速生成訂閱 Token只需用戶 ID 和頻道)
func (g *TokenGenerator) QuickSubscriptionToken(userID, channel string) (string, error) {
return g.GenerateSubscriptionToken(SubscriptionTokenOptions{
UserID: userID,
Channel: channel,
})
}

View File

@ -0,0 +1,263 @@
package centrifugo
import (
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewTokenGenerator(t *testing.T) {
secret := "test-secret"
gen := NewTokenGenerator(secret)
assert.NotNil(t, gen)
assert.Equal(t, secret, gen.config.Secret)
assert.Equal(t, time.Hour, gen.config.ExpireIn)
}
func TestNewTokenGeneratorWithConfig(t *testing.T) {
config := TokenConfig{
Secret: "custom-secret",
ExpireIn: 24 * time.Hour,
}
gen := NewTokenGeneratorWithConfig(config)
assert.NotNil(t, gen)
assert.Equal(t, config.Secret, gen.config.Secret)
assert.Equal(t, config.ExpireIn, gen.config.ExpireIn)
}
func TestNewTokenGeneratorWithConfig_DefaultExpire(t *testing.T) {
config := TokenConfig{
Secret: "test-secret",
ExpireIn: 0, // 應該使用預設值
}
gen := NewTokenGeneratorWithConfig(config)
assert.Equal(t, time.Hour, gen.config.ExpireIn)
}
func TestQuickConnectionToken(t *testing.T) {
gen := NewTokenGenerator("test-secret")
userID := "user-123"
token, err := gen.QuickConnectionToken(userID)
require.NoError(t, err)
assert.NotEmpty(t, token)
// 驗證 Token 內容
claims := &ConnectionClaims{}
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("test-secret"), nil
})
require.NoError(t, err)
assert.True(t, parsedToken.Valid)
assert.Equal(t, userID, claims.Sub)
assert.NotNil(t, claims.ExpiresAt)
assert.NotNil(t, claims.IssuedAt)
}
func TestGenerateConnectionToken(t *testing.T) {
gen := NewTokenGenerator("test-secret")
tests := []struct {
name string
opts ConnectionTokenOptions
checkFn func(t *testing.T, claims *ConnectionClaims)
}{
{
name: "basic token",
opts: ConnectionTokenOptions{
UserID: "user-123",
},
checkFn: func(t *testing.T, claims *ConnectionClaims) {
assert.Equal(t, "user-123", claims.Sub)
assert.Nil(t, claims.Info)
assert.Nil(t, claims.Channels)
},
},
{
name: "token with info",
opts: ConnectionTokenOptions{
UserID: "user-456",
Info: map[string]interface{}{
"name": "Daniel",
"role": "admin",
},
},
checkFn: func(t *testing.T, claims *ConnectionClaims) {
assert.Equal(t, "user-456", claims.Sub)
assert.NotNil(t, claims.Info)
assert.Equal(t, "Daniel", claims.Info["name"])
assert.Equal(t, "admin", claims.Info["role"])
},
},
{
name: "token with channels",
opts: ConnectionTokenOptions{
UserID: "user-789",
Channels: []string{"chat:room-1", "chat:room-2"},
},
checkFn: func(t *testing.T, claims *ConnectionClaims) {
assert.Equal(t, "user-789", claims.Sub)
assert.Equal(t, []string{"chat:room-1", "chat:room-2"}, claims.Channels)
},
},
{
name: "token with custom expire",
opts: ConnectionTokenOptions{
UserID: "user-abc",
ExpireAt: ptrTime(time.Now().Add(48 * time.Hour)),
},
checkFn: func(t *testing.T, claims *ConnectionClaims) {
assert.Equal(t, "user-abc", claims.Sub)
// 檢查過期時間大約在 48 小時後
expireTime := claims.ExpiresAt.Time
assert.True(t, expireTime.After(time.Now().Add(47*time.Hour)))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, err := gen.GenerateConnectionToken(tt.opts)
require.NoError(t, err)
assert.NotEmpty(t, token)
// 解析 Token
claims := &ConnectionClaims{}
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("test-secret"), nil
})
require.NoError(t, err)
assert.True(t, parsedToken.Valid)
tt.checkFn(t, claims)
})
}
}
func TestQuickSubscriptionToken(t *testing.T) {
gen := NewTokenGenerator("test-secret")
userID := "user-123"
channel := "private:room-456"
token, err := gen.QuickSubscriptionToken(userID, channel)
require.NoError(t, err)
assert.NotEmpty(t, token)
// 驗證 Token 內容
claims := &SubscriptionClaims{}
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("test-secret"), nil
})
require.NoError(t, err)
assert.True(t, parsedToken.Valid)
assert.Equal(t, userID, claims.Sub)
assert.Equal(t, channel, claims.Channel)
}
func TestGenerateSubscriptionToken(t *testing.T) {
gen := NewTokenGenerator("test-secret")
opts := SubscriptionTokenOptions{
UserID: "user-123",
Channel: "private:room-456",
Info: map[string]interface{}{
"role": "moderator",
},
}
token, err := gen.GenerateSubscriptionToken(opts)
require.NoError(t, err)
assert.NotEmpty(t, token)
// 驗證 Token 內容
claims := &SubscriptionClaims{}
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("test-secret"), nil
})
require.NoError(t, err)
assert.True(t, parsedToken.Valid)
assert.Equal(t, opts.UserID, claims.Sub)
assert.Equal(t, opts.Channel, claims.Channel)
assert.Equal(t, "moderator", claims.Info["role"])
}
func TestGenerateAnonymousToken(t *testing.T) {
gen := NewTokenGenerator("test-secret")
token, err := gen.GenerateAnonymousToken()
require.NoError(t, err)
assert.NotEmpty(t, token)
// 驗證 Token 內容
claims := &ConnectionClaims{}
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("test-secret"), nil
})
require.NoError(t, err)
assert.True(t, parsedToken.Valid)
assert.Equal(t, "", claims.Sub) // 匿名用戶的 UserID 為空
}
func TestTokenExpiration(t *testing.T) {
// 創建一個很短過期時間的生成器
gen := NewTokenGeneratorWithConfig(TokenConfig{
Secret: "test-secret",
ExpireIn: 1 * time.Second,
})
token, err := gen.QuickConnectionToken("user-123")
require.NoError(t, err)
// 立即驗證應該成功
claims := &ConnectionClaims{}
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("test-secret"), nil
})
require.NoError(t, err)
assert.True(t, parsedToken.Valid)
// 等待過期
time.Sleep(2 * time.Second)
// 過期後驗證應該失敗
claims2 := &ConnectionClaims{}
_, err = jwt.ParseWithClaims(token, claims2, func(token *jwt.Token) (interface{}, error) {
return []byte("test-secret"), nil
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "token is expired")
}
func TestTokenWithWrongSecret(t *testing.T) {
gen := NewTokenGenerator("correct-secret")
token, err := gen.QuickConnectionToken("user-123")
require.NoError(t, err)
// 使用錯誤的密鑰驗證
claims := &ConnectionClaims{}
_, err = jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("wrong-secret"), nil
})
assert.Error(t, err)
}
// 輔助函數
func ptrTime(t time.Time) *time.Time {
return &t
}

14
pkg/utils/time.go Normal file
View File

@ -0,0 +1,14 @@
package utils
import "time"
// GetBucketDay 取得 bucket_dayyyyyMMdd 格式)
func GetBucketDay(t time.Time) string {
return t.Format("20060102")
}
// GetTodayBucketDay 取得今天的 bucket_day
func GetTodayBucketDay() string {
return GetBucketDay(time.Now())
}

20
pkg/utils/uuid.go Normal file
View File

@ -0,0 +1,20 @@
package utils
import (
"github.com/google/uuid"
)
//// GenerateUID 生成匿名 UID
//func GenerateUID() string {
// return fmt.Sprintf("%s%s", consts.AnonUIDPrefix, uuid.New().String()[:8])
//}
//
//// GenerateRoomID 生成房間 ID
//func GenerateRoomID() string {
// return fmt.Sprintf("%s%s", consts.RoomIDPrefix, uuid.New().String())
//}
// GenerateMessageID 生成訊息 ID
func GenerateMessageID() string {
return uuid.New().String()
}