diff --git a/deployment/centrifugo.json b/deployment/centrifugo.json new file mode 100644 index 0000000..807b47d --- /dev/null +++ b/deployment/centrifugo.json @@ -0,0 +1,44 @@ +{ + "token_hmac_secret_key": "your-secret-key-change-in-production", + "admin_password": "admin", + "admin_secret": "admin-secret", + "api_key": "api-key", + "allowed_origins": [ + "*" + ], + "log_level": "info", + "log_handler": "stdout", + "websocket_compression": true, + "websocket_read_buffer_size": 1024, + "websocket_write_buffer_size": 1024, + "namespaces": [ + { + "name": "default", + "publish": true, + "subscribe_to_publish": true, + "presence": true, + "join_leave": true, + "history_size": 100, + "history_ttl": "300s" + }, + { + "name": "room", + "publish": true, + "subscribe_to_publish": true, + "presence": true, + "join_leave": true, + "history_size": 100, + "history_ttl": "300s" + }, + { + "name": "user", + "publish": false, + "subscribe_to_publish": false, + "presence": false, + "join_leave": false, + "history_size": 0, + "history_ttl": "0s" + } + ], + "redis_address": "redis:6379" +} \ No newline at end of file diff --git a/deployment/docker-compose.yaml b/deployment/docker-compose.yaml index cf0f4eb..ae092cb 100644 --- a/deployment/docker-compose.yaml +++ b/deployment/docker-compose.yaml @@ -27,7 +27,6 @@ services: restart: always ports: - "6379:6379" - minio: image: minio/minio container_name: minio @@ -39,3 +38,36 @@ services: MINIO_ROOT_PASSWORD: minioadmin # Replace with your desired root password # MINIO_DEFAULT_BUCKETS: mybucket # Optional: Create a default bucket on startup command: server /data --console-address ":9001" # Start MinIO server and specify console address + centrifugo: + image: centrifugo/centrifugo:v5 + container_name: centrifugo + restart: always + ports: + - "8000:8000" # HTTP API + - "8001:8001" # WebSocket + volumes: + - ./centrifugo.json:/centrifugo/config.json:ro + command: centrifugo --config=/centrifugo/config.json + healthcheck: + test: ["CMD", "wget", "--quiet", "--tries=1", "--spider", "http://localhost:8000/health"] + interval: 10s + timeout: 5s + retries: 3 + depends_on: + - redis + cassandra: + image: cassandra:5.0.4 + restart: always + ports: + - "9042:9042" + environment: + TZ: ${TIMEZONE:-UTC} + MAX_HEAP_SIZE: 4G + HEAP_NEWSIZE: 2G + healthcheck: + test: [ "CMD", "cqlsh", "-k", "sccflex" ] + interval: 10s + timeout: 10s + retries: 12 + mem_limit: 8g # <--- 單機 docker-compose up 時建議明確加這行 + memswap_limit: 8g # <--- 關掉 swap \ No newline at end of file diff --git a/gateway.json b/gateway.json index b6254a3..b4382ac 100644 --- a/gateway.json +++ b/gateway.json @@ -1259,7 +1259,7 @@ "url": "https://localhost:8888" } ], - "x-date": "2025-11-12 14:59:58", + "x-date": "2026-01-05 10:01:16", "x-description": "This is a go-doc generated swagger file.", "x-generator": "go-doc", "x-github": "https://github.com/danielchan-25/go-doc", diff --git a/generate/database/cassandra/2026010611150001_chat_message.up.sql b/generate/database/cassandra/2026010611150001_chat_message.up.sql new file mode 100644 index 0000000..b9c21c1 --- /dev/null +++ b/generate/database/cassandra/2026010611150001_chat_message.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE messages ( + room_id uuid, + bucket_day date, + ts bigint, + msg_id uuid, + uid text, + content text, + PRIMARY KEY ((room_id, bucket_day), ts, msg_id) +) WITH CLUSTERING ORDER BY (ts ASC); diff --git a/generate/database/cassandra/2026010611150002_chat_message_dedupe.up.sql b/generate/database/cassandra/2026010611150002_chat_message_dedupe.up.sql new file mode 100644 index 0000000..423e300 --- /dev/null +++ b/generate/database/cassandra/2026010611150002_chat_message_dedupe.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE message_dedup ( + room_id uuid, + uid text, + content_md5 text, + bucket_sec bigint, + PRIMARY KEY ((room_id, uid), bucket_sec, content_md5) +) WITH default_time_to_live = 2; + diff --git a/go.mod b/go.mod index 897e95a..ef0c965 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/go-playground/validator/v10 v10.28.0 github.com/gocql/gocql v1.7.0 github.com/golang-jwt/jwt/v4 v4.5.2 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/matcornic/hermes/v2 v2.1.0 github.com/minchao/go-mitake v1.0.0 diff --git a/go.sum b/go.sum index 0bdeaa4..26f1bad 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,8 @@ github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus= github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= diff --git a/internal/config/config.go b/internal/config/config.go index 639607f..616d5a1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -108,4 +108,20 @@ type Config struct { SecretKey string CloudFrontID string } + + // Cassandra 配置 + Cassandra struct { + Hosts []string + Port int + Keyspace string + Username string + Password string + UseAuth bool + } + + // Centrifugo 配置 + Centrifugo struct { + APIURL string + APIKey string + } } diff --git a/internal/svc/chat.go b/internal/svc/chat.go new file mode 100644 index 0000000..0911de2 --- /dev/null +++ b/internal/svc/chat.go @@ -0,0 +1,67 @@ +package svc + +import ( + "backend/internal/config" + "backend/pkg/chat/domain/usecase" + repo "backend/pkg/chat/repository" + uc "backend/pkg/chat/usecase" + "backend/pkg/library/cassandra" + "backend/pkg/library/centrifugo" + errs "backend/pkg/library/errors" + "fmt" +) + +func MustMessageUseCase(c *config.Config, logger errs.Logger) usecase.MessageUseCase { + // 初始化 Cassandra DB + cassandraDB, err := initCassandraDB(c) + if err != nil { + panic(fmt.Sprintf("failed to initialize Cassandra DB: %v", err)) + } + + // 初始化 Message Repository + messageRepo := repo.MustMessageRepository(repo.MessageRepositoryParam{ + DB: cassandraDB, + Keyspace: c.Cassandra.Keyspace, + }) + + // 初始化 Room Repository + roomRepo := repo.MustRoomRepository(repo.RoomRepositoryParam{ + DB: cassandraDB, + Keyspace: c.Cassandra.Keyspace, + }) + + // 初始化 Centrifugo Client + msgClient := centrifugo.NewClientWithConfig(centrifugo.HighPerformanceConfig(c.Centrifugo.APIURL, c.Centrifugo.APIKey)) + + return uc.NewMessageUseCase(uc.MessageUseCaseParam{ + MessageRepo: messageRepo, + RoomRepo: roomRepo, + MsgClient: msgClient, + Logger: logger, + }) +} + +// initCassandraDB 初始化 Cassandra 資料庫連接 +func initCassandraDB(c *config.Config) (*cassandra.DB, error) { + if len(c.Cassandra.Hosts) == 0 { + return nil, fmt.Errorf("cassandra hosts are required") + } + + opts := []cassandra.Option{ + cassandra.WithHosts(c.Cassandra.Hosts...), + } + + if c.Cassandra.Port > 0 { + opts = append(opts, cassandra.WithPort(c.Cassandra.Port)) + } + + if c.Cassandra.Keyspace != "" { + opts = append(opts, cassandra.WithKeyspace(c.Cassandra.Keyspace)) + } + + if c.Cassandra.UseAuth { + opts = append(opts, cassandra.WithAuth(c.Cassandra.Username, c.Cassandra.Password)) + } + + return cassandra.New(opts...) +} diff --git a/internal/svc/service_context.go b/internal/svc/service_context.go index 0448754..bc6f412 100644 --- a/internal/svc/service_context.go +++ b/internal/svc/service_context.go @@ -9,6 +9,7 @@ import ( "github.com/zeromicro/go-zero/core/logx" + chatUC "backend/pkg/chat/domain/usecase" fileStorageUC "backend/pkg/fileStorage/domain/usecase" vi "backend/pkg/library/validator" memberUC "backend/pkg/member/domain/usecase" @@ -31,6 +32,7 @@ type ServiceContext struct { UserRoleUC tokenUC.UserRoleUseCase DeliveryUC deliveryUC.DeliveryUseCase FileStorageUC fileStorageUC.FileStorageUseCase + MessageUC chatUC.MessageUseCase Redis *redis.Redis Logger errs.Logger } @@ -62,6 +64,7 @@ func NewServiceContext(c config.Config) *ServiceContext { Redis: rds, DeliveryUC: MustDeliveryUseCase(&c, lgr), FileStorageUC: MustS3Storage(&c, lgr), + MessageUC: MustMessageUseCase(&c, lgr), Logger: lgr, } } diff --git a/pkg/chat/domain/chat/room_role.go b/pkg/chat/domain/chat/room_role.go new file mode 100644 index 0000000..d1a729c --- /dev/null +++ b/pkg/chat/domain/chat/room_role.go @@ -0,0 +1,38 @@ +package chat + +// RoomRole 聊天室成員角色 +type RoomRole string + +const ( + // RoomRoleMember 一般成員 + RoomRoleMember RoomRole = "member" + // RoomRoleAdmin 管理員 + RoomRoleAdmin RoomRole = "admin" + // RoomRoleOwner 擁有者 + RoomRoleOwner RoomRole = "owner" +) + +// String 返回角色的字串表示 +func (r RoomRole) String() string { + return string(r) +} + +// IsValid 檢查角色是否有效 +func (r RoomRole) IsValid() bool { + switch r { + case RoomRoleMember, RoomRoleAdmin, RoomRoleOwner: + return true + default: + return false + } +} + +// IsAdmin 檢查是否為管理員或擁有者 +func (r RoomRole) IsAdmin() bool { + return r == RoomRoleAdmin || r == RoomRoleOwner +} + +// IsOwner 檢查是否為擁有者 +func (r RoomRole) IsOwner() bool { + return r == RoomRoleOwner +} diff --git a/pkg/chat/domain/entity/message.go b/pkg/chat/domain/entity/message.go new file mode 100644 index 0000000..1a05fce --- /dev/null +++ b/pkg/chat/domain/entity/message.go @@ -0,0 +1,23 @@ +package entity + +import ( + "github.com/gocql/gocql" + "github.com/google/uuid" +) + +// Message 對應 Cassandra 的 messages_by_room 表 +// Primary Key: (room_id, bucket_day) +// Clustering Key: ts DESC +type Message struct { + RoomID gocql.UUID `db:"room_id" partition_key:"true"` + BucketDay string `db:"bucket_day" partition_key:"true"` // yyyyMMdd + TS int64 `db:"ts" clustering_key:"true"` // timestamp + UID string `db:"uid"` + Content string `db:"content"` + MsgID uuid.UUID `db:"msg_id"` +} + +// TableName 返回表名 +func (m Message) TableName() string { + return "messages_by_room" +} diff --git a/pkg/chat/domain/entity/message_dedup.go b/pkg/chat/domain/entity/message_dedup.go new file mode 100644 index 0000000..63f19d9 --- /dev/null +++ b/pkg/chat/domain/entity/message_dedup.go @@ -0,0 +1,21 @@ +package entity + +import ( + "github.com/gocql/gocql" +) + +// MessageDedup 對應 Cassandra 的 message_dedup 表 +// Primary Key: ((room_id, uid), bucket_sec, content_md5) +// TTL: 2 秒 +type MessageDedup struct { + RoomID gocql.UUID `db:"room_id" partition_key:"true"` + UID string `db:"uid" partition_key:"true"` + BucketSec int64 `db:"bucket_sec" clustering_key:"true"` // Unix timestamp in seconds + ContentMD5 string `db:"content_md5" clustering_key:"true"` // MD5 hash of content +} + +// TableName 返回表名 +func (m MessageDedup) TableName() string { + return "message_dedup" +} + diff --git a/pkg/chat/domain/entity/room.go b/pkg/chat/domain/entity/room.go new file mode 100644 index 0000000..98ca329 --- /dev/null +++ b/pkg/chat/domain/entity/room.go @@ -0,0 +1,59 @@ +package entity + +import "github.com/gocql/gocql" + +// Room 對應 Cassandra 的 room 表 +// Primary Key: (room_id) +// 設計說明:存儲聊天室本身的基本資訊 +type Room struct { + RoomID gocql.UUID `db:"room_id" partition_key:"true"` + Name string `db:"name"` // 聊天室名稱 + Status string `db:"status"` // 狀態:active, archived, deleted + CreatedAt int64 `db:"created_at"` // 創建時間 + UpdatedAt int64 `db:"updated_at"` // 更新時間 +} + +// TableName 返回表名 +func (r Room) TableName() string { + return "room" +} + +// RoomMember 對應 Cassandra 的 room_member 表 +// Primary Key: (room_id) +// Clustering Key: uid +// 設計說明: +// - room_id 作為 partition key,可以高效地查詢/刪除整個聊天室的所有成員 +// - uid 作為 clustering key,可以高效地查詢/刪除特定成員 +// - 支援三種操作: +// 1. 刪除整個聊天室:刪除整個 partition(需要查詢後批量刪除或使用原生 CQL) +// 2. 刪除特定成員:使用 Delete(roomID, uid) +// 3. 查詢聊天室所有成員:使用 Query().Where(Eq("room_id", roomID)).Scan() +type RoomMember struct { + RoomID gocql.UUID `db:"room_id" partition_key:"true"` + UID string `db:"uid" clustering_key:"true"` + Role string `db:"role"` // Role 角色:member(一般成員)、admin(管理員)、owner(擁有者)等 + JoinedAt int64 `db:"joined_at"` // JoinedAt 加入時間(可選,用於記錄加入時間) +} + +// TableName 返回表名 +func (m RoomMember) TableName() string { + return "room_member" +} + +// UserRoom 對應 Cassandra 的 user_room 表(反向查詢表) +// Primary Key: (uid) +// Clustering Key: room_id +// 設計說明: +// - uid 作為 partition key,可以高效地查詢某個用戶所在的所有聊天室 +// - room_id 作為 clustering key,可以高效地查詢/刪除特定關聯 +// - 這個表用於支援「查詢用戶在哪些聊天室中」的需求 +type UserRoom struct { + UID string `db:"uid" partition_key:"true"` + RoomID gocql.UUID `db:"room_id" clustering_key:"true"` + JoinedAt int64 `db:"joined_at"` // 加入時間 +} + +// TableName 返回表名 +func (u UserRoom) TableName() string { + return "user_room" +} diff --git a/pkg/chat/domain/repository/message.go b/pkg/chat/domain/repository/message.go new file mode 100644 index 0000000..3ada9cf --- /dev/null +++ b/pkg/chat/domain/repository/message.go @@ -0,0 +1,41 @@ +package repository + +import ( + "backend/pkg/chat/domain/entity" + "context" +) + +// MessageRepository 定義訊息相關的資料存取介面 +type MessageRepository interface { + // Insert 插入訊息 + Insert(ctx context.Context, msg *entity.Message) error + // ListMessages 查詢訊息列表(分頁) + ListMessages(ctx context.Context, param ListMessagesReq) ([]entity.Message, error) + // Count 計算符合條件的訊息總數 + Count(ctx context.Context, RoomID string) (int64, error) + // CheckAndInsertDedup 檢查並插入去重記錄,如果已存在則返回 true(表示重複) + CheckAndInsertDedup(ctx context.Context, param CheckDupReq) (bool, error) +} + +type SendMessageReq struct { + RoomID string + UID string + Content string + ClientMsgID string +} + +type ListMessagesReq struct { + RoomID string + BucketDay string + PageSize int + // LastTS 用於 cursor-based pagination,獲取 ts < LastTS 的訊息 + // 如果為 0,則獲取最新的訊息 + LastTS int64 +} + +type CheckDupReq struct { + RoomID string + UID string + BucketSec int64 + ContentMD5 string +} diff --git a/pkg/chat/domain/repository/room.go b/pkg/chat/domain/repository/room.go new file mode 100644 index 0000000..418923e --- /dev/null +++ b/pkg/chat/domain/repository/room.go @@ -0,0 +1,51 @@ +package repository + +import ( + "backend/pkg/chat/domain/entity" + "context" +) + +type RoomRepository interface { + Room + Member + User +} + +// ListRoomsReq 查詢聊天室列表的請求參數 +type ListRoomsReq struct { + Status string // 聊天室狀態(active, archived 等) + PageSize int // 每頁大小 + LastID string // 用於 cursor-based pagination +} + +// CountRoomsReq 統計聊天室的請求參數 +type CountRoomsReq struct { + Status string // 可選的狀態篩選 +} + +type Room interface { + Create(ctx context.Context, room *entity.Room) error // Create 創建聊天室 + RoomGet(ctx context.Context, roomID string) (*entity.Room, error) // Get 獲取聊天室資訊 + RoomUpdate(ctx context.Context, room *entity.Room) error // Update 更新聊天室資訊 + RoomDelete(ctx context.Context, roomID string) error // Delete 刪除聊天室(同時需要刪除相關的成員和訊息) + RoomList(ctx context.Context, param ListRoomsReq) ([]entity.Room, error) // List 查詢聊天室列表(支援分頁和篩選) + RoomCount(ctx context.Context, param CountRoomsReq) (int64, error) // Count 統計聊天室總數 + RoomExists(ctx context.Context, roomID string) (bool, error) // Exists 檢查聊天室是否存在 + RoomGetByID(ctx context.Context, roomIDs []string) ([]entity.Room, error) // 取得 Room by id +} + +type Member interface { + Insert(ctx context.Context, member *entity.RoomMember) error // Insert 添加成員到聊天室 + Get(ctx context.Context, roomID, uid string) (*entity.RoomMember, error) // Get 獲取特定成員資訊 + AllMembers(ctx context.Context, roomID string) ([]entity.RoomMember, error) // AllMembers 查詢聊天室所有成員 + UpdateRole(ctx context.Context, member *entity.RoomMember) error // UpdateRole 更新成員資訊(例如更新角色) + DeleteMember(ctx context.Context, roomID, uid string) error // DeleteMember 刪除特定成員(某人退出聊天室) + DeleteRoom(ctx context.Context, roomID string) error // DeleteRoom 刪除整個聊天室的所有成員 + Count(ctx context.Context, roomID string) (int64, error) // Count 計算聊天室成員數量 +} + +type User interface { + GetUserRooms(ctx context.Context, uid string) ([]entity.UserRoom, error) // GetUserRooms 查詢用戶所在的所有聊天室 + CountUserRooms(ctx context.Context, uid string) (int64, error) // CountUserRooms 統計用戶所在的聊天室數量 + IsUserInRoom(ctx context.Context, uid, roomID string) (bool, error) // IsUserInRoom 檢查用戶是否在某個聊天室中 +} diff --git a/pkg/chat/domain/usecase/message.go b/pkg/chat/domain/usecase/message.go new file mode 100644 index 0000000..d7ed87a --- /dev/null +++ b/pkg/chat/domain/usecase/message.go @@ -0,0 +1,35 @@ +package usecase + +import ( + "context" +) + +// MessageUseCase 定義訊息相關的業務邏輯介面 +type MessageUseCase interface { + // SendMessage 發送訊息 + SendMessage(ctx context.Context, param SendMessageReq) error + // ListMessages 查詢訊息列表(分頁) + ListMessages(ctx context.Context, req ListMessagesReq) ([]Message, int64, error) +} + +type SendMessageReq struct { + RoomID string + UID string + Content string +} + +type ListMessagesReq struct { + RoomID string + UID string + BucketDay string + PageSize int64 + LastTS int64 +} + +type Message struct { + RoomID string `json:"room_id"` + BucketDay string `json:"bucket_day"` // yyyyMMdd + TS int64 `json:"ts"` // timestamp + UID string `json:"uid"` + Content string `json:"content"` +} diff --git a/pkg/chat/repository/message.go b/pkg/chat/repository/message.go new file mode 100644 index 0000000..8ea826f --- /dev/null +++ b/pkg/chat/repository/message.go @@ -0,0 +1,165 @@ +package repository + +import ( + "backend/pkg/chat/domain/entity" + "backend/pkg/chat/domain/repository" + "backend/pkg/library/cassandra" + "context" + "fmt" + "time" + + "github.com/gocql/gocql" +) + +type messageRepository struct { + repo cassandra.Repository[entity.Message] + dedupRepo cassandra.Repository[entity.MessageDedup] + db *cassandra.DB + keyspace string +} + +// MessageRepositoryParam 創建 MessageRepository 所需的參數 +type MessageRepositoryParam struct { + DB *cassandra.DB + Keyspace string +} + +// MustMessageRepository 創建 MessageRepository(如果失敗會 panic) +func MustMessageRepository(param MessageRepositoryParam) repository.MessageRepository { + repo, err := NewMessageRepository(param.DB, param.Keyspace) + if err != nil { + panic(fmt.Sprintf("failed to create message repository: %v", err)) + } + return repo +} + +// NewMessageRepository 創建新的訊息 Repository +func NewMessageRepository(db *cassandra.DB, keyspace string) (repository.MessageRepository, error) { + repo, err := cassandra.NewRepository[entity.Message](db, keyspace) + if err != nil { + return nil, err + } + + dedupRepo, err := cassandra.NewRepository[entity.MessageDedup](db, keyspace) + if err != nil { + return nil, err + } + + return &messageRepository{ + repo: repo, + dedupRepo: dedupRepo, + db: db, + keyspace: keyspace, + }, nil +} + +func (message *messageRepository) Insert(ctx context.Context, msg *entity.Message) error { + now := time.Now().UTC() + if msg.TS == 0 { + msg.TS = now.UnixNano() + } + // 只在 BucketDay 為空時才自動設置,保留先前傳入的值 + if msg.BucketDay == "" { + msg.BucketDay = now.Format(time.DateOnly) + } + + return message.repo.Insert(ctx, *msg) +} + +func (message *messageRepository) ListMessages(ctx context.Context, param repository.ListMessagesReq) ([]entity.Message, error) { + // 設定預設分頁大小 + if param.PageSize <= 0 { + param.PageSize = 20 + } + + // 將字串 RoomID 轉換為 UUID + roomUUID, err := gocql.ParseUUID(param.RoomID) + if err != nil { + return nil, err + } + + // 構建查詢條件 + query := message.repo.Query(). + Where(cassandra.Eq("room_id", roomUUID)). + Where(cassandra.Eq("bucket_day", param.BucketDay)) + + // 使用 cursor-based pagination:如果提供了 LastTS,則查詢 ts < LastTS 的訊息 + // 因為排序是 DESC,所以使用 < 來獲取更早的訊息(下一頁) + if param.LastTS > 0 { + query = query.Where(cassandra.Lt("ts", param.LastTS)) + } + + // 添加排序和限制 + query = query. + OrderBy("ts", cassandra.DESC). + Limit(int(param.PageSize)) + + // 執行查詢 + var messages []entity.Message + if err := query.Scan(ctx, &messages); err != nil { + return nil, err + } + + return messages, nil +} + +func (message *messageRepository) Count(ctx context.Context, roomID string) (int64, error) { + // 將字串 RoomID 轉換為 UUID + roomUUID, err := gocql.ParseUUID(roomID) + if err != nil { + return 0, err + } + + // 注意:由於 partition key 是 (room_id, bucket_day),只用 room_id 查詢需要 ALLOW FILTERING + // 這在生產環境中效能較差,建議改用按 bucket_day 分別查詢後加總 + count, err := message.repo.Query(). + Where(cassandra.Eq("room_id", roomUUID)). + AllowFiltering(). + Count(ctx) + if err != nil { + return 0, err + } + return count, nil +} + +// CheckAndInsertDedup 檢查並插入去重記錄 +// 使用 IF NOT EXISTS 來實現原子性的去重檢查 +// 返回值:true 表示已存在(重複),false 表示成功插入(不重複) +func (message *messageRepository) CheckAndInsertDedup(ctx context.Context, param repository.CheckDupReq) (bool, error) { + // 將字串 RoomID 轉換為 UUID + roomUUID, err := gocql.ParseUUID(param.RoomID) + if err != nil { + return false, err + } + + // 使用 IF NOT EXISTS 來實現原子性的去重檢查 + // 如果記錄已存在,INSERT 不會插入,且 applied = false + dedup := entity.MessageDedup{ + RoomID: roomUUID, + UID: param.UID, + BucketSec: param.BucketSec, + ContentMD5: param.ContentMD5, + } + + // 使用原生 CQL 語句來實現 IF NOT EXISTS + tableName := dedup.TableName() + stmt := fmt.Sprintf( + "INSERT INTO %s.%s (room_id, uid, bucket_sec, content_md5) VALUES (?, ?, ?, ?) IF NOT EXISTS", + message.keyspace, + tableName, + ) + + // 執行 INSERT IF NOT EXISTS + applied, err := message.db.GetSession().Query(stmt, nil). + Bind(roomUUID, param.UID, param.BucketSec, param.ContentMD5). + WithContext(ctx). + MapScanCAS(make(map[string]interface{})) + + if err != nil { + return false, fmt.Errorf("failed to check dedup: %w", err) + } + + // applied = false 表示記錄已存在(重複) + // applied = true 表示成功插入(不重複) + return !applied, nil +} diff --git a/pkg/chat/repository/message_test.go b/pkg/chat/repository/message_test.go new file mode 100644 index 0000000..2c8fc68 --- /dev/null +++ b/pkg/chat/repository/message_test.go @@ -0,0 +1,526 @@ +package repository + +import ( + "backend/pkg/chat/domain/entity" + "backend/pkg/chat/domain/repository" + "backend/pkg/library/cassandra" + "context" + "os" + "strconv" + "testing" + "time" + + "github.com/gocql/gocql" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +var ( + // 全局變量:所有測試共享的資源 + testDB *cassandra.DB + testRepo repository.MessageRepository + testContainer testcontainers.Container + cleanupContainer func() +) + +// TestMain 在所有測試之前執行一次,設置共享的資料庫 +func TestMain(m *testing.M) { + // 使用更長的 context 超時時間(Cassandra 啟動需要較長時間) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + // 啟動 Cassandra 容器(只執行一次) + // 不指定固定 port,讓 testcontainers 自動分配,避免與本地開發環境衝突 + req := testcontainers.ContainerRequest{ + Image: "cassandra:4.1", + ExposedPorts: []string{"9042/tcp"}, + // 等待 Cassandra 日誌確認啟動完成(Cassandra 啟動需要較長時間) + WaitingFor: wait.ForLog("Startup complete"). + WithStartupTimeout(3 * time.Minute). + WithOccurrence(1), + Env: map[string]string{ + "CASSANDRA_CLUSTER_NAME": "test-cluster", + "HEAP_NEWSIZE": "128M", + "MAX_HEAP_SIZE": "512M", + }, + } + + var err error + testContainer, err = testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + panic("Failed to start Cassandra container: " + err.Error()) + } + + port, err := testContainer.MappedPort(ctx, "9042") + if err != nil { + testContainer.Terminate(ctx) + panic("Failed to get mapped port: " + err.Error()) + } + + host, err := testContainer.Host(ctx) + if err != nil { + testContainer.Terminate(ctx) + panic("Failed to get host: " + err.Error()) + } + + portInt, err := strconv.Atoi(port.Port()) + if err != nil { + testContainer.Terminate(ctx) + panic("Failed to convert port to int: " + err.Error()) + } + + // 先創建 DB 連接(不指定 keyspace,因為 keyspace 還不存在) + testDB, err = cassandra.New( + cassandra.WithHosts(host), + cassandra.WithPort(portInt), + // 不指定 keyspace,先連接後再創建 + ) + if err != nil { + testContainer.Terminate(ctx) + panic("Failed to create DB: " + err.Error()) + } + + // 創建 keyspace(需要在連接後才能創建) + createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" + if err := testDB.GetSession().Query(createKeyspaceStmt, nil).Exec(); err != nil { + testDB.Close() + testContainer.Terminate(ctx) + panic("Failed to create keyspace: " + err.Error()) + } + + // 等待 keyspace 創建完成 + time.Sleep(500 * time.Millisecond) + + // 創建 messages_by_room 表(需要指定 keyspace) + createTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.messages_by_room ( + room_id uuid, + bucket_day text, + ts bigint, + uid text, + content text, + PRIMARY KEY ((room_id, bucket_day), ts) + ) WITH CLUSTERING ORDER BY (ts DESC)` + if err := testDB.GetSession().Query(createTableStmt, nil).Exec(); err != nil { + testDB.Close() + testContainer.Terminate(ctx) + panic("Failed to create table: " + err.Error()) + } + + // 創建 messages repository + testRepo, err = NewMessageRepository(testDB, "test_keyspace") + if err != nil { + testDB.Close() + testContainer.Terminate(ctx) + panic("Failed to create message repository: " + err.Error()) + } + + // 創建 room 相關表(供 room_test.go 使用) + createRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room ( + room_id uuid PRIMARY KEY, + name text, + status text, + created_at bigint, + updated_at bigint + )` + if err := testDB.GetSession().Query(createRoomTableStmt, nil).Exec(); err != nil { + testDB.Close() + testContainer.Terminate(ctx) + panic("Failed to create room table: " + err.Error()) + } + + createMemberTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room_member ( + room_id uuid, + uid text, + role text, + joined_at bigint, + PRIMARY KEY (room_id, uid) + )` + if err := testDB.GetSession().Query(createMemberTableStmt, nil).Exec(); err != nil { + testDB.Close() + testContainer.Terminate(ctx) + panic("Failed to create room_member table: " + err.Error()) + } + + createUserRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.user_room ( + uid text, + room_id uuid, + joined_at bigint, + PRIMARY KEY (uid, room_id) + )` + if err := testDB.GetSession().Query(createUserRoomTableStmt, nil).Exec(); err != nil { + testDB.Close() + testContainer.Terminate(ctx) + panic("Failed to create user_room table: " + err.Error()) + } + + // 設置清理函數 + cleanupContainer = func() { + if testDB != nil { + testDB.Close() + } + if testContainer != nil { + _ = testContainer.Terminate(ctx) + } + } + + // 執行所有測試 + code := m.Run() + + // 清理資源 + cleanupContainer() + + // 退出 + os.Exit(code) +} + +// clearMessages 清空 messages_by_room 表的所有數據 +func clearMessages(t *testing.T) { + ctx := context.Background() + // 使用完整的 keyspace.table 名稱 + truncateStmt := "TRUNCATE test_keyspace.messages_by_room" + if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil { + t.Fatalf("Failed to truncate messages_by_room table: %v", err) + } + // 等待數據清空 + time.Sleep(50 * time.Millisecond) +} + +func TestNewMessageRepository(t *testing.T) { + clearMessages(t) + assert.NotNil(t, testRepo) +} + +func TestMessageRepository_Insert(t *testing.T) { + clearMessages(t) + ctx := context.Background() + + tests := []struct { + name string + message *entity.Message + wantErr bool + }{ + { + name: "successful insert", + message: &entity.Message{ + RoomID: gocql.TimeUUID(), + BucketDay: "20250117", + TS: time.Now().UnixNano(), + UID: "user-1", + Content: "Hello, world!", + }, + wantErr: false, + }, + { + name: "insert with auto timestamp", + message: &entity.Message{ + RoomID: gocql.TimeUUID(), + BucketDay: "20250117", + UID: "user-2", + Content: "Test message", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := testRepo.Insert(ctx, tt.message) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + // 驗證 TS 和 BucketDay 是否被自動設置 + if tt.message.TS == 0 { + assert.NotZero(t, tt.message.TS) + } + if tt.message.BucketDay == "" { + assert.NotEmpty(t, tt.message.BucketDay) + } + } + }) + } +} + +func TestMessageRepository_ListMessages(t *testing.T) { + clearMessages(t) + ctx := context.Background() + roomID := gocql.TimeUUID() + roomIDStr := roomID.String() + bucketDay := time.Now().Format(time.DateOnly) + + // 插入測試數據(使用相同的 roomID) + messages := []*entity.Message{ + {RoomID: roomID, BucketDay: bucketDay, TS: time.Now().UnixNano(), UID: "user-1", Content: "Message 1"}, + {RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Second).UnixNano(), UID: "user-2", Content: "Message 2"}, + {RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(2 * time.Second).UnixNano(), UID: "user-3", Content: "Message 3"}, + } + + for _, msg := range messages { + require.NoError(t, testRepo.Insert(ctx, msg)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + param repository.ListMessagesReq + wantLen int + wantErr bool + }{ + { + name: "list all messages", + param: repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: bucketDay, + PageSize: 10, + LastTS: 0, + }, + wantLen: 3, + wantErr: false, + }, + { + name: "list with page size limit", + param: repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: bucketDay, + PageSize: 2, + LastTS: 0, + }, + wantLen: 2, + wantErr: false, + }, + { + name: "list with cursor pagination", + param: repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: bucketDay, + PageSize: 10, + LastTS: messages[1].TS, // 使用第二條訊息的 TS 作為 cursor + }, + wantLen: 1, // 應該只返回比 LastTS 更早的訊息 + wantErr: false, + }, + { + name: "list with default page size", + param: repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: bucketDay, + PageSize: 0, // 應該使用預設值 20 + LastTS: 0, + }, + wantLen: 3, + wantErr: false, + }, + { + name: "list empty room", + param: repository.ListMessagesReq{ + RoomID: gocql.TimeUUID().String(), // 使用不存在的 UUID + BucketDay: bucketDay, + PageSize: 10, + LastTS: 0, + }, + wantLen: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := testRepo.ListMessages(ctx, tt.param) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, result, tt.wantLen) + // 驗證排序(應該按 TS DESC 排序) + if len(result) > 1 { + for i := 0; i < len(result)-1; i++ { + assert.GreaterOrEqual(t, result[i].TS, result[i+1].TS, "Messages should be sorted by TS DESC") + } + } + } + }) + } +} + +func TestMessageRepository_Count(t *testing.T) { + clearMessages(t) + ctx := context.Background() + roomID := gocql.TimeUUID() + roomIDStr := roomID.String() + bucketDay := time.Now().Format(time.DateOnly) + + // 插入測試數據(使用相同的 roomID) + for i := 0; i < 5; i++ { + msg := &entity.Message{ + RoomID: roomID, + BucketDay: bucketDay, + TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(), + UID: "user-1", + Content: "Message", + } + require.NoError(t, testRepo.Insert(ctx, msg)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + roomID string + want int64 + wantErr bool + }{ + { + name: "count existing room", + roomID: roomIDStr, + want: 5, + wantErr: false, + }, + { + name: "count non-existent room", + roomID: gocql.TimeUUID().String(), // 使用不存在的 UUID + want: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + count, err := testRepo.Count(ctx, tt.roomID) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, count) + } + }) + } +} + +func TestMessageRepository_Insert_ListMessages_Integration(t *testing.T) { + clearMessages(t) + ctx := context.Background() + roomID := gocql.TimeUUID() + roomIDStr := roomID.String() + bucketDay := time.Now().Format(time.DateOnly) + + // 插入多條訊息(使用相同的 roomID) + messages := make([]*entity.Message, 10) + for i := 0; i < 10; i++ { + msg := &entity.Message{ + RoomID: roomID, + BucketDay: bucketDay, + TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(), + UID: "user-1", + Content: "Message", + } + messages[i] = msg + require.NoError(t, testRepo.Insert(ctx, msg)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 驗證可以查詢到所有訊息 + result, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: bucketDay, + PageSize: 20, + LastTS: 0, + }) + require.NoError(t, err) + assert.Len(t, result, 10) + + // 驗證分頁功能 + firstPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: bucketDay, + PageSize: 5, + LastTS: 0, + }) + require.NoError(t, err) + assert.Len(t, firstPage, 5) + + // 使用 cursor 獲取下一頁 + if len(firstPage) > 0 { + lastTS := firstPage[len(firstPage)-1].TS + secondPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: bucketDay, + PageSize: 5, + LastTS: lastTS, + }) + require.NoError(t, err) + assert.Len(t, secondPage, 5) + // 驗證第二頁的訊息都比第一頁的最後一條更早 + if len(secondPage) > 0 { + assert.Less(t, secondPage[0].TS, lastTS) + } + } + + // 驗證計數 + count, err := testRepo.Count(ctx, roomIDStr) + require.NoError(t, err) + assert.Equal(t, int64(10), count) +} + +func TestMessageRepository_DifferentBuckets(t *testing.T) { + clearMessages(t) + ctx := context.Background() + roomID := gocql.TimeUUID() + roomIDStr := roomID.String() + today := time.Now().Format(time.DateOnly) + yesterday := time.Now().AddDate(0, 0, -1).Format(time.DateOnly) + + // 插入不同 bucket 的訊息(使用相同的 roomID) + todayMsg := &entity.Message{ + RoomID: roomID, + BucketDay: today, + TS: time.Now().UnixNano(), + UID: "user-1", + Content: "Today's message", + } + yesterdayMsg := &entity.Message{ + RoomID: roomID, + BucketDay: yesterday, + TS: time.Now().UnixNano(), + UID: "user-1", + Content: "Yesterday's message", + } + + require.NoError(t, testRepo.Insert(ctx, todayMsg)) + require.NoError(t, testRepo.Insert(ctx, yesterdayMsg)) + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 查詢今天的訊息 + todayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: today, + PageSize: 10, + LastTS: 0, + }) + require.NoError(t, err) + assert.Len(t, todayMessages, 1) + assert.Equal(t, "Today's message", todayMessages[0].Content) + + // 查詢昨天的訊息 + yesterdayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ + RoomID: roomIDStr, + BucketDay: yesterday, + PageSize: 10, + LastTS: 0, + }) + require.NoError(t, err) + assert.Len(t, yesterdayMessages, 1) + assert.Equal(t, "Yesterday's message", yesterdayMessages[0].Content) +} diff --git a/pkg/chat/repository/room.go b/pkg/chat/repository/room.go new file mode 100644 index 0000000..68f1134 --- /dev/null +++ b/pkg/chat/repository/room.go @@ -0,0 +1,404 @@ +package repository + +import ( + "backend/pkg/chat/domain/chat" + "backend/pkg/chat/domain/entity" + "backend/pkg/chat/domain/repository" + "backend/pkg/library/cassandra" + "context" + "fmt" + "time" + + "github.com/gocql/gocql" +) + +type roomRepository struct { + roomRepo cassandra.Repository[entity.Room] + memberRepo cassandra.Repository[entity.RoomMember] + userRoomRepo cassandra.Repository[entity.UserRoom] + db *cassandra.DB + keyspace string +} + +// RoomRepositoryParam 創建 RoomRepository 所需的參數 +type RoomRepositoryParam struct { + DB *cassandra.DB + Keyspace string +} + +// MustRoomRepository 創建 RoomRepository(如果失敗會 panic) +func MustRoomRepository(param RoomRepositoryParam) repository.RoomRepository { + repo, err := NewRoomRepository(param.DB, param.Keyspace) + if err != nil { + panic(fmt.Sprintf("failed to create room repository: %v", err)) + } + return repo +} + +// NewRoomRepository 創建新的聊天室 Repository +func NewRoomRepository(db *cassandra.DB, keyspace string) (repository.RoomRepository, error) { + roomRepo, err := cassandra.NewRepository[entity.Room](db, keyspace) + if err != nil { + return nil, err + } + + memberRepo, err := cassandra.NewRepository[entity.RoomMember](db, keyspace) + if err != nil { + return nil, err + } + + userRoomRepo, err := cassandra.NewRepository[entity.UserRoom](db, keyspace) + if err != nil { + return nil, err + } + + return &roomRepository{ + roomRepo: roomRepo, + memberRepo: memberRepo, + userRoomRepo: userRoomRepo, + db: db, + keyspace: keyspace, + }, nil +} + +// ==================== Room Interface 實作 ==================== + +func (r *roomRepository) Create(ctx context.Context, room *entity.Room) error { + now := time.Now().UTC().UnixNano() + if room.CreatedAt == 0 { + room.CreatedAt = now + } + if room.UpdatedAt == 0 { + room.UpdatedAt = now + } + room.RoomID = gocql.TimeUUID() + + return r.roomRepo.Insert(ctx, *room) +} + +func (r *roomRepository) RoomGet(ctx context.Context, roomID string) (*entity.Room, error) { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return nil, err + } + room, err := r.roomRepo.Get(ctx, entity.Room{ + RoomID: uuid, + }) + if err != nil { + if cassandra.IsNotFound(err) { + return nil, cassandra.ErrNotFound + } + + return nil, err + } + + return &room, nil +} + +func (r *roomRepository) RoomUpdate(ctx context.Context, room *entity.Room) error { + update := entity.Room{} + now := time.Now().UTC().UnixNano() + + get, err := r.RoomGet(ctx, room.RoomID.String()) + if err != nil { + return err + } + + update.CreatedAt = get.CreatedAt + update.UpdatedAt = now + update.Status = room.Status + update.Name = room.Name + update.RoomID = room.RoomID + + if err = r.roomRepo.Update(ctx, update); err != nil { + return err + } + + return nil +} + +func (r *roomRepository) RoomDelete(ctx context.Context, roomID string) error { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return err + } + + // 使用原生 CQL 語句來刪除,避免 cassandra library Delete 方法的 struct 綁定問題 + e := entity.Room{} + stmt := fmt.Sprintf("DELETE FROM %s.%s WHERE room_id = ?", r.keyspace, e.TableName()) + return r.db.GetSession().Query(stmt, nil). + Bind(uuid). + WithContext(ctx). + ExecRelease() +} + +func (r *roomRepository) RoomList(ctx context.Context, param repository.ListRoomsReq) ([]entity.Room, error) { + query := r.roomRepo.Query() + + if param.Status != "" { + query = query.Where(cassandra.Eq("status", param.Status)).AllowFiltering() + } + + if param.PageSize > 0 { + query = query.Limit(param.PageSize) + } + + if param.LastID != "" { + lastUUID, err := gocql.ParseUUID(param.LastID) + if err != nil { + return nil, err + } + query = query.Where(cassandra.Lt("room_id", lastUUID)) + } + + var rooms []entity.Room + if err := query.Scan(ctx, &rooms); err != nil { + return nil, err + } + + return rooms, nil +} + +func (r *roomRepository) RoomCount(ctx context.Context, param repository.CountRoomsReq) (int64, error) { + query := r.roomRepo.Query() + if param.Status != "" { + query = query.Where(cassandra.Eq("status", param.Status)).AllowFiltering() + } + return query.Count(ctx) +} + +func (r *roomRepository) RoomExists(ctx context.Context, roomID string) (bool, error) { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return false, err + } + _, err = r.roomRepo.Get(ctx, entity.Room{ + RoomID: uuid, + }) + + if err != nil { + if cassandra.IsNotFound(err) { + return false, nil + } + + return false, err + } + + return true, nil +} + +func (r *roomRepository) RoomGetByID(ctx context.Context, roomIDs []string) ([]entity.Room, error) { + if len(roomIDs) == 0 { + return []entity.Room{}, nil + } + + // 將字串 ID 轉換為 UUID + uuids := make([]any, 0, len(roomIDs)) + for _, id := range roomIDs { + uuid, err := gocql.ParseUUID(id) + if err != nil { + return nil, err + } + uuids = append(uuids, uuid) + } + + var rooms []entity.Room + err := r.roomRepo.Query(). + Where(cassandra.In("room_id", uuids)). + Scan(ctx, &rooms) + if err != nil { + return nil, err + } + + return rooms, nil +} + +// ==================== Member Interface 實作 ==================== + +func (r *roomRepository) Insert(ctx context.Context, member *entity.RoomMember) error { + now := time.Now().UTC().UnixNano() + if member.JoinedAt == 0 { + member.JoinedAt = now + } + if member.Role == "" { + member.Role = chat.RoomRoleMember.String() + } + + // 同時插入到 user_room 表(反向查詢表) + userRoom := entity.UserRoom{ + UID: member.UID, + RoomID: member.RoomID, + JoinedAt: member.JoinedAt, + } + + if err := r.memberRepo.Insert(ctx, *member); err != nil { + return err + } + + if err := r.userRoomRepo.Insert(ctx, userRoom); err != nil { + return err + } + + return nil +} + +func (r *roomRepository) Get(ctx context.Context, roomID, uid string) (*entity.RoomMember, error) { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return nil, err + } + + member, err := r.memberRepo.Get(ctx, entity.RoomMember{ + RoomID: uuid, + UID: uid, + }) + if err != nil { + if cassandra.IsNotFound(err) { + return nil, cassandra.ErrNotFound + } + return nil, err + } + return &member, nil +} + +func (r *roomRepository) AllMembers(ctx context.Context, roomID string) ([]entity.RoomMember, error) { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return nil, err + } + + var members []entity.RoomMember + err = r.memberRepo.Query(). + Where(cassandra.Eq("room_id", uuid)). + Scan(ctx, &members) + if err != nil { + return nil, err + } + return members, nil +} + +func (r *roomRepository) UpdateRole(ctx context.Context, member *entity.RoomMember) error { + get, err := r.Get(ctx, member.RoomID.String(), member.UID) + if err != nil { + return err + } + + update := entity.RoomMember{ + RoomID: member.RoomID, + UID: member.UID, + Role: member.Role, + JoinedAt: get.JoinedAt, + } + + return r.memberRepo.Update(ctx, update) +} + +func (r *roomRepository) DeleteMember(ctx context.Context, roomID, uid string) error { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return err + } + + // 同時從兩個表中刪除 + if err := r.memberRepo.Delete(ctx, entity.RoomMember{ + RoomID: uuid, + UID: uid, + }); err != nil { + return err + } + + if err := r.userRoomRepo.Delete(ctx, entity.UserRoom{ + UID: uid, + RoomID: uuid, + }); err != nil { + return err + } + + return nil +} + +func (r *roomRepository) DeleteRoom(ctx context.Context, roomID string) error { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return err + } + + // 先查詢所有成員(必須在刪除之前查詢) + members, err := r.AllMembers(ctx, roomID) + if err != nil { + return err + } + + // 刪除 user_room 表中的所有關聯 + for _, member := range members { + if err := r.userRoomRepo.Delete(ctx, entity.UserRoom{ + UID: member.UID, + RoomID: uuid, + }); err != nil { + return err + } + } + + // 最後刪除 room_member 表中的所有成員 + e := entity.RoomMember{} + stmt := fmt.Sprintf("DELETE FROM %s.%s WHERE room_id = ?", r.keyspace, e.TableName()) + if err := r.db.GetSession().Query(stmt, nil). + Bind(uuid). + WithContext(ctx). + WithTimestamp(time.Now().UnixNano() / 1e3). + ExecRelease(); err != nil { + return err + } + + return nil +} + +func (r *roomRepository) Count(ctx context.Context, roomID string) (int64, error) { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return 0, err + } + + return r.memberRepo.Query(). + Where(cassandra.Eq("room_id", uuid)). + Count(ctx) +} + +// ==================== User Interface 實作 ==================== + +func (r *roomRepository) GetUserRooms(ctx context.Context, uid string) ([]entity.UserRoom, error) { + var userRooms []entity.UserRoom + err := r.userRoomRepo.Query(). + Where(cassandra.Eq("uid", uid)). + Scan(ctx, &userRooms) + if err != nil { + return nil, err + } + + return userRooms, nil +} + +func (r *roomRepository) CountUserRooms(ctx context.Context, uid string) (int64, error) { + return r.userRoomRepo.Query(). + Where(cassandra.Eq("uid", uid)). + Count(ctx) +} + +func (r *roomRepository) IsUserInRoom(ctx context.Context, uid, roomID string) (bool, error) { + uuid, err := gocql.ParseUUID(roomID) + if err != nil { + return false, err + } + + _, err = r.userRoomRepo.Get(ctx, entity.UserRoom{ + UID: uid, + RoomID: uuid, + }) + if err != nil { + if cassandra.IsNotFound(err) { + return false, nil + } + return false, err + } + return true, nil +} diff --git a/pkg/chat/repository/room_test.go b/pkg/chat/repository/room_test.go new file mode 100644 index 0000000..b75b13c --- /dev/null +++ b/pkg/chat/repository/room_test.go @@ -0,0 +1,953 @@ +package repository + +import ( + "backend/pkg/chat/domain/chat" + "backend/pkg/chat/domain/entity" + "backend/pkg/chat/domain/repository" + "backend/pkg/library/cassandra" + "context" + "strconv" + "testing" + "time" + + "github.com/gocql/gocql" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + // 全局變量:所有測試共享的資源(與 message_test.go 共享) + roomTestRepo repository.RoomRepository +) + +// ensureRoomRepository 確保 Room Repository 已初始化 +func ensureRoomRepository(t *testing.T) { + if roomTestRepo == nil { + if testDB == nil { + t.Fatal("testDB is not initialized. Make sure TestMain in message_test.go runs first.") + } + var err error + roomTestRepo, err = NewRoomRepository(testDB, "test_keyspace") + if err != nil { + t.Fatalf("Failed to create room repository: %v", err) + } + } +} + +// clearRoomData 清空所有 room 相關表的數據 +func clearRoomData(t *testing.T) { + ctx := context.Background() + tables := []string{"user_room", "room_member", "room"} + for _, table := range tables { + truncateStmt := "TRUNCATE test_keyspace." + table + if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil { + t.Fatalf("Failed to truncate %s table: %v", table, err) + } + } + // 等待數據清空 + time.Sleep(50 * time.Millisecond) +} + +// ==================== Room Interface 測試 ==================== + +func TestNewRoomRepository(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + assert.NotNil(t, roomTestRepo) +} + +func TestRoomRepository_Create(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + tests := []struct { + name string + room *entity.Room + wantErr bool + }{ + { + name: "successful create", + room: &entity.Room{ + Name: "Test Room", + Status: "active", + }, + wantErr: false, + }, + { + name: "create with auto timestamp", + room: &entity.Room{ + Name: "Auto Timestamp Room", + Status: "active", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := roomTestRepo.Create(ctx, tt.room) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + // 驗證 RoomID 和時間戳是否被自動設置 + assert.NotZero(t, tt.room.RoomID) + assert.NotZero(t, tt.room.CreatedAt) + assert.NotZero(t, tt.room.UpdatedAt) + } + }) + } +} + +func TestRoomRepository_RoomGet(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "Get Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + roomIDStr := room.RoomID.String() + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + roomID string + wantErr bool + }{ + { + name: "get existing room", + roomID: roomIDStr, + wantErr: false, + }, + { + name: "get non-existent room", + roomID: gocql.TimeUUID().String(), + wantErr: true, + }, + { + name: "get with invalid UUID", + roomID: "invalid-uuid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := roomTestRepo.RoomGet(ctx, tt.roomID) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, room.Name, result.Name) + assert.Equal(t, room.Status, result.Status) + } + }) + } +} + +func TestRoomRepository_RoomUpdate(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "Original Name", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + originalCreatedAt := room.CreatedAt + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 更新房間 + room.Name = "Updated Name" + room.Status = "archived" + err := roomTestRepo.RoomUpdate(ctx, room) + require.NoError(t, err) + + // 驗證更新 + updated, err := roomTestRepo.RoomGet(ctx, room.RoomID.String()) + require.NoError(t, err) + assert.Equal(t, "Updated Name", updated.Name) + assert.Equal(t, "archived", updated.Status) + assert.Equal(t, originalCreatedAt, updated.CreatedAt) // CreatedAt 不應該改變 + assert.Greater(t, updated.UpdatedAt, originalCreatedAt) // UpdatedAt 應該更新 +} + +func TestRoomRepository_RoomDelete(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "Delete Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + roomIDStr := room.RoomID.String() + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 刪除房間 + err := roomTestRepo.RoomDelete(ctx, roomIDStr) + require.NoError(t, err) + + // 驗證房間已被刪除 + _, err = roomTestRepo.RoomGet(ctx, roomIDStr) + assert.Error(t, err) + assert.True(t, cassandra.IsNotFound(err)) +} + +func TestRoomRepository_RoomList(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建多個測試房間 + rooms := make([]*entity.Room, 5) + for i := 0; i < 5; i++ { + room := &entity.Room{ + Name: "Room " + strconv.Itoa(i), + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + rooms[i] = room + time.Sleep(10 * time.Millisecond) // 確保時間戳不同 + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + param repository.ListRoomsReq + wantLen int + wantErr bool + }{ + { + name: "list all rooms", + param: repository.ListRoomsReq{ + PageSize: 10, + }, + wantLen: 5, + wantErr: false, + }, + { + name: "list with status filter", + param: repository.ListRoomsReq{ + Status: "active", + PageSize: 10, + }, + wantLen: 5, + wantErr: false, + }, + { + name: "list with page size limit", + param: repository.ListRoomsReq{ + PageSize: 2, + }, + wantLen: 2, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := roomTestRepo.RoomList(ctx, tt.param) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, result, tt.wantLen) + } + }) + } +} + +func TestRoomRepository_RoomCount(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + for i := 0; i < 3; i++ { + room := &entity.Room{ + Name: "Count Room " + strconv.Itoa(i), + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + } + + // 創建一個 archived 房間 + archivedRoom := &entity.Room{ + Name: "Archived Room", + Status: "archived", + } + require.NoError(t, roomTestRepo.Create(ctx, archivedRoom)) + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + param repository.CountRoomsReq + want int64 + wantErr bool + }{ + { + name: "count all rooms", + param: repository.CountRoomsReq{}, + want: 4, + }, + { + name: "count active rooms", + param: repository.CountRoomsReq{ + Status: "active", + }, + want: 3, + }, + { + name: "count archived rooms", + param: repository.CountRoomsReq{ + Status: "archived", + }, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + count, err := roomTestRepo.RoomCount(ctx, tt.param) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, count) + } + }) + } +} + +func TestRoomRepository_RoomExists(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "Exists Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + roomIDStr := room.RoomID.String() + nonExistentID := gocql.TimeUUID().String() + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + roomID string + want bool + wantErr bool + }{ + { + name: "existing room", + roomID: roomIDStr, + want: true, + wantErr: false, + }, + { + name: "non-existent room", + roomID: nonExistentID, + want: false, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + exists, err := roomTestRepo.RoomExists(ctx, tt.roomID) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, exists) + } + }) + } +} + +func TestRoomRepository_RoomGetByID(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建多個測試房間 + rooms := make([]*entity.Room, 3) + roomIDs := make([]string, 3) + for i := 0; i < 3; i++ { + room := &entity.Room{ + Name: "Room " + strconv.Itoa(i), + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + rooms[i] = room + roomIDs[i] = room.RoomID.String() + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + roomIDs []string + wantLen int + wantErr bool + }{ + { + name: "get multiple rooms", + roomIDs: roomIDs, + wantLen: 3, + wantErr: false, + }, + { + name: "get empty list", + roomIDs: []string{}, + wantLen: 0, + wantErr: false, + }, + { + name: "get with non-existent ID", + roomIDs: []string{roomIDs[0], gocql.TimeUUID().String()}, + wantLen: 1, // 只返回存在的房間 + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := roomTestRepo.RoomGetByID(ctx, tt.roomIDs) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, result, tt.wantLen) + } + }) + } +} + +// ==================== Member Interface 測試 ==================== + +func TestRoomRepository_Insert(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "Member Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + tests := []struct { + name string + member *entity.RoomMember + wantErr bool + }{ + { + name: "successful insert", + member: &entity.RoomMember{ + RoomID: room.RoomID, + UID: "user-1", + Role: chat.RoomRoleMember.String(), + }, + wantErr: false, + }, + { + name: "insert with auto role", + member: &entity.RoomMember{ + RoomID: room.RoomID, + UID: "user-2", + }, + wantErr: false, + }, + { + name: "insert with admin role", + member: &entity.RoomMember{ + RoomID: room.RoomID, + UID: "user-3", + Role: chat.RoomRoleAdmin.String(), + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := roomTestRepo.Insert(ctx, tt.member) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + // 驗證 JoinedAt 是否被自動設置 + assert.NotZero(t, tt.member.JoinedAt) + // 驗證 Role 是否被自動設置(如果為空) + if tt.member.Role == "" { + assert.Equal(t, chat.RoomRoleMember.String(), tt.member.Role) + } + + // 驗證 user_room 表也被更新 + userRooms, err := roomTestRepo.GetUserRooms(ctx, tt.member.UID) + require.NoError(t, err) + assert.Greater(t, len(userRooms), 0) + } + }) + } +} + +func TestRoomRepository_Get(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間和成員 + room := &entity.Room{ + Name: "Get Member Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + member := &entity.RoomMember{ + RoomID: room.RoomID, + UID: "user-1", + Role: chat.RoomRoleAdmin.String(), + } + require.NoError(t, roomTestRepo.Insert(ctx, member)) + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + roomID string + uid string + wantErr bool + }{ + { + name: "get existing member", + roomID: room.RoomID.String(), + uid: "user-1", + wantErr: false, + }, + { + name: "get non-existent member", + roomID: room.RoomID.String(), + uid: "non-existent-user", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := roomTestRepo.Get(ctx, tt.roomID, tt.uid) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, member.UID, result.UID) + assert.Equal(t, member.Role, result.Role) + } + }) + } +} + +func TestRoomRepository_AllMembers(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "All Members Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + // 插入多個成員 + members := []*entity.RoomMember{ + {RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleMember.String()}, + {RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()}, + {RoomID: room.RoomID, UID: "user-3", Role: chat.RoomRoleMember.String()}, + } + + for _, member := range members { + require.NoError(t, roomTestRepo.Insert(ctx, member)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 查詢所有成員 + result, err := roomTestRepo.AllMembers(ctx, room.RoomID.String()) + require.NoError(t, err) + assert.Len(t, result, 3) +} + +func TestRoomRepository_UpdateRole(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間和成員 + room := &entity.Room{ + Name: "Update Role Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + member := &entity.RoomMember{ + RoomID: room.RoomID, + UID: "user-1", + Role: chat.RoomRoleMember.String(), + } + require.NoError(t, roomTestRepo.Insert(ctx, member)) + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 更新角色 + member.Role = chat.RoomRoleAdmin.String() + err := roomTestRepo.UpdateRole(ctx, member) + require.NoError(t, err) + + // 驗證更新 + updated, err := roomTestRepo.Get(ctx, room.RoomID.String(), "user-1") + require.NoError(t, err) + assert.Equal(t, chat.RoomRoleAdmin.String(), updated.Role) +} + +func TestRoomRepository_DeleteMember(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間和成員 + room := &entity.Room{ + Name: "Delete Member Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + member := &entity.RoomMember{ + RoomID: room.RoomID, + UID: "user-1", + Role: chat.RoomRoleMember.String(), + } + require.NoError(t, roomTestRepo.Insert(ctx, member)) + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 刪除成員 + err := roomTestRepo.DeleteMember(ctx, room.RoomID.String(), "user-1") + require.NoError(t, err) + + // 驗證成員已被刪除 + _, err = roomTestRepo.Get(ctx, room.RoomID.String(), "user-1") + assert.Error(t, err) + assert.True(t, cassandra.IsNotFound(err)) + + // 驗證 user_room 表也被刪除 + userRooms, err := roomTestRepo.GetUserRooms(ctx, "user-1") + require.NoError(t, err) + assert.Len(t, userRooms, 0) +} + +func TestRoomRepository_DeleteRoom(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "Delete Room Test", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + // 插入多個成員 + members := []*entity.RoomMember{ + {RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleMember.String()}, + {RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()}, + } + + for _, member := range members { + require.NoError(t, roomTestRepo.Insert(ctx, member)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 刪除整個房間 + err := roomTestRepo.DeleteRoom(ctx, room.RoomID.String()) + require.NoError(t, err) + + // 驗證所有成員已被刪除 + allMembers, err := roomTestRepo.AllMembers(ctx, room.RoomID.String()) + require.NoError(t, err) + assert.Len(t, allMembers, 0) + + // 驗證 user_room 表中的關聯也被刪除 + for _, member := range members { + userRooms, err := roomTestRepo.GetUserRooms(ctx, member.UID) + require.NoError(t, err) + // 驗證該用戶的 user_room 記錄中不包含這個房間 + found := false + for _, ur := range userRooms { + if ur.RoomID == room.RoomID { + found = true + break + } + } + assert.False(t, found, "user_room should not contain deleted room") + } +} + +func TestRoomRepository_Count(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "Count Members Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + // 插入多個成員 + for i := 0; i < 5; i++ { + member := &entity.RoomMember{ + RoomID: room.RoomID, + UID: "user-" + strconv.Itoa(i), + Role: chat.RoomRoleMember.String(), + } + require.NoError(t, roomTestRepo.Insert(ctx, member)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 計算成員數 + count, err := roomTestRepo.Count(ctx, room.RoomID.String()) + require.NoError(t, err) + assert.Equal(t, int64(5), count) +} + +// ==================== User Interface 測試 ==================== + +func TestRoomRepository_GetUserRooms(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建多個測試房間 + rooms := make([]*entity.Room, 3) + for i := 0; i < 3; i++ { + room := &entity.Room{ + Name: "User Room " + strconv.Itoa(i), + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + rooms[i] = room + } + + // 將用戶加入多個房間 + uid := "user-1" + for _, room := range rooms { + member := &entity.RoomMember{ + RoomID: room.RoomID, + UID: uid, + Role: chat.RoomRoleMember.String(), + } + require.NoError(t, roomTestRepo.Insert(ctx, member)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 查詢用戶所在的所有房間 + userRooms, err := roomTestRepo.GetUserRooms(ctx, uid) + require.NoError(t, err) + assert.Len(t, userRooms, 3) +} + +func TestRoomRepository_CountUserRooms(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + rooms := make([]*entity.Room, 5) + for i := 0; i < 5; i++ { + room := &entity.Room{ + Name: "Count User Room " + strconv.Itoa(i), + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + rooms[i] = room + } + + // 將用戶加入多個房間 + uid := "user-1" + for _, room := range rooms { + member := &entity.RoomMember{ + RoomID: room.RoomID, + UID: uid, + Role: chat.RoomRoleMember.String(), + } + require.NoError(t, roomTestRepo.Insert(ctx, member)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 計算用戶所在的房間數 + count, err := roomTestRepo.CountUserRooms(ctx, uid) + require.NoError(t, err) + assert.Equal(t, int64(5), count) +} + +func TestRoomRepository_IsUserInRoom(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建測試房間 + room := &entity.Room{ + Name: "Is User In Room Test", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + // 插入成員 + member := &entity.RoomMember{ + RoomID: room.RoomID, + UID: "user-1", + Role: chat.RoomRoleMember.String(), + } + require.NoError(t, roomTestRepo.Insert(ctx, member)) + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + tests := []struct { + name string + uid string + roomID string + want bool + wantErr bool + }{ + { + name: "user in room", + uid: "user-1", + roomID: room.RoomID.String(), + want: true, + wantErr: false, + }, + { + name: "user not in room", + uid: "user-2", + roomID: room.RoomID.String(), + want: false, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := roomTestRepo.IsUserInRoom(ctx, tt.uid, tt.roomID) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, result) + } + }) + } +} + +// ==================== 整合測試 ==================== + +func TestRoomRepository_Integration(t *testing.T) { + ensureRoomRepository(t) + clearRoomData(t) + ctx := context.Background() + + // 創建房間 + room := &entity.Room{ + Name: "Integration Test Room", + Status: "active", + } + require.NoError(t, roomTestRepo.Create(ctx, room)) + + // 添加多個成員 + members := []*entity.RoomMember{ + {RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleOwner.String()}, + {RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()}, + {RoomID: room.RoomID, UID: "user-3", Role: chat.RoomRoleMember.String()}, + } + + for _, member := range members { + require.NoError(t, roomTestRepo.Insert(ctx, member)) + } + + // 等待數據寫入 + time.Sleep(100 * time.Millisecond) + + // 驗證房間存在 + exists, err := roomTestRepo.RoomExists(ctx, room.RoomID.String()) + require.NoError(t, err) + assert.True(t, exists) + + // 驗證成員數量 + count, err := roomTestRepo.Count(ctx, room.RoomID.String()) + require.NoError(t, err) + assert.Equal(t, int64(3), count) + + // 驗證所有成員 + allMembers, err := roomTestRepo.AllMembers(ctx, room.RoomID.String()) + require.NoError(t, err) + assert.Len(t, allMembers, 3) + + // 驗證用戶所在的房間 + for _, member := range members { + userRooms, err := roomTestRepo.GetUserRooms(ctx, member.UID) + require.NoError(t, err) + assert.Greater(t, len(userRooms), 0) + + inRoom, err := roomTestRepo.IsUserInRoom(ctx, member.UID, room.RoomID.String()) + require.NoError(t, err) + assert.True(t, inRoom) + } +} + diff --git a/pkg/chat/usecase/message.go b/pkg/chat/usecase/message.go new file mode 100644 index 0000000..f14ea8b --- /dev/null +++ b/pkg/chat/usecase/message.go @@ -0,0 +1,272 @@ +package usecase + +import ( + "backend/pkg/chat/domain/entity" + "backend/pkg/chat/domain/repository" + "backend/pkg/chat/domain/usecase" + "backend/pkg/library/centrifugo" + errs "backend/pkg/library/errors" + "backend/pkg/utils" + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "math" + "time" + + "github.com/gocql/gocql" + "github.com/google/uuid" +) + +const ( + defaultPageSize = 20 + maxPageSize = 100 // 設置最大分頁大小 +) + +type MessageUseCaseParam struct { + MessageRepo repository.MessageRepository + RoomRepo repository.RoomRepository + MsgClient *centrifugo.Client + Logger errs.Logger +} + +type MessageUseCase struct { + MessageUseCaseParam +} + +// NewMessageUseCase 創建新的訊息 UseCase +func NewMessageUseCase(param MessageUseCaseParam) usecase.MessageUseCase { + return &MessageUseCase{ + param, + } +} + +func (use *MessageUseCase) SendMessage(ctx context.Context, param usecase.SendMessageReq) error { + // 驗證輸入參數 + if err := use.validateSendMessageReq(param); err != nil { + return err + } + + // 驗證使用者是否在房間中(優先檢查,避免不必要的去重操作) + if err := use.verifyRoomMembership(ctx, param.UID, param.RoomID); err != nil { + return err + } + + // 去重檢查 + now := time.Now().UTC() + bucketSec := now.Unix() + contentMD5 := calculateMD5(param.Content) + + isDuplicate, err := use.MessageRepo.CheckAndInsertDedup(ctx, repository.CheckDupReq{ + RoomID: param.RoomID, + UID: param.UID, // 修復:應該是 param.UID 而不是 param.RoomID + BucketSec: bucketSec, + ContentMD5: contentMD5, + }) + if err != nil { + return use.logError("messageRepo.CheckAndInsertDedup", param, err, "failed to check message deduplication") + } + + if isDuplicate { + return errs.InputInvalidFormatError("duplicate message detected") + } + + // 建立並儲存訊息 + msg, err := use.createMessage(param, now) + if err != nil { + return err + } + + if err = use.MessageRepo.Insert(ctx, msg); err != nil { + return use.logError("messageRepo.Insert", param, err, "failed to insert message") + } + + // 發布到 Centrifugo(非阻塞,失敗不影響主流程) + use.publishToCentrifugo(ctx, param.RoomID, msg, now) + + return nil +} + +// validateSendMessageReq 驗證發送訊息的請求參數 +func (use *MessageUseCase) validateSendMessageReq(param usecase.SendMessageReq) error { + if param.Content == "" { + return errs.InputInvalidFormatError("content cannot be empty") + } + if param.RoomID == "" { + return errs.InputInvalidFormatError("room_id cannot be empty") + } + if param.UID == "" { + return errs.InputInvalidFormatError("uid cannot be empty") + } + return nil +} + +// verifyRoomMembership 驗證使用者是否在房間中 +func (use *MessageUseCase) verifyRoomMembership(ctx context.Context, uid, roomID string) error { + isMember, err := use.RoomRepo.IsUserInRoom(ctx, uid, roomID) + if err != nil { + return use.logError("roomRepo.IsUserInRoom", map[string]interface{}{ + "uid": uid, + "roomID": roomID, + }, err, "failed to check room membership") + } + + if !isMember { + return errs.AuthForbiddenErrorL( + use.Logger, + []errs.LogField{ + {Key: "uid", Val: uid}, + {Key: "roomID", Val: roomID}, + }, + "user is not a member of the room") + } + + return nil +} + +// createMessage 建立訊息實體 +func (use *MessageUseCase) createMessage(param usecase.SendMessageReq, now time.Time) (*entity.Message, error) { + roomID, err := gocql.ParseUUID(param.RoomID) + if err != nil { + return nil, errs.InputInvalidFormatError("invalid room_id format").Wrap(err) + } + + msgID, err := uuid.NewV7() + if err != nil { + return nil, errs.SysInternalError("failed to generate message id").Wrap(err) + } + + return &entity.Message{ + RoomID: roomID, + BucketDay: utils.GetBucketDay(now), + UID: param.UID, + MsgID: msgID, + Content: param.Content, + }, nil +} + +// publishToCentrifugo 發布訊息到 Centrifugo +func (use *MessageUseCase) publishToCentrifugo(ctx context.Context, roomID string, msg *entity.Message, now time.Time) { + channel := fmt.Sprintf("room:%s", roomID) + messageData := map[string]interface{}{ + "msg_id": msg.MsgID.String(), + "uid": msg.UID, + "content": msg.Content, + "timestamp": now.UnixNano(), // 使用實際時間戳,而不是 msg.TS(可能為 0) + "room_id": msg.RoomID.String(), + } + + if _, err := use.MsgClient.PublishJSON(ctx, channel, messageData); err != nil { + // 記錄錯誤但不影響主流程,因為訊息已經成功儲存 + if use.Logger != nil { + use.Logger.WithFields( + errs.LogField{Key: "roomID", Val: roomID}, + errs.LogField{Key: "msgID", Val: msg.MsgID.String()}, + errs.LogField{Key: "error", Val: err.Error()}, + ).Error(fmt.Sprintf("failed to publish message to Centrifugo: %v", err)) + } + } +} + +// logError 統一的錯誤記錄方法 +func (use *MessageUseCase) logError(funcName string, param interface{}, err error, message string) error { + return errs.DBErrorErrorL( + use.Logger, + []errs.LogField{ + {Key: "param", Val: param}, + {Key: "func", Val: funcName}, + {Key: "err", Val: err.Error()}, + }, + message) +} + +func (use *MessageUseCase) ListMessages(ctx context.Context, req usecase.ListMessagesReq) ([]usecase.Message, int64, error) { + // 驗證輸入參數 + if err := use.validateListMessagesReq(req); err != nil { + return nil, 0, err + } + + // 驗證使用者是否在房間中 + if err := use.verifyRoomMembership(ctx, req.UID, req.RoomID); err != nil { + return nil, 0, err + } + + // 取得 bucket_day(如果未提供則使用今天) + bucketDay := req.BucketDay + if bucketDay == "" { + bucketDay = utils.GetTodayBucketDay() + } + + // 防止 PageSize overflow 並設置合理的範圍 + pageSize := use.normalizePageSize(req.PageSize) + + // 查詢訊息 + messages, err := use.MessageRepo.ListMessages(ctx, repository.ListMessagesReq{ + RoomID: req.RoomID, + BucketDay: bucketDay, + PageSize: pageSize, + LastTS: req.LastTS, + }) + if err != nil { + return nil, 0, use.logError("messageRepo.ListMessages", req, err, "failed to list messages") + } + + // 轉換為 usecase.Message + result := make([]usecase.Message, 0, len(messages)) + for _, msg := range messages { + result = append(result, usecase.Message{ + RoomID: msg.RoomID.String(), + BucketDay: msg.BucketDay, + TS: msg.TS, + UID: msg.UID, + Content: msg.Content, + }) + } + + // 計算總數(只在第一頁時計算) + var total int64 + if req.LastTS == 0 { + total, err = use.MessageRepo.Count(ctx, req.RoomID) + if err != nil { + return nil, 0, use.logError("messageRepo.Count", req, err, "failed to count messages") + } + } + + return result, total, nil +} + +// validateListMessagesReq 驗證查詢訊息的請求參數 +func (use *MessageUseCase) validateListMessagesReq(req usecase.ListMessagesReq) error { + if req.RoomID == "" { + return errs.InputInvalidFormatError("room_id cannot be empty") + } + if req.UID == "" { + return errs.InputInvalidFormatError("uid cannot be empty") + } + return nil +} + +// normalizePageSize 正規化分頁大小 +func (use *MessageUseCase) normalizePageSize(pageSize int64) int { + // 檢查是否超過 int 的最大值 + if pageSize > int64(math.MaxInt) { + return maxPageSize + } + + size := int(pageSize) + if size <= 0 { + return defaultPageSize + } + + if size > maxPageSize { + return maxPageSize + } + + return size +} + +// calculateMD5 計算字串的 MD5 雜湊值 +func calculateMD5(content string) string { + hash := md5.Sum([]byte(content)) + return hex.EncodeToString(hash[:]) +} diff --git a/pkg/library/cassandra/query.go b/pkg/library/cassandra/query.go index 81dbc9d..0398189 100644 --- a/pkg/library/cassandra/query.go +++ b/pkg/library/cassandra/query.go @@ -75,6 +75,7 @@ type QueryBuilder[T Table] interface { OrderBy(column string, order Order) QueryBuilder[T] Limit(n int) QueryBuilder[T] Select(columns ...string) QueryBuilder[T] + AllowFiltering() QueryBuilder[T] Scan(ctx context.Context, dest *[]T) error One(ctx context.Context) (T, error) Count(ctx context.Context) (int64, error) @@ -82,11 +83,12 @@ type QueryBuilder[T Table] interface { // queryBuilder 是 QueryBuilder 的具體實作 type queryBuilder[T Table] struct { - repo *repository[T] - conditions []Condition - orders []orderBy - limit int - columns []string + repo *repository[T] + conditions []Condition + orders []orderBy + limit int + columns []string + allowFiltering bool } type orderBy struct { @@ -125,6 +127,12 @@ func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] { return q } +// AllowFiltering 允許不使用 partition key 的查詢(效能較差,慎用) +func (q *queryBuilder[T]) AllowFiltering() QueryBuilder[T] { + q.allowFiltering = true + return q +} + // Scan 執行查詢並將結果掃描到 dest func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error { if dest == nil { @@ -171,6 +179,11 @@ func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error { builder = builder.Limit(uint(q.limit)) } + // 添加 ALLOW FILTERING + if q.allowFiltering { + builder = builder.AllowFiltering() + } + stmt, names := builder.ToCql() query := q.repo.db.withContextAndTimestamp(ctx, q.repo.db.session.Query(stmt, names).BindMap(bindMap)) @@ -213,6 +226,11 @@ func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) { builder = builder.Where(cmps...) } + // 添加 ALLOW FILTERING + if q.allowFiltering { + builder = builder.AllowFiltering() + } + stmt, names := builder.ToCql() query := q.repo.db.withContextAndTimestamp(ctx, q.repo.db.session.Query(stmt, names).BindMap(bindMap)) diff --git a/pkg/library/cassandra/repository.go b/pkg/library/cassandra/repository.go index 6f6333e..1bef99d 100644 --- a/pkg/library/cassandra/repository.go +++ b/pkg/library/cassandra/repository.go @@ -130,8 +130,23 @@ func (r *repository[T]) updateSelective(ctx context.Context, doc T, includeZero func (r *repository[T]) Delete(ctx context.Context, pk any) error { t := table.New(r.metadata) stmt, names := t.Delete() - q := r.db.withContextAndTimestamp(ctx, - r.db.session.Query(stmt, names).Bind(pk)) + + // 如果 pk 是 struct,使用 BindStruct;否則使用 Bind + var q *gocqlx.Queryx + if reflect.TypeOf(pk).Kind() == reflect.Struct { + q = r.db.withContextAndTimestamp(ctx, + r.db.session.Query(stmt, names).BindStruct(pk)) + } else { + // 單一主鍵欄位的情況 + // 注意:這只適用於單一 Partition Key 且無 Clustering Key 的情況 + if len(r.metadata.PartKey) != 1 || len(r.metadata.SortKey) > 0 { + return ErrInvalidInput.WithTable(r.table).WithError( + fmt.Errorf("single value primary key only supported for single partition key without clustering key"), + ) + } + q = r.db.withContextAndTimestamp(ctx, + r.db.session.Query(stmt, names).Bind(pk)) + } return q.ExecRelease() } diff --git a/pkg/library/centrifugo/README.md b/pkg/library/centrifugo/README.md new file mode 100644 index 0000000..a5b12f9 --- /dev/null +++ b/pkg/library/centrifugo/README.md @@ -0,0 +1,474 @@ +# Centrifugo Client Library + +Go 語言的 Centrifugo 即時訊息服務客戶端庫,提供完整的 Server-side API 支援。 + +## 功能特色 + +- ✅ **HTTP API 客戶端** - 發布訊息、訂閱管理、在線狀態、歷史訊息 +- ✅ **JWT Token 生成** - 連線認證和私有頻道訂閱 +- ✅ **Token 黑名單** - 撤銷單一 Token 或用戶所有 Token +- ✅ **在線狀態追蹤** - Redis 存儲 + Centrifugo Presence API +- ✅ **統一服務入口** - 簡單易用的 `Service` 整合介面 + +## 安裝依賴 + +```bash +go get github.com/golang-jwt/jwt/v5 +go get github.com/zeromicro/go-zero/core/stores/redis +``` + +## 快速開始 + +### 1. 創建服務實例(推薦方式) + +```go +import ( + "backend/pkg/library/centrifugo" + "github.com/zeromicro/go-zero/core/stores/redis" +) + +// 創建 Redis 客戶端 +rds, _ := redis.NewRedis(redis.RedisConf{ + Host: "localhost:6379", + Type: "node", +}) + +// 創建 Centrifugo 服務 +svc := centrifugo.NewService(centrifugo.ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "your-api-key", + TokenSecret: "your-jwt-secret", + Redis: rds, // 可選,用於黑名單和在線狀態 +}) +``` + +### 2. 發布訊息 + +```go +ctx := context.Background() + +// 方法 1: 使用 Service 便捷方法 +result, err := svc.PublishJSON(ctx, "chat:room-123", map[string]interface{}{ + "message": "Hello, World!", + "user": "daniel", +}) + +// 方法 2: 使用 Client 直接調用 +result, err := svc.Client().Publish(ctx, "chat:room-123", []byte(`{"message": "Hello!"}`)) + +// 批量發布到多個頻道 +channels := []string{"user:1", "user:2", "user:3"} +err := svc.BroadcastJSON(ctx, channels, map[string]string{ + "type": "notification", + "message": "System maintenance", +}) +``` + +### 3. Token 生成 + +```go +// 快速生成連線 Token +token, err := svc.GenerateToken("user-123") + +// 生成帶用戶資訊的 Token +token, err := svc.GenerateTokenWithInfo("user-123", map[string]interface{}{ + "name": "Daniel", + "avatar": "https://example.com/avatar.jpg", +}) + +// 完整選項 +token, err := svc.Token().GenerateConnectionToken(centrifugo.ConnectionTokenOptions{ + UserID: "user-123", + Info: map[string]interface{}{"role": "admin"}, + Channels: []string{"chat:room-1", "chat:room-2"}, // 自動訂閱 +}) + +// 訂閱 Token(用於私有頻道) +token, err := svc.Token().QuickSubscriptionToken("user-123", "private:room-456") +``` + +### 4. 撤銷 Token(踢人) + +```go +// 最常用:撤銷用戶所有 Token 並斷開連線 +// 適用於:用戶被封禁、密碼變更、用戶登出全部設備 +err := svc.InvalidateUser(ctx, "user-123") + +// 只斷開連線(不撤銷 Token) +err := svc.Disconnect(ctx, "user-123") + +// 撤銷特定 Token(需要 JTI) +err := svc.Blacklist().RevokeToken(ctx, jti, time.Hour) + +// 撤銷用戶所有 Token(不斷開連線) +err := svc.Blacklist().RevokeUserTokens(ctx, "user-123") +``` + +### 5. 在線狀態追蹤 + +```go +// 檢查單一用戶是否在線 +online, err := svc.IsUserOnline(ctx, "user-123") + +// 批量獲取在線狀態 +status, err := svc.GetUsersOnlineStatus(ctx, []string{"user-1", "user-2", "user-3"}) +// status = map[string]bool{"user-1": true, "user-2": false, "user-3": true} + +// 處理 Centrifugo Connect/Disconnect Proxy 事件 +svc.Online().HandleConnect(ctx, "user-123") +svc.Online().HandleDisconnect(ctx, "user-123") + +// 使用 Centrifugo Presence API(頻道級別) +users, err := svc.Online().GetChannelOnlineUsers(ctx, "chat:room-123") +stats, err := svc.Online().GetChannelStats(ctx, "chat:room-123") +``` + +--- + +## 獨立使用各元件 + +如果不需要完整的 Service,可以獨立使用各元件: + +### HTTP API Client + +```go +// 創建客戶端 +client := centrifugo.NewClient("http://localhost:8000", "your-api-key") + +// 使用自定義配置 +client := centrifugo.NewClientWithConfig(centrifugo.ClientConfig{ + APIURL: "http://localhost:8000", + APIKey: "your-api-key", + Timeout: 5 * time.Second, + MaxIdleConns: 200, + MaxIdleConnsPerHost: 50, +}) + +// API 調用 +client.Publish(ctx, channel, data) +client.PublishJSON(ctx, channel, data) +client.Broadcast(ctx, channels, data) +client.Subscribe(ctx, user, channel) +client.Unsubscribe(ctx, user, channel) +client.Disconnect(ctx, user) +client.DisconnectWithCode(ctx, user, code, reason) +client.Presence(ctx, channel) +client.PresenceStats(ctx, channel) +client.History(ctx, channel, limit) +client.HistoryReverse(ctx, channel, limit) +client.Channels(ctx) +client.ChannelsWithPattern(ctx, pattern) +client.Info(ctx) +client.Ping(ctx) +``` + +### Token Generator + +```go +// 創建生成器 +tokenGen := centrifugo.NewTokenGenerator("your-jwt-secret") + +// 使用自定義配置 +tokenGen := centrifugo.NewTokenGeneratorWithConfig(centrifugo.TokenConfig{ + Secret: "your-jwt-secret", + ExpireIn: 24 * time.Hour, +}) + +// 生成 Token +tokenGen.QuickConnectionToken(userID) +tokenGen.QuickSubscriptionToken(userID, channel) +tokenGen.GenerateConnectionToken(opts) +tokenGen.GenerateSubscriptionToken(opts) +tokenGen.GenerateAnonymousToken() +``` + +### Token Blacklist + +```go +// 創建黑名單管理器 +blacklist := centrifugo.NewTokenBlacklist(redisClient) + +// 撤銷操作 +blacklist.RevokeToken(ctx, jti, ttl) // 撤銷單一 Token +blacklist.RevokeUserTokens(ctx, userID) // 撤銷用戶所有 Token + +// 驗證操作 +blacklist.IsTokenRevoked(ctx, jti) // 檢查 Token 是否被撤銷 +blacklist.GetUserTokenVersion(ctx, userID) // 獲取用戶 Token 版本 +blacklist.IsTokenVersionValid(ctx, userID, v) // 檢查版本是否有效 +``` + +### Online Manager + +```go +// 創建管理器 +store := centrifugo.NewRedisOnlineStore(redisClient) +onlineManager := centrifugo.NewOnlineManagerWithTTL(client, store, 5*time.Minute) + +// Redis 存儲操作 +onlineManager.HandleConnect(ctx, userID) +onlineManager.HandleDisconnect(ctx, userID) +onlineManager.IsUserOnline(ctx, userID) +onlineManager.GetUsersOnlineStatus(ctx, userIDs) +onlineManager.RefreshOnline(ctx, userID) + +// Centrifugo Presence API +onlineManager.IsUserInChannel(ctx, userID, channel) +onlineManager.GetChannelOnlineUsers(ctx, channel) +onlineManager.GetChannelStats(ctx, channel) +``` + +--- + +## 前端整合範例 + +### JavaScript (使用 centrifuge-js) + +```javascript +import { Centrifuge } from 'centrifuge'; + +// 從後端 API 獲取 Token +const getToken = async () => { + const response = await fetch('/api/centrifugo/token'); + const data = await response.json(); + return data.token; +}; + +// 創建連線 +const centrifuge = new Centrifuge('ws://localhost:8000/connection/websocket', { + getToken: getToken, +}); + +// 訂閱頻道 +const sub = centrifuge.newSubscription('chat:room-123'); +sub.on('publication', (ctx) => { + console.log('Received:', ctx.data); +}); +sub.subscribe(); + +// 連線 +centrifuge.connect(); +``` + +### 後端 Token API + +```go +// handlers/centrifugo.go +func (h *Handler) GetConnectionToken(c *gin.Context) { + userID := c.GetString("user_id") // 從 JWT 或 session 獲取 + + token, err := h.svc.GenerateToken(userID) + if err != nil { + c.JSON(500, gin.H{"error": "failed to generate token"}) + return + } + + c.JSON(200, gin.H{"token": token}) +} + +func (h *Handler) GetSubscriptionToken(c *gin.Context) { + userID := c.GetString("user_id") + channel := c.Query("channel") + + // 驗證用戶是否有權限訂閱此頻道 + if !h.canSubscribe(userID, channel) { + c.JSON(403, gin.H{"error": "forbidden"}) + return + } + + token, err := h.svc.Token().QuickSubscriptionToken(userID, channel) + if err != nil { + c.JSON(500, gin.H{"error": "failed to generate token"}) + return + } + + c.JSON(200, gin.H{"token": token}) +} +``` + +--- + +## Centrifugo Proxy 整合 + +### Connect Proxy + +```go +// POST /centrifugo/connect +func (h *Handler) CentrifugoConnect(c *gin.Context) { + var req struct { + Client string `json:"client"` + Transport string `json:"transport"` + Protocol string `json:"protocol"` + Data []byte `json:"data"` + } + c.BindJSON(&req) + + // 從 Token 驗證用戶(Centrifugo 會傳遞) + userID := extractUserID(req.Data) + + // 記錄連線 + h.svc.Online().HandleConnect(c, userID) + + c.JSON(200, gin.H{ + "result": map[string]interface{}{ + "user": userID, + }, + }) +} +``` + +### Disconnect Proxy + +```go +// POST /centrifugo/disconnect +func (h *Handler) CentrifugoDisconnect(c *gin.Context) { + var req struct { + Client string `json:"client"` + User string `json:"user"` + } + c.BindJSON(&req) + + // 記錄離線 + h.svc.Online().HandleDisconnect(c, req.User) + + c.JSON(200, gin.H{"result": map[string]interface{}{}}) +} +``` + +--- + +## Centrifugo 配置參考 + +```json +{ + "token_hmac_secret_key": "your-jwt-secret", + "api_key": "your-api-key", + "admin": true, + "allowed_origins": ["http://localhost:3000"], + "proxy_connect_endpoint": "http://localhost:8080/centrifugo/connect", + "proxy_disconnect_endpoint": "http://localhost:8080/centrifugo/disconnect", + "namespaces": [ + { + "name": "chat", + "presence": true, + "history_size": 100, + "history_ttl": "300s" + }, + { + "name": "private", + "presence": true, + "protected": true + } + ] +} +``` + +--- + +## API 參考 + +### Service 方法 + +| 方法 | 說明 | +|------|------| +| `Client()` | 返回 HTTP API 客戶端 | +| `Token()` | 返回 Token 生成器 | +| `Blacklist()` | 返回黑名單管理器(可能為 nil) | +| `Online()` | 返回在線狀態管理器(可能為 nil) | +| `PublishJSON(ctx, channel, data)` | 發布 JSON 訊息 | +| `BroadcastJSON(ctx, channels, data)` | 批量發布 JSON 訊息 | +| `Disconnect(ctx, userID)` | 斷開用戶連線 | +| `GenerateToken(userID)` | 快速生成連線 Token | +| `GenerateTokenWithInfo(userID, info)` | 生成帶資訊的連線 Token | +| `InvalidateUser(ctx, userID)` | 撤銷所有 Token 並斷開連線 | +| `IsUserOnline(ctx, userID)` | 檢查用戶是否在線 | +| `GetUsersOnlineStatus(ctx, userIDs)` | 批量獲取在線狀態 | + +### Client 方法 + +| 方法 | 說明 | 返回值 | +|------|------|--------| +| `Publish(ctx, channel, data)` | 發布訊息 | `*PublishResult, error` | +| `PublishJSON(ctx, channel, data)` | 發布 JSON | `*PublishResult, error` | +| `Broadcast(ctx, channels, data)` | 批量發布 | `error` | +| `BroadcastJSON(ctx, channels, data)` | 批量發布 JSON | `error` | +| `Subscribe(ctx, user, channel)` | 訂閱用戶 | `error` | +| `Unsubscribe(ctx, user, channel)` | 取消訂閱 | `error` | +| `Disconnect(ctx, user)` | 斷開連線 | `error` | +| `DisconnectWithCode(ctx, user, code, reason)` | 帶代碼斷開連線 | `error` | +| `Presence(ctx, channel)` | 在線用戶 | `*PresenceResult, error` | +| `PresenceStats(ctx, channel)` | 在線統計 | `*PresenceStatsResult, error` | +| `History(ctx, channel, limit)` | 歷史訊息 | `*HistoryResult, error` | +| `HistoryReverse(ctx, channel, limit)` | 歷史訊息(倒序) | `*HistoryResult, error` | +| `Channels(ctx)` | 活躍頻道 | `*ChannelsResult, error` | +| `ChannelsWithPattern(ctx, pattern)` | 匹配頻道 | `*ChannelsResult, error` | +| `Info(ctx)` | 伺服器資訊 | `*InfoResult, error` | +| `Ping(ctx)` | 健康檢查 | `error` | + +### TokenGenerator 方法 + +| 方法 | 說明 | +|------|------| +| `GenerateConnectionToken(opts)` | 生成連線 Token(完整選項) | +| `GenerateSubscriptionToken(opts)` | 生成訂閱 Token(完整選項) | +| `GenerateAnonymousToken()` | 生成匿名 Token | +| `QuickConnectionToken(userID)` | 快速生成連線 Token | +| `QuickSubscriptionToken(userID, channel)` | 快速生成訂閱 Token | + +### TokenBlacklist 方法 + +| 方法 | 說明 | +|------|------| +| `RevokeToken(ctx, jti, ttl)` | 撤銷特定 Token | +| `RevokeUserTokens(ctx, userID)` | 撤銷用戶所有 Token | +| `IsTokenRevoked(ctx, jti)` | 檢查 Token 是否被撤銷 | +| `GetUserTokenVersion(ctx, userID)` | 獲取用戶 Token 版本 | +| `IsTokenVersionValid(ctx, userID, version)` | 檢查版本是否有效 | + +--- + +## 錯誤處理 + +```go +result, err := svc.Client().Publish(ctx, channel, data) +if err != nil { + // 檢查是否為 Centrifugo API 錯誤 + if apiErr, ok := err.(*centrifugo.APIError); ok { + fmt.Printf("Centrifugo error code: %d, message: %s\n", + apiErr.Code, apiErr.Message) + } else { + // 網路錯誤或其他錯誤 + fmt.Printf("Error: %v\n", err) + } +} + +// 檢查特定錯誤 +if errors.Is(err, centrifugo.ErrBlacklistNotConfigured) { + // 黑名單未配置 +} +if errors.Is(err, centrifugo.ErrOnlineStoreNotConfigured) { + // 在線狀態存儲未配置 +} +``` + +--- + +## 檔案結構 + +``` +pkg/library/centrifugo/ +├── centrifugo.go # 主入口,Service 整合介面 +├── client.go # HTTP API 客戶端 +├── token.go # JWT Token 生成器 +├── blacklist.go # Token 黑名單管理 +├── online.go # 在線狀態管理介面 +├── online_redis.go # Redis 在線狀態實作 +├── README.md # 文檔 +└── *_test.go # 測試文件 +``` + +--- + +## License + +MIT diff --git a/pkg/library/centrifugo/blacklist.go b/pkg/library/centrifugo/blacklist.go new file mode 100644 index 0000000..405e81c --- /dev/null +++ b/pkg/library/centrifugo/blacklist.go @@ -0,0 +1,130 @@ +package centrifugo + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/zeromicro/go-zero/core/stores/redis" +) + +// 錯誤定義 +var ( + ErrBlacklistNotConfigured = errors.New("token blacklist is not configured") + ErrOnlineStoreNotConfigured = errors.New("online store is not configured") +) + +// TokenBlacklist Token 黑名單管理器 +// 提供兩種撤銷機制: +// 1. 單一 Token 撤銷:使用 JTI(JWT ID)將特定 Token 加入黑名單 +// 2. 用戶全部撤銷:使用版本號機制,使用戶之前所有 Token 失效 +type TokenBlacklist struct { + redis *redis.Redis + prefix string +} + +// NewTokenBlacklist 創建 Token 黑名單管理器 +func NewTokenBlacklist(redisClient *redis.Redis) *TokenBlacklist { + return &TokenBlacklist{ + redis: redisClient, + prefix: "centrifugo:blacklist:", + } +} + +// NewTokenBlacklistWithPrefix 創建帶自定義前綴的 Token 黑名單管理器 +func NewTokenBlacklistWithPrefix(redisClient *redis.Redis, prefix string) *TokenBlacklist { + return &TokenBlacklist{ + redis: redisClient, + prefix: prefix, + } +} + +// ==================== 撤銷操作 ==================== + +// RevokeToken 撤銷特定 Token(使用 JTI) +// ttl: 黑名單過期時間,應設置為 Token 的剩餘有效時間 +// +// 使用場景: +// - 用戶登出單一設備 +// - 檢測到可疑活動的特定 session +func (b *TokenBlacklist) RevokeToken(ctx context.Context, jti string, ttl time.Duration) error { + if jti == "" { + return errors.New("jti cannot be empty") + } + key := b.tokenKey(jti) + return b.redis.SetexCtx(ctx, key, "revoked", int(ttl.Seconds())) +} + +// RevokeUserTokens 撤銷用戶的所有 Token(使用版本控制) +// 通過更新版本號,使該用戶之前發出的所有 Token 失效 +// +// 使用場景: +// - 用戶被封禁 +// - 密碼變更 +// - 用戶主動登出全部設備 +func (b *TokenBlacklist) RevokeUserTokens(ctx context.Context, userID string) error { + if userID == "" { + return errors.New("userID cannot be empty") + } + key := b.userVersionKey(userID) + version := time.Now().UnixNano() + // 設置 7 天過期,足夠長於任何 Token 的有效期 + return b.redis.SetexCtx(ctx, key, fmt.Sprintf("%d", version), 7*24*3600) +} + +// ==================== 驗證操作 ==================== + +// IsTokenRevoked 檢查 Token 是否被撤銷(使用 JTI) +func (b *TokenBlacklist) IsTokenRevoked(ctx context.Context, jti string) (bool, error) { + if jti == "" { + return false, nil + } + key := b.tokenKey(jti) + exists, err := b.redis.ExistsCtx(ctx, key) + if err != nil { + return false, err + } + return exists, nil +} + +// GetUserTokenVersion 獲取用戶的 Token 版本 +// 返回 0 表示沒有設置版本(用戶從未被撤銷過) +func (b *TokenBlacklist) GetUserTokenVersion(ctx context.Context, userID string) (int64, error) { + if userID == "" { + return 0, nil + } + key := b.userVersionKey(userID) + val, err := b.redis.GetCtx(ctx, key) + if err != nil { + return 0, err + } + if val == "" { + return 0, nil + } + var version int64 + _, err = fmt.Sscanf(val, "%d", &version) + return version, err +} + +// IsTokenVersionValid 檢查 Token 版本是否有效 +// tokenVersion: Token 內嵌的版本號 +// 如果 currentVersion > tokenVersion,表示 Token 已被撤銷 +func (b *TokenBlacklist) IsTokenVersionValid(ctx context.Context, userID string, tokenVersion int64) (bool, error) { + currentVersion, err := b.GetUserTokenVersion(ctx, userID) + if err != nil { + return false, err + } + // 如果沒有設置版本,或 Token 版本 >= 當前版本,則有效 + return currentVersion == 0 || tokenVersion >= currentVersion, nil +} + +// ==================== Key 生成 ==================== + +func (b *TokenBlacklist) tokenKey(jti string) string { + return b.prefix + "token:" + jti +} + +func (b *TokenBlacklist) userVersionKey(userID string) string { + return b.prefix + "user_version:" + userID +} diff --git a/pkg/library/centrifugo/blacklist_test.go b/pkg/library/centrifugo/blacklist_test.go new file mode 100644 index 0000000..0b9f9ab --- /dev/null +++ b/pkg/library/centrifugo/blacklist_test.go @@ -0,0 +1,246 @@ +package centrifugo + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zeromicro/go-zero/core/stores/redis" +) + +func setupTestRedis(t *testing.T) (*redis.Redis, func()) { + mr, err := miniredis.Run() + require.NoError(t, err) + + rds, err := redis.NewRedis(redis.RedisConf{ + Host: mr.Addr(), + Type: "node", + }) + require.NoError(t, err) + + return rds, func() { + mr.Close() + } +} + +func TestNewTokenBlacklist(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + + assert.NotNil(t, blacklist) + assert.Equal(t, "centrifugo:blacklist:", blacklist.prefix) +} + +func TestNewTokenBlacklistWithPrefix(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklistWithPrefix(rds, "custom:prefix:") + + assert.NotNil(t, blacklist) + assert.Equal(t, "custom:prefix:", blacklist.prefix) +} + +func TestRevokeToken(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + jti := "test-jti-123" + ttl := 1 * time.Hour + + // 撤銷 Token + err := blacklist.RevokeToken(ctx, jti, ttl) + require.NoError(t, err) + + // 檢查是否被撤銷 + revoked, err := blacklist.IsTokenRevoked(ctx, jti) + require.NoError(t, err) + assert.True(t, revoked) +} + +func TestRevokeToken_EmptyJTI(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + err := blacklist.RevokeToken(ctx, "", time.Hour) + assert.Error(t, err) + assert.Contains(t, err.Error(), "jti cannot be empty") +} + +func TestIsTokenRevoked_NotRevoked(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + // 檢查未撤銷的 Token + revoked, err := blacklist.IsTokenRevoked(ctx, "non-existent-jti") + require.NoError(t, err) + assert.False(t, revoked) +} + +func TestIsTokenRevoked_EmptyJTI(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + // 空 JTI 應該返回 false + revoked, err := blacklist.IsTokenRevoked(ctx, "") + require.NoError(t, err) + assert.False(t, revoked) +} + +func TestRevokeUserTokens(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + userID := "user-123" + + // 撤銷用戶所有 Token + err := blacklist.RevokeUserTokens(ctx, userID) + require.NoError(t, err) + + // 獲取版本 + version, err := blacklist.GetUserTokenVersion(ctx, userID) + require.NoError(t, err) + assert.Greater(t, version, int64(0)) +} + +func TestRevokeUserTokens_EmptyUserID(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + err := blacklist.RevokeUserTokens(ctx, "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "userID cannot be empty") +} + +func TestGetUserTokenVersion_NoVersion(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + // 未設置版本的用戶應該返回 0 + version, err := blacklist.GetUserTokenVersion(ctx, "new-user") + require.NoError(t, err) + assert.Equal(t, int64(0), version) +} + +func TestGetUserTokenVersion_EmptyUserID(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + version, err := blacklist.GetUserTokenVersion(ctx, "") + require.NoError(t, err) + assert.Equal(t, int64(0), version) +} + +func TestIsTokenVersionValid(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + userID := "user-123" + + // 未設置版本時,任何版本都應該有效 + valid, err := blacklist.IsTokenVersionValid(ctx, userID, 0) + require.NoError(t, err) + assert.True(t, valid) + + // 撤銷用戶 Token + err = blacklist.RevokeUserTokens(ctx, userID) + require.NoError(t, err) + + // 獲取當前版本 + currentVersion, err := blacklist.GetUserTokenVersion(ctx, userID) + require.NoError(t, err) + + // 舊版本應該無效 + valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion-1) + require.NoError(t, err) + assert.False(t, valid) + + // 當前版本應該有效 + valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion) + require.NoError(t, err) + assert.True(t, valid) + + // 更新版本應該有效 + valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion+1) + require.NoError(t, err) + assert.True(t, valid) +} + +func TestRevokeUserTokens_MultipleRevokes(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklist(rds) + ctx := context.Background() + + userID := "user-123" + + // 第一次撤銷 + err := blacklist.RevokeUserTokens(ctx, userID) + require.NoError(t, err) + + version1, err := blacklist.GetUserTokenVersion(ctx, userID) + require.NoError(t, err) + + // 等待一點時間確保時間戳不同 + time.Sleep(10 * time.Millisecond) + + // 第二次撤銷 + err = blacklist.RevokeUserTokens(ctx, userID) + require.NoError(t, err) + + version2, err := blacklist.GetUserTokenVersion(ctx, userID) + require.NoError(t, err) + + // 第二次的版本應該更大 + assert.Greater(t, version2, version1) + + // 第一次的版本應該已經無效 + valid, err := blacklist.IsTokenVersionValid(ctx, userID, version1) + require.NoError(t, err) + assert.False(t, valid) +} + +func TestKeyGeneration(t *testing.T) { + rds, cleanup := setupTestRedis(t) + defer cleanup() + + blacklist := NewTokenBlacklistWithPrefix(rds, "test:") + + // 測試 key 生成 + assert.Equal(t, "test:token:jti-123", blacklist.tokenKey("jti-123")) + assert.Equal(t, "test:user_version:user-456", blacklist.userVersionKey("user-456")) +} + diff --git a/pkg/library/centrifugo/centrifugo.go b/pkg/library/centrifugo/centrifugo.go new file mode 100644 index 0000000..3a459f8 --- /dev/null +++ b/pkg/library/centrifugo/centrifugo.go @@ -0,0 +1,188 @@ +// Package centrifugo 提供 Centrifugo 即時訊息服務的完整 Go 客戶端 +// +// 功能包含: +// - HTTP API 客戶端(發布訊息、訂閱管理、在線狀態等) +// - JWT Token 生成(連線認證、私有頻道訂閱) +// - Token 黑名單管理(撤銷單一 Token、撤銷用戶所有 Token) +// - 在線狀態追蹤(Redis 或記憶體存儲) +// +// 基本使用: +// +// // 創建服務實例 +// svc := centrifugo.NewService(centrifugo.ServiceConfig{ +// APIURL: "http://localhost:8000", +// APIKey: "your-api-key", +// TokenSecret: "your-jwt-secret", +// Redis: redisClient, // 可選,用於黑名單和在線狀態 +// }) +// +// // 發布訊息 +// svc.Client().PublishJSON(ctx, "chat:room-1", data) +// +// // 生成 Token +// token, _ := svc.Token().QuickConnectionToken("user-123") +// +// // 撤銷用戶所有 Token 並踢出 +// svc.InvalidateUser(ctx, "user-123") +package centrifugo + +import ( + "context" + "time" + + "github.com/zeromicro/go-zero/core/stores/redis" +) + +// Service Centrifugo 服務整合介面 +// 提供 HTTP API、Token 生成、黑名單管理、在線狀態追蹤的統一入口 +type Service struct { + client *Client + token *TokenGenerator + blacklist *TokenBlacklist + online *OnlineManager +} + +// ServiceConfig 服務配置 +type ServiceConfig struct { + // APIURL Centrifugo HTTP API 地址(必填) + APIURL string + // APIKey Centrifugo API 密鑰(必填) + APIKey string + // TokenSecret JWT Token 簽名密鑰(必填) + TokenSecret string + // TokenExpire Token 過期時間(預設 1 小時) + TokenExpire time.Duration + // Redis 客戶端(可選,用於黑名單和在線狀態) + Redis *redis.Redis + // ClientConfig HTTP 客戶端配置(可選) + ClientConfig *ClientConfig + // OnlineTTL 在線狀態過期時間(預設 5 分鐘) + OnlineTTL time.Duration + // KeyPrefix Redis key 前綴(預設 "centrifugo:") + KeyPrefix string +} + +// NewService 創建 Centrifugo 服務實例 +func NewService(cfg ServiceConfig) *Service { + // 設定預設值 + if cfg.TokenExpire == 0 { + cfg.TokenExpire = time.Hour + } + if cfg.OnlineTTL == 0 { + cfg.OnlineTTL = 5 * time.Minute + } + if cfg.KeyPrefix == "" { + cfg.KeyPrefix = "centrifugo:" + } + + // 創建 HTTP 客戶端 + var client *Client + if cfg.ClientConfig != nil { + client = NewClientWithConfig(*cfg.ClientConfig) + } else { + client = NewClient(cfg.APIURL, cfg.APIKey) + } + + // 創建 Token 生成器 + token := NewTokenGeneratorWithConfig(TokenConfig{ + Secret: cfg.TokenSecret, + ExpireIn: cfg.TokenExpire, + }) + + svc := &Service{ + client: client, + token: token, + } + + // 如果有 Redis,創建黑名單管理器和在線狀態管理器 + if cfg.Redis != nil { + svc.blacklist = NewTokenBlacklistWithPrefix(cfg.Redis, cfg.KeyPrefix+"blacklist:") + store := NewRedisOnlineStoreWithPrefix(cfg.Redis, cfg.KeyPrefix+"online:") + svc.online = NewOnlineManagerWithTTL(client, store, cfg.OnlineTTL) + } + + return svc +} + +// Client 返回 HTTP API 客戶端 +func (s *Service) Client() *Client { + return s.client +} + +// Token 返回 Token 生成器 +func (s *Service) Token() *TokenGenerator { + return s.token +} + +// Blacklist 返回黑名單管理器(可能為 nil) +func (s *Service) Blacklist() *TokenBlacklist { + return s.blacklist +} + +// Online 返回在線狀態管理器(可能為 nil) +func (s *Service) Online() *OnlineManager { + return s.online +} + +// ==================== 便捷方法 ==================== + +// PublishJSON 發布 JSON 訊息到頻道 +func (s *Service) PublishJSON(ctx context.Context, channel string, data interface{}) (*PublishResult, error) { + return s.client.PublishJSON(ctx, channel, data) +} + +// BroadcastJSON 批量發布 JSON 訊息到多個頻道 +func (s *Service) BroadcastJSON(ctx context.Context, channels []string, data interface{}) error { + return s.client.BroadcastJSON(ctx, channels, data) +} + +// Disconnect 斷開用戶連線 +func (s *Service) Disconnect(ctx context.Context, userID string) error { + return s.client.Disconnect(ctx, userID) +} + +// GenerateToken 快速生成連線 Token +func (s *Service) GenerateToken(userID string) (string, error) { + return s.token.QuickConnectionToken(userID) +} + +// GenerateTokenWithInfo 生成帶用戶資訊的連線 Token +func (s *Service) GenerateTokenWithInfo(userID string, info map[string]interface{}) (string, error) { + return s.token.GenerateConnectionToken(ConnectionTokenOptions{ + UserID: userID, + Info: info, + }) +} + +// InvalidateUser 撤銷用戶所有 Token 並斷開連線 +// 這是最常用的「踢人」方法,適用於: +// - 用戶被封禁 +// - 密碼變更 +// - 用戶登出(全設備) +func (s *Service) InvalidateUser(ctx context.Context, userID string) error { + // 撤銷所有 Token + if s.blacklist != nil { + if err := s.blacklist.RevokeUserTokens(ctx, userID); err != nil { + return err + } + } + // 斷開連線 + return s.client.Disconnect(ctx, userID) +} + +// IsUserOnline 檢查用戶是否在線 +func (s *Service) IsUserOnline(ctx context.Context, userID string) (bool, error) { + if s.online == nil { + return false, ErrOnlineStoreNotConfigured + } + return s.online.IsUserOnline(ctx, userID) +} + +// GetUsersOnlineStatus 批量獲取用戶在線狀態 +func (s *Service) GetUsersOnlineStatus(ctx context.Context, userIDs []string) (map[string]bool, error) { + if s.online == nil { + return nil, ErrOnlineStoreNotConfigured + } + return s.online.GetUsersOnlineStatus(ctx, userIDs) +} + diff --git a/pkg/library/centrifugo/centrifugo_test.go b/pkg/library/centrifugo/centrifugo_test.go new file mode 100644 index 0000000..391a2a3 --- /dev/null +++ b/pkg/library/centrifugo/centrifugo_test.go @@ -0,0 +1,303 @@ +package centrifugo + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zeromicro/go-zero/core/stores/redis" +) + +func setupServiceTestRedis(t *testing.T) (*redis.Redis, func()) { + mr, err := miniredis.Run() + require.NoError(t, err) + + rds, err := redis.NewRedis(redis.RedisConf{ + Host: mr.Addr(), + Type: "node", + }) + require.NoError(t, err) + + return rds, func() { + mr.Close() + } +} + +func TestNewService(t *testing.T) { + rds, cleanup := setupServiceTestRedis(t) + defer cleanup() + + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + Redis: rds, + }) + + assert.NotNil(t, svc) + assert.NotNil(t, svc.Client()) + assert.NotNil(t, svc.Token()) + assert.NotNil(t, svc.Blacklist()) + assert.NotNil(t, svc.Online()) +} + +func TestNewService_WithoutRedis(t *testing.T) { + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + Redis: nil, // 沒有 Redis + }) + + assert.NotNil(t, svc) + assert.NotNil(t, svc.Client()) + assert.NotNil(t, svc.Token()) + assert.Nil(t, svc.Blacklist()) // 應該為 nil + assert.Nil(t, svc.Online()) // 應該為 nil +} + +func TestNewService_DefaultValues(t *testing.T) { + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + }) + + assert.NotNil(t, svc) + // 檢查預設值已被應用(通過生成 Token 間接驗證) + token, err := svc.GenerateToken("user-123") + require.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestNewService_WithCustomConfig(t *testing.T) { + rds, cleanup := setupServiceTestRedis(t) + defer cleanup() + + customExpire := 24 * time.Hour + customTTL := 10 * time.Minute + customPrefix := "custom:" + + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + TokenExpire: customExpire, + Redis: rds, + OnlineTTL: customTTL, + KeyPrefix: customPrefix, + }) + + assert.NotNil(t, svc) +} + +func TestService_GenerateToken(t *testing.T) { + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + }) + + token, err := svc.GenerateToken("user-123") + + require.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestService_GenerateTokenWithInfo(t *testing.T) { + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + }) + + info := map[string]interface{}{ + "name": "Daniel", + "avatar": "https://example.com/avatar.jpg", + } + + token, err := svc.GenerateTokenWithInfo("user-123", info) + + require.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestService_IsUserOnline_WithoutRedis(t *testing.T) { + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + Redis: nil, + }) + + _, err := svc.IsUserOnline(context.Background(), "user-123") + + assert.Error(t, err) + assert.ErrorIs(t, err, ErrOnlineStoreNotConfigured) +} + +func TestService_GetUsersOnlineStatus_WithoutRedis(t *testing.T) { + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + Redis: nil, + }) + + _, err := svc.GetUsersOnlineStatus(context.Background(), []string{"user-1", "user-2"}) + + assert.Error(t, err) + assert.ErrorIs(t, err, ErrOnlineStoreNotConfigured) +} + +func TestService_IsUserOnline_WithRedis(t *testing.T) { + rds, cleanup := setupServiceTestRedis(t) + defer cleanup() + + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + Redis: rds, + }) + + ctx := context.Background() + userID := "user-123" + + // 初始狀態應該是離線 + online, err := svc.IsUserOnline(ctx, userID) + require.NoError(t, err) + assert.False(t, online) + + // 處理連線事件 + err = svc.Online().HandleConnect(ctx, userID) + require.NoError(t, err) + + // 現在應該在線 + online, err = svc.IsUserOnline(ctx, userID) + require.NoError(t, err) + assert.True(t, online) + + // 處理斷線事件 + err = svc.Online().HandleDisconnect(ctx, userID) + require.NoError(t, err) + + // 現在應該離線 + online, err = svc.IsUserOnline(ctx, userID) + require.NoError(t, err) + assert.False(t, online) +} + +func TestService_GetUsersOnlineStatus_WithRedis(t *testing.T) { + rds, cleanup := setupServiceTestRedis(t) + defer cleanup() + + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + Redis: rds, + }) + + ctx := context.Background() + + // 設置一些用戶在線 + err := svc.Online().HandleConnect(ctx, "user-1") + require.NoError(t, err) + err = svc.Online().HandleConnect(ctx, "user-3") + require.NoError(t, err) + + // 批量獲取在線狀態 + status, err := svc.GetUsersOnlineStatus(ctx, []string{"user-1", "user-2", "user-3"}) + require.NoError(t, err) + + assert.True(t, status["user-1"]) + assert.False(t, status["user-2"]) + assert.True(t, status["user-3"]) +} + +func TestService_Blacklist_Integration(t *testing.T) { + rds, cleanup := setupServiceTestRedis(t) + defer cleanup() + + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + Redis: rds, + }) + + ctx := context.Background() + userID := "user-123" + + // 撤銷用戶所有 Token + err := svc.Blacklist().RevokeUserTokens(ctx, userID) + require.NoError(t, err) + + // 獲取版本 + version, err := svc.Blacklist().GetUserTokenVersion(ctx, userID) + require.NoError(t, err) + assert.Greater(t, version, int64(0)) + + // 檢查舊版本無效 + valid, err := svc.Blacklist().IsTokenVersionValid(ctx, userID, version-1) + require.NoError(t, err) + assert.False(t, valid) + + // 檢查當前版本有效 + valid, err = svc.Blacklist().IsTokenVersionValid(ctx, userID, version) + require.NoError(t, err) + assert.True(t, valid) +} + +func TestService_MultipleConnections(t *testing.T) { + rds, cleanup := setupServiceTestRedis(t) + defer cleanup() + + svc := NewService(ServiceConfig{ + APIURL: "http://localhost:8000", + APIKey: "test-api-key", + TokenSecret: "test-secret", + Redis: rds, + }) + + ctx := context.Background() + userID := "user-123" + + // 模擬多個設備連線 + err := svc.Online().HandleConnect(ctx, userID) + require.NoError(t, err) + err = svc.Online().HandleConnect(ctx, userID) + require.NoError(t, err) + err = svc.Online().HandleConnect(ctx, userID) + require.NoError(t, err) + + // 用戶應該在線 + online, err := svc.IsUserOnline(ctx, userID) + require.NoError(t, err) + assert.True(t, online) + + // 斷開一個設備 + err = svc.Online().HandleDisconnect(ctx, userID) + require.NoError(t, err) + + // 用戶仍然在線(還有 2 個連線) + online, err = svc.IsUserOnline(ctx, userID) + require.NoError(t, err) + assert.True(t, online) + + // 斷開剩餘設備 + err = svc.Online().HandleDisconnect(ctx, userID) + require.NoError(t, err) + err = svc.Online().HandleDisconnect(ctx, userID) + require.NoError(t, err) + + // 用戶現在離線 + online, err = svc.IsUserOnline(ctx, userID) + require.NoError(t, err) + assert.False(t, online) +} + diff --git a/pkg/library/centrifugo/client.go b/pkg/library/centrifugo/client.go new file mode 100644 index 0000000..b6bc8d9 --- /dev/null +++ b/pkg/library/centrifugo/client.go @@ -0,0 +1,519 @@ +package centrifugo + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "time" +) + +// Client Centrifugo 客戶端 +type Client struct { + apiURL string + apiKey string + client *http.Client +} + +// ClientConfig 客戶端配置 +type ClientConfig struct { + // APIURL Centrifugo API 地址(必填) + APIURL string + // APIKey API 密鑰(必填) + APIKey string + + // Timeout 整體請求超時時間(預設 10 秒) + Timeout time.Duration + // MaxIdleConns 最大閒置連線數(預設 100) + MaxIdleConns int + // MaxIdleConnsPerHost 每個 host 最大閒置連線數(預設 20,適合高併發) + MaxIdleConnsPerHost int + // IdleConnTimeout 閒置連線超時時間(預設 90 秒) + IdleConnTimeout time.Duration + // DialTimeout 建立連線超時時間(預設 5 秒) + DialTimeout time.Duration + // TLSHandshakeTimeout TLS 握手超時時間(預設 5 秒) + TLSHandshakeTimeout time.Duration + // ResponseHeaderTimeout 等待響應頭超時時間(預設 10 秒) + ResponseHeaderTimeout time.Duration +} + +// DefaultConfig 返回預設配置 +func DefaultConfig(apiURL, apiKey string) ClientConfig { + return ClientConfig{ + APIURL: apiURL, + APIKey: apiKey, + Timeout: 10 * time.Second, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, // 提高以支援高併發 + IdleConnTimeout: 90 * time.Second, + DialTimeout: 5 * time.Second, + TLSHandshakeTimeout: 5 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + } +} + +// HighPerformanceConfig 返回高效能配置(適合高併發場景) +func HighPerformanceConfig(apiURL, apiKey string) ClientConfig { + return ClientConfig{ + APIURL: apiURL, + APIKey: apiKey, + Timeout: 5 * time.Second, // 更短的超時 + MaxIdleConns: 200, // 更多閒置連線 + MaxIdleConnsPerHost: 50, // 更多每 host 連線 + IdleConnTimeout: 120 * time.Second, + DialTimeout: 3 * time.Second, + TLSHandshakeTimeout: 3 * time.Second, + ResponseHeaderTimeout: 5 * time.Second, + } +} + +// NewClient 創建新的 Centrifugo 客戶端(使用預設配置) +func NewClient(apiURL, apiKey string) *Client { + return NewClientWithConfig(DefaultConfig(apiURL, apiKey)) +} + +// NewClientWithConfig 創建使用自定義配置的 Centrifugo 客戶端 +func NewClientWithConfig(config ClientConfig) *Client { + // 設定預設值 + if config.Timeout == 0 { + config.Timeout = 10 * time.Second + } + if config.MaxIdleConns == 0 { + config.MaxIdleConns = 100 + } + if config.MaxIdleConnsPerHost == 0 { + config.MaxIdleConnsPerHost = 20 + } + if config.IdleConnTimeout == 0 { + config.IdleConnTimeout = 90 * time.Second + } + if config.DialTimeout == 0 { + config.DialTimeout = 5 * time.Second + } + if config.TLSHandshakeTimeout == 0 { + config.TLSHandshakeTimeout = 5 * time.Second + } + if config.ResponseHeaderTimeout == 0 { + config.ResponseHeaderTimeout = 10 * time.Second + } + + transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: config.DialTimeout, + KeepAlive: 30 * time.Second, // TCP keep-alive + }).DialContext, + MaxIdleConns: config.MaxIdleConns, + MaxIdleConnsPerHost: config.MaxIdleConnsPerHost, + IdleConnTimeout: config.IdleConnTimeout, + TLSHandshakeTimeout: config.TLSHandshakeTimeout, + ResponseHeaderTimeout: config.ResponseHeaderTimeout, + ForceAttemptHTTP2: true, // 嘗試使用 HTTP/2 + } + + return &Client{ + apiURL: config.APIURL, + apiKey: config.APIKey, + client: &http.Client{ + Timeout: config.Timeout, + Transport: transport, + }, + } +} + +// NewClientWithHTTP 創建使用自定義 HTTP 客戶端的 Centrifugo 客戶端 +func NewClientWithHTTP(apiURL, apiKey string, httpClient *http.Client) *Client { + return &Client{ + apiURL: apiURL, + apiKey: apiKey, + client: httpClient, + } +} + +// ==================== Request/Response 結構 ==================== + +// PublishRequest 發布請求 +type PublishRequest struct { + Channel string `json:"channel"` + Data interface{} `json:"data"` +} + +// BroadcastRequest 批量發布請求 +type BroadcastRequest struct { + Channels []string `json:"channels"` + Data interface{} `json:"data"` +} + +// SubscribeRequest 訂閱請求 +type SubscribeRequest struct { + User string `json:"user"` + Channel string `json:"channel"` +} + +// UnsubscribeRequest 取消訂閱請求 +type UnsubscribeRequest struct { + User string `json:"user"` + Channel string `json:"channel"` +} + +// DisconnectRequest 斷開連線請求 +type DisconnectRequest struct { + User string `json:"user"` +} + +// DisconnectWithCodeRequest 帶代碼的斷開連線請求 +type DisconnectWithCodeRequest struct { + User string `json:"user"` + Disconnect DisconnectInfo `json:"disconnect,omitempty"` +} + +// DisconnectInfo 斷開連線資訊 +type DisconnectInfo struct { + Code uint32 `json:"code"` + Reason string `json:"reason"` +} + +// PresenceRequest 在線狀態請求 +type PresenceRequest struct { + Channel string `json:"channel"` +} + +// PresenceStatsRequest 在線統計請求 +type PresenceStatsRequest struct { + Channel string `json:"channel"` +} + +// HistoryRequest 歷史訊息請求 +type HistoryRequest struct { + Channel string `json:"channel"` + Limit int `json:"limit,omitempty"` + Reverse bool `json:"reverse,omitempty"` +} + +// ChannelsRequest 頻道列表請求 +type ChannelsRequest struct { + Pattern string `json:"pattern,omitempty"` +} + +// InfoRequest 伺服器資訊請求 +type InfoRequest struct{} + +// APIResponse Centrifugo API 通用響應 +type APIResponse struct { + Error *APIError `json:"error,omitempty"` + Result interface{} `json:"result,omitempty"` +} + +// APIError Centrifugo API 錯誤 +type APIError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (e *APIError) Error() string { + return fmt.Sprintf("centrifugo error %d: %s", e.Code, e.Message) +} + +// PublishResult 發布結果 +type PublishResult struct { + Offset uint64 `json:"offset,omitempty"` + Epoch string `json:"epoch,omitempty"` +} + +// PresenceResult 在線狀態結果 +type PresenceResult struct { + Presence map[string]ClientInfo `json:"presence"` +} + +// ClientInfo 客戶端資訊 +type ClientInfo struct { + User string `json:"user"` + Client string `json:"client"` + ConnInfo json.RawMessage `json:"conn_info,omitempty"` + ChanInfo json.RawMessage `json:"chan_info,omitempty"` +} + +// PresenceStatsResult 在線統計結果 +type PresenceStatsResult struct { + NumClients int `json:"num_clients"` + NumUsers int `json:"num_users"` +} + +// HistoryResult 歷史訊息結果 +type HistoryResult struct { + Publications []Publication `json:"publications"` + Offset uint64 `json:"offset,omitempty"` + Epoch string `json:"epoch,omitempty"` +} + +// Publication 發布的訊息 +type Publication struct { + Offset uint64 `json:"offset,omitempty"` + Data json.RawMessage `json:"data"` + Info *ClientInfo `json:"info,omitempty"` +} + +// ChannelsResult 頻道列表結果 +type ChannelsResult struct { + Channels map[string]ChannelInfo `json:"channels"` +} + +// ChannelInfo 頻道資訊 +type ChannelInfo struct { + NumClients int `json:"num_clients"` +} + +// InfoResult 伺服器資訊結果 +type InfoResult struct { + Nodes []NodeInfo `json:"nodes"` +} + +// NodeInfo 節點資訊 +type NodeInfo struct { + UID string `json:"uid"` + Name string `json:"name"` + Version string `json:"version"` + NumClients int `json:"num_clients"` + NumUsers int `json:"num_users"` + NumChannels int `json:"num_channels"` + Uptime int `json:"uptime"` +} + +// ==================== 發布相關方法 ==================== + +// Publish 發布訊息到指定頻道 +func (c *Client) Publish(ctx context.Context, channel string, data []byte) (*PublishResult, error) { + req := PublishRequest{ + Channel: channel, + Data: json.RawMessage(data), + } + return c.publish(ctx, req) +} + +// PublishJSON 發布 JSON 訊息到指定頻道 +func (c *Client) PublishJSON(ctx context.Context, channel string, data interface{}) (*PublishResult, error) { + req := PublishRequest{ + Channel: channel, + Data: data, + } + return c.publish(ctx, req) +} + +func (c *Client) publish(ctx context.Context, req PublishRequest) (*PublishResult, error) { + var result PublishResult + if err := c.callAPI(ctx, "publish", req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// Broadcast 批量發布訊息到多個頻道 +func (c *Client) Broadcast(ctx context.Context, channels []string, data []byte) error { + req := BroadcastRequest{ + Channels: channels, + Data: json.RawMessage(data), + } + return c.callAPI(ctx, "broadcast", req, nil) +} + +// BroadcastJSON 批量發布 JSON 訊息到多個頻道 +func (c *Client) BroadcastJSON(ctx context.Context, channels []string, data interface{}) error { + req := BroadcastRequest{ + Channels: channels, + Data: data, + } + return c.callAPI(ctx, "broadcast", req, nil) +} + +// ==================== 訂閱管理方法 ==================== + +// Subscribe 訂閱用戶到頻道 +func (c *Client) Subscribe(ctx context.Context, user, channel string) error { + req := SubscribeRequest{ + User: user, + Channel: channel, + } + return c.callAPI(ctx, "subscribe", req, nil) +} + +// Unsubscribe 取消用戶訂閱 +func (c *Client) Unsubscribe(ctx context.Context, user, channel string) error { + req := UnsubscribeRequest{ + User: user, + Channel: channel, + } + return c.callAPI(ctx, "unsubscribe", req, nil) +} + +// Disconnect 強制斷開用戶連線 +func (c *Client) Disconnect(ctx context.Context, user string) error { + req := DisconnectRequest{ + User: user, + } + return c.callAPI(ctx, "disconnect", req, nil) +} + +// DisconnectWithCode 強制斷開用戶連線(帶斷開代碼和原因) +func (c *Client) DisconnectWithCode(ctx context.Context, user string, code uint32, reason string) error { + req := DisconnectWithCodeRequest{ + User: user, + Disconnect: DisconnectInfo{ + Code: code, + Reason: reason, + }, + } + return c.callAPI(ctx, "disconnect", req, nil) +} + + +// ==================== 在線狀態方法 ==================== + +// Presence 獲取頻道在線用戶 +func (c *Client) Presence(ctx context.Context, channel string) (*PresenceResult, error) { + req := PresenceRequest{ + Channel: channel, + } + var result PresenceResult + if err := c.callAPI(ctx, "presence", req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// PresenceStats 獲取頻道在線統計 +func (c *Client) PresenceStats(ctx context.Context, channel string) (*PresenceStatsResult, error) { + req := PresenceStatsRequest{ + Channel: channel, + } + var result PresenceStatsResult + if err := c.callAPI(ctx, "presence_stats", req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// ==================== 歷史訊息方法 ==================== + +// History 獲取頻道歷史訊息 +func (c *Client) History(ctx context.Context, channel string, limit int) (*HistoryResult, error) { + req := HistoryRequest{ + Channel: channel, + Limit: limit, + } + var result HistoryResult + if err := c.callAPI(ctx, "history", req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// HistoryReverse 獲取頻道歷史訊息(倒序) +func (c *Client) HistoryReverse(ctx context.Context, channel string, limit int) (*HistoryResult, error) { + req := HistoryRequest{ + Channel: channel, + Limit: limit, + Reverse: true, + } + var result HistoryResult + if err := c.callAPI(ctx, "history", req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// ==================== 頻道管理方法 ==================== + +// Channels 獲取所有活躍頻道 +func (c *Client) Channels(ctx context.Context) (*ChannelsResult, error) { + req := ChannelsRequest{} + var result ChannelsResult + if err := c.callAPI(ctx, "channels", req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// ChannelsWithPattern 獲取匹配模式的活躍頻道 +func (c *Client) ChannelsWithPattern(ctx context.Context, pattern string) (*ChannelsResult, error) { + req := ChannelsRequest{ + Pattern: pattern, + } + var result ChannelsResult + if err := c.callAPI(ctx, "channels", req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// ==================== 伺服器狀態方法 ==================== + +// Info 獲取伺服器狀態資訊 +func (c *Client) Info(ctx context.Context) (*InfoResult, error) { + req := InfoRequest{} + var result InfoResult + if err := c.callAPI(ctx, "info", req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// Ping 檢查伺服器是否健康 +func (c *Client) Ping(ctx context.Context) error { + _, err := c.Info(ctx) + return err +} + +// ==================== 內部方法 ==================== + +// callAPI 調用 Centrifugo API +func (c *Client) callAPI(ctx context.Context, method string, params interface{}, result interface{}) error { + body, err := json.Marshal(params) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/api/%s", c.apiURL, method) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + if c.apiKey != "" { + httpReq.Header.Set("Authorization", "apikey "+c.apiKey) + } + + resp, err := c.client.Do(httpReq) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("centrifugo returned status %d: %s", resp.StatusCode, string(respBody)) + } + + // 解析響應 + var apiResp APIResponse + if result != nil { + apiResp.Result = result + } + + if err := json.Unmarshal(respBody, &apiResp); err != nil { + return fmt.Errorf("failed to unmarshal response: %w", err) + } + + if apiResp.Error != nil { + return apiResp.Error + } + + return nil +} diff --git a/pkg/library/centrifugo/online.go b/pkg/library/centrifugo/online.go new file mode 100644 index 0000000..691a904 --- /dev/null +++ b/pkg/library/centrifugo/online.go @@ -0,0 +1,175 @@ +package centrifugo + +import ( + "context" + "fmt" + "time" +) + +// OnlineStatus 在線狀態 +type OnlineStatus struct { + UserID string `json:"user_id"` + IsOnline bool `json:"is_online"` + LastSeenAt time.Time `json:"last_seen_at,omitempty"` + Clients int `json:"clients,omitempty"` // 連線數(可能多個設備) +} + +// OnlineStore 在線狀態存儲介面 +// 可以用 Redis、Memory 或其他存儲實作 +type OnlineStore interface { + // SetOnline 設置用戶在線 + SetOnline(ctx context.Context, userID string, ttl time.Duration) error + // SetOffline 設置用戶離線 + SetOffline(ctx context.Context, userID string) error + // IsOnline 檢查用戶是否在線 + IsOnline(ctx context.Context, userID string) (bool, error) + // GetOnlineUsers 獲取在線用戶列表 + GetOnlineUsers(ctx context.Context, userIDs []string) (map[string]bool, error) + // IncrClient 增加用戶連線數 + IncrClient(ctx context.Context, userID string) (int64, error) + // DecrClient 減少用戶連線數 + DecrClient(ctx context.Context, userID string) (int64, error) +} + +// OnlineManager 在線狀態管理器 +// 結合 Redis 存儲和 Centrifugo Presence API 提供在線狀態追蹤 +type OnlineManager struct { + client *Client + store OnlineStore + ttl time.Duration +} + +// NewOnlineManager 創建在線狀態管理器 +// store 可以為 nil,此時只使用 Centrifugo Presence API +func NewOnlineManager(client *Client, store OnlineStore) *OnlineManager { + return &OnlineManager{ + client: client, + store: store, + ttl: 5 * time.Minute, // 預設 5 分鐘過期 + } +} + +// NewOnlineManagerWithTTL 創建帶 TTL 的在線狀態管理器 +func NewOnlineManagerWithTTL(client *Client, store OnlineStore, ttl time.Duration) *OnlineManager { + return &OnlineManager{ + client: client, + store: store, + ttl: ttl, + } +} + +// ==================== 連線事件處理 ==================== + +// HandleConnect 處理用戶連線事件(用於 Centrifugo Connect Proxy) +func (m *OnlineManager) HandleConnect(ctx context.Context, userID string) error { + if m.store == nil { + return nil + } + + // 增加連線數 + count, err := m.store.IncrClient(ctx, userID) + if err != nil { + return fmt.Errorf("failed to incr client: %w", err) + } + + // 如果是第一個連線,設置在線狀態 + if count == 1 { + if err := m.store.SetOnline(ctx, userID, m.ttl); err != nil { + return fmt.Errorf("failed to set online: %w", err) + } + } + + return nil +} + +// HandleDisconnect 處理用戶斷線事件(用於 Centrifugo Disconnect Proxy) +func (m *OnlineManager) HandleDisconnect(ctx context.Context, userID string) error { + if m.store == nil { + return nil + } + + // 減少連線數 + count, err := m.store.DecrClient(ctx, userID) + if err != nil { + return fmt.Errorf("failed to decr client: %w", err) + } + + // 如果沒有連線了,設置離線狀態 + if count <= 0 { + if err := m.store.SetOffline(ctx, userID); err != nil { + return fmt.Errorf("failed to set offline: %w", err) + } + } + + return nil +} + +// ==================== 在線狀態查詢 ==================== + +// IsUserOnline 檢查用戶是否在線(使用 Store) +func (m *OnlineManager) IsUserOnline(ctx context.Context, userID string) (bool, error) { + if m.store == nil { + return false, ErrOnlineStoreNotConfigured + } + return m.store.IsOnline(ctx, userID) +} + +// GetUsersOnlineStatus 批量獲取用戶在線狀態 +func (m *OnlineManager) GetUsersOnlineStatus(ctx context.Context, userIDs []string) (map[string]bool, error) { + if m.store == nil { + return nil, ErrOnlineStoreNotConfigured + } + return m.store.GetOnlineUsers(ctx, userIDs) +} + +// RefreshOnline 刷新用戶在線狀態(用於心跳) +func (m *OnlineManager) RefreshOnline(ctx context.Context, userID string) error { + if m.store == nil { + return nil + } + return m.store.SetOnline(ctx, userID, m.ttl) +} + +// ==================== Centrifugo Presence API ==================== + +// IsUserInChannel 檢查用戶是否在指定頻道中(使用 Centrifugo Presence) +func (m *OnlineManager) IsUserInChannel(ctx context.Context, userID, channel string) (bool, error) { + presence, err := m.client.Presence(ctx, channel) + if err != nil { + return false, err + } + + for _, info := range presence.Presence { + if info.User == userID { + return true, nil + } + } + + return false, nil +} + +// GetChannelOnlineUsers 獲取頻道中的在線用戶(使用 Centrifugo Presence) +func (m *OnlineManager) GetChannelOnlineUsers(ctx context.Context, channel string) ([]string, error) { + presence, err := m.client.Presence(ctx, channel) + if err != nil { + return nil, err + } + + // 去重(一個用戶可能有多個連線) + userMap := make(map[string]bool) + for _, info := range presence.Presence { + userMap[info.User] = true + } + + users := make([]string, 0, len(userMap)) + for userID := range userMap { + users = append(users, userID) + } + + return users, nil +} + +// GetChannelStats 獲取頻道在線統計 +func (m *OnlineManager) GetChannelStats(ctx context.Context, channel string) (*PresenceStatsResult, error) { + return m.client.PresenceStats(ctx, channel) +} diff --git a/pkg/library/centrifugo/online_redis.go b/pkg/library/centrifugo/online_redis.go new file mode 100644 index 0000000..ddcb941 --- /dev/null +++ b/pkg/library/centrifugo/online_redis.go @@ -0,0 +1,141 @@ +package centrifugo + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/zeromicro/go-zero/core/stores/redis" +) + +// RedisOnlineStore 使用 Redis 實作的在線狀態存儲 +type RedisOnlineStore struct { + client *redis.Redis + keyPrefix string +} + +// NewRedisOnlineStore 創建 Redis 在線狀態存儲 +func NewRedisOnlineStore(client *redis.Redis) *RedisOnlineStore { + return &RedisOnlineStore{ + client: client, + keyPrefix: "online:", + } +} + +// NewRedisOnlineStoreWithPrefix 創建帶自定義前綴的 Redis 在線狀態存儲 +func NewRedisOnlineStoreWithPrefix(client *redis.Redis, prefix string) *RedisOnlineStore { + return &RedisOnlineStore{ + client: client, + keyPrefix: prefix, + } +} + +// key 生成 Redis key +func (s *RedisOnlineStore) key(userID string) string { + return s.keyPrefix + userID +} + +// clientCountKey 生成連線數 key +func (s *RedisOnlineStore) clientCountKey(userID string) string { + return s.keyPrefix + "clients:" + userID +} + +// SetOnline 設置用戶在線 +func (s *RedisOnlineStore) SetOnline(ctx context.Context, userID string, ttl time.Duration) error { + return s.client.SetexCtx(ctx, s.key(userID), fmt.Sprintf("%d", time.Now().Unix()), int(ttl.Seconds())) +} + +// SetOffline 設置用戶離線 +func (s *RedisOnlineStore) SetOffline(ctx context.Context, userID string) error { + // 刪除在線狀態 + _, err := s.client.DelCtx(ctx, s.key(userID)) + if err != nil { + return err + } + // 刪除連線數 + _, err = s.client.DelCtx(ctx, s.clientCountKey(userID)) + return err +} + +// IsOnline 檢查用戶是否在線 +func (s *RedisOnlineStore) IsOnline(ctx context.Context, userID string) (bool, error) { + return s.client.ExistsCtx(ctx, s.key(userID)) +} + +// GetOnlineUsers 批量獲取在線用戶狀態 +func (s *RedisOnlineStore) GetOnlineUsers(ctx context.Context, userIDs []string) (map[string]bool, error) { + if len(userIDs) == 0 { + return make(map[string]bool), nil + } + + result := make(map[string]bool, len(userIDs)) + for _, userID := range userIDs { + exists, err := s.client.ExistsCtx(ctx, s.key(userID)) + if err != nil { + return nil, fmt.Errorf("failed to check online status for %s: %w", userID, err) + } + result[userID] = exists + } + + return result, nil +} + +// IncrClient 增加用戶連線數 +func (s *RedisOnlineStore) IncrClient(ctx context.Context, userID string) (int64, error) { + count, err := s.client.IncrCtx(ctx, s.clientCountKey(userID)) + if err != nil { + return 0, err + } + return int64(count), nil +} + +// DecrClient 減少用戶連線數 +func (s *RedisOnlineStore) DecrClient(ctx context.Context, userID string) (int64, error) { + count, err := s.client.DecrCtx(ctx, s.clientCountKey(userID)) + if err != nil { + return 0, err + } + + // 確保不會變成負數 + if count < 0 { + _ = s.client.SetCtx(ctx, s.clientCountKey(userID), "0") + return 0, nil + } + + return int64(count), nil +} + +// GetClientCount 獲取用戶連線數 +func (s *RedisOnlineStore) GetClientCount(ctx context.Context, userID string) (int64, error) { + val, err := s.client.GetCtx(ctx, s.clientCountKey(userID)) + if err != nil { + return 0, err + } + if val == "" { + return 0, nil + } + return strconv.ParseInt(val, 10, 64) +} + +// GetAllOnlineUserIDs 獲取所有在線用戶 ID +// 注意:此方法使用 KEYS 命令,在大規模生產環境中可能有性能問題 +// 建議在需要時使用 Centrifugo Presence API 替代 +func (s *RedisOnlineStore) GetAllOnlineUserIDs(ctx context.Context) ([]string, error) { + // 使用 KEYS 查找所有在線用戶(排除 clients: 開頭的 key) + pattern := s.keyPrefix + "[^c]*" + keys, err := s.client.KeysCtx(ctx, pattern) + if err != nil { + return nil, err + } + + userIDs := make([]string, 0, len(keys)) + prefixLen := len(s.keyPrefix) + for _, key := range keys { + if len(key) > prefixLen { + userIDs = append(userIDs, key[prefixLen:]) + } + } + + return userIDs, nil +} diff --git a/pkg/library/centrifugo/online_redis_test.go b/pkg/library/centrifugo/online_redis_test.go new file mode 100644 index 0000000..8cabae3 --- /dev/null +++ b/pkg/library/centrifugo/online_redis_test.go @@ -0,0 +1,272 @@ +package centrifugo + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zeromicro/go-zero/core/stores/redis" +) + +func setupOnlineTestRedis(t *testing.T) (*redis.Redis, func()) { + mr, err := miniredis.Run() + require.NoError(t, err) + + rds, err := redis.NewRedis(redis.RedisConf{ + Host: mr.Addr(), + Type: "node", + }) + require.NoError(t, err) + + return rds, func() { + mr.Close() + } +} + +func TestNewRedisOnlineStore(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + + assert.NotNil(t, store) + assert.Equal(t, "online:", store.keyPrefix) +} + +func TestNewRedisOnlineStoreWithPrefix(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStoreWithPrefix(rds, "custom:online:") + + assert.NotNil(t, store) + assert.Equal(t, "custom:online:", store.keyPrefix) +} + +func TestSetOnline(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + userID := "user-123" + ttl := 5 * time.Minute + + // 設置在線 + err := store.SetOnline(ctx, userID, ttl) + require.NoError(t, err) + + // 檢查是否在線 + online, err := store.IsOnline(ctx, userID) + require.NoError(t, err) + assert.True(t, online) +} + +func TestSetOffline(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + userID := "user-123" + + // 先設置在線 + err := store.SetOnline(ctx, userID, 5*time.Minute) + require.NoError(t, err) + + // 設置離線 + err = store.SetOffline(ctx, userID) + require.NoError(t, err) + + // 檢查是否離線 + online, err := store.IsOnline(ctx, userID) + require.NoError(t, err) + assert.False(t, online) +} + +func TestIsOnline_NotOnline(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + // 未設置在線的用戶應該返回 false + online, err := store.IsOnline(ctx, "non-existent-user") + require.NoError(t, err) + assert.False(t, online) +} + +func TestGetOnlineUsers(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + // 設置一些用戶在線 + err := store.SetOnline(ctx, "user-1", 5*time.Minute) + require.NoError(t, err) + err = store.SetOnline(ctx, "user-3", 5*time.Minute) + require.NoError(t, err) + + // 批量獲取在線狀態 + userIDs := []string{"user-1", "user-2", "user-3"} + status, err := store.GetOnlineUsers(ctx, userIDs) + require.NoError(t, err) + + assert.True(t, status["user-1"]) + assert.False(t, status["user-2"]) + assert.True(t, status["user-3"]) +} + +func TestGetOnlineUsers_Empty(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + // 空列表應該返回空 map + status, err := store.GetOnlineUsers(ctx, []string{}) + require.NoError(t, err) + assert.Empty(t, status) +} + +func TestIncrClient(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + userID := "user-123" + + // 第一次增加 + count, err := store.IncrClient(ctx, userID) + require.NoError(t, err) + assert.Equal(t, int64(1), count) + + // 第二次增加 + count, err = store.IncrClient(ctx, userID) + require.NoError(t, err) + assert.Equal(t, int64(2), count) + + // 獲取連線數 + count, err = store.GetClientCount(ctx, userID) + require.NoError(t, err) + assert.Equal(t, int64(2), count) +} + +func TestDecrClient(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + userID := "user-123" + + // 先增加到 2 + _, err := store.IncrClient(ctx, userID) + require.NoError(t, err) + _, err = store.IncrClient(ctx, userID) + require.NoError(t, err) + + // 減少一次 + count, err := store.DecrClient(ctx, userID) + require.NoError(t, err) + assert.Equal(t, int64(1), count) + + // 再減少一次 + count, err = store.DecrClient(ctx, userID) + require.NoError(t, err) + assert.Equal(t, int64(0), count) +} + +func TestDecrClient_NegativeProtection(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + userID := "user-123" + + // 直接減少(沒有先增加) + count, err := store.DecrClient(ctx, userID) + require.NoError(t, err) + assert.Equal(t, int64(0), count) // 應該被保護為 0 + + // 再次減少 + count, err = store.DecrClient(ctx, userID) + require.NoError(t, err) + assert.Equal(t, int64(0), count) // 仍然是 0 +} + +func TestGetClientCount_NoClient(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + // 未設置連線數的用戶應該返回 0 + count, err := store.GetClientCount(ctx, "non-existent-user") + require.NoError(t, err) + assert.Equal(t, int64(0), count) +} + +func TestSetOffline_ClearsClientCount(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + ctx := context.Background() + + userID := "user-123" + + // 設置在線和連線數 + err := store.SetOnline(ctx, userID, 5*time.Minute) + require.NoError(t, err) + _, err = store.IncrClient(ctx, userID) + require.NoError(t, err) + _, err = store.IncrClient(ctx, userID) + require.NoError(t, err) + + // 設置離線 + err = store.SetOffline(ctx, userID) + require.NoError(t, err) + + // 連線數也應該被清除 + count, err := store.GetClientCount(ctx, userID) + require.NoError(t, err) + assert.Equal(t, int64(0), count) +} + +func TestKeyGeneration_OnlineStore(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStoreWithPrefix(rds, "test:") + + // 測試 key 生成 + assert.Equal(t, "test:user-123", store.key("user-123")) + assert.Equal(t, "test:clients:user-456", store.clientCountKey("user-456")) +} + +func TestOnlineStore_ImplementsInterface(t *testing.T) { + rds, cleanup := setupOnlineTestRedis(t) + defer cleanup() + + store := NewRedisOnlineStore(rds) + + // 確保 RedisOnlineStore 實現了 OnlineStore 介面 + var _ OnlineStore = store +} + diff --git a/pkg/library/centrifugo/token.go b/pkg/library/centrifugo/token.go new file mode 100644 index 0000000..d65127d --- /dev/null +++ b/pkg/library/centrifugo/token.go @@ -0,0 +1,203 @@ +package centrifugo + +import ( + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +// TokenConfig JWT Token 配置 +type TokenConfig struct { + // Secret 用於簽名的密鑰(與 Centrifugo 配置的 token_hmac_secret_key 一致) + Secret string + // ExpireIn Token 過期時間(預設 1 小時) + ExpireIn time.Duration +} + +// TokenGenerator JWT Token 生成器 +type TokenGenerator struct { + config TokenConfig +} + +// NewTokenGenerator 創建新的 Token 生成器 +func NewTokenGenerator(secret string) *TokenGenerator { + return &TokenGenerator{ + config: TokenConfig{ + Secret: secret, + ExpireIn: time.Hour, + }, + } +} + +// NewTokenGeneratorWithConfig 創建使用自定義配置的 Token 生成器 +func NewTokenGeneratorWithConfig(config TokenConfig) *TokenGenerator { + if config.ExpireIn == 0 { + config.ExpireIn = time.Hour + } + return &TokenGenerator{ + config: config, + } +} + +// ConnectionClaims 連線 Token 的 Claims +type ConnectionClaims struct { + jwt.RegisteredClaims + // Sub 用戶 ID(必填) + Sub string `json:"sub"` + // Info 用戶資訊(可選,會在 presence 中顯示) + Info map[string]interface{} `json:"info,omitempty"` + // Channels 自動訂閱的頻道列表(可選) + Channels []string `json:"channels,omitempty"` + // TokenVersion Token 版本(用於批量撤銷) + TokenVersion int64 `json:"tv,omitempty"` +} + +// SubscriptionClaims 訂閱 Token 的 Claims(用於私有頻道) +type SubscriptionClaims struct { + jwt.RegisteredClaims + // Sub 用戶 ID(必填) + Sub string `json:"sub"` + // Channel 頻道名稱(必填) + Channel string `json:"channel"` + // Info 頻道特定的用戶資訊(可選) + Info map[string]interface{} `json:"info,omitempty"` + // TokenVersion Token 版本(用於批量撤銷) + TokenVersion int64 `json:"tv,omitempty"` +} + +// ConnectionTokenOptions 連線 Token 選項 +type ConnectionTokenOptions struct { + // UserID 用戶 ID(必填) + UserID string + // Info 用戶資訊(可選) + Info map[string]interface{} + // Channels 自動訂閱的頻道列表(可選) + Channels []string + // ExpireAt 自定義過期時間(可選,為空則使用預設) + ExpireAt *time.Time + // TokenVersion Token 版本(可選,用於黑名單機制) + TokenVersion int64 +} + +// SubscriptionTokenOptions 訂閱 Token 選項 +type SubscriptionTokenOptions struct { + // UserID 用戶 ID(必填) + UserID string + // Channel 頻道名稱(必填) + Channel string + // Info 頻道特定的用戶資訊(可選) + Info map[string]interface{} + // ExpireAt 自定義過期時間(可選,為空則使用預設) + ExpireAt *time.Time + // TokenVersion Token 版本(可選,用於黑名單機制) + TokenVersion int64 +} + +// TokenResult 生成 Token 的結果 +type TokenResult struct { + Token string // JWT Token 字串 + JTI string // JWT ID(用於撤銷單一 Token) + ExpiresAt time.Time // 過期時間 +} + +// GenerateConnectionToken 生成連線 Token +// 用於前端建立 WebSocket 連線時的身份驗證 +func (g *TokenGenerator) GenerateConnectionToken(opts ConnectionTokenOptions) (string, error) { + result, err := g.GenerateConnectionTokenWithJTI(opts) + if err != nil { + return "", err + } + return result.Token, nil +} + +// GenerateConnectionTokenWithJTI 生成連線 Token 並返回 JTI +// 用於需要支援單一 Token 撤銷的場景 +func (g *TokenGenerator) GenerateConnectionTokenWithJTI(opts ConnectionTokenOptions) (*TokenResult, error) { + now := time.Now() + expireAt := now.Add(g.config.ExpireIn) + if opts.ExpireAt != nil { + expireAt = *opts.ExpireAt + } + + // 生成唯一的 JTI + jti := uuid.New().String() + + // 如果沒有指定 TokenVersion,使用當前時間戳 + tokenVersion := opts.TokenVersion + if tokenVersion == 0 { + tokenVersion = now.UnixNano() + } + + claims := ConnectionClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ID: jti, + Subject: opts.UserID, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(expireAt), + }, + Sub: opts.UserID, + Info: opts.Info, + Channels: opts.Channels, + TokenVersion: tokenVersion, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := token.SignedString([]byte(g.config.Secret)) + if err != nil { + return nil, err + } + + return &TokenResult{ + Token: tokenStr, + JTI: jti, + ExpiresAt: expireAt, + }, nil +} + +// GenerateSubscriptionToken 生成訂閱 Token +// 用於訂閱私有頻道時的身份驗證 +func (g *TokenGenerator) GenerateSubscriptionToken(opts SubscriptionTokenOptions) (string, error) { + now := time.Now() + expireAt := now.Add(g.config.ExpireIn) + if opts.ExpireAt != nil { + expireAt = *opts.ExpireAt + } + + claims := SubscriptionClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: opts.UserID, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(expireAt), + }, + Sub: opts.UserID, + Channel: opts.Channel, + Info: opts.Info, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(g.config.Secret)) +} + +// GenerateAnonymousToken 生成匿名連線 Token +// 用於允許匿名用戶連線(但仍需要驗證) +func (g *TokenGenerator) GenerateAnonymousToken() (string, error) { + return g.GenerateConnectionToken(ConnectionTokenOptions{ + UserID: "", // 空的 UserID 表示匿名用戶 + }) +} + +// QuickConnectionToken 快速生成連線 Token(只需用戶 ID) +func (g *TokenGenerator) QuickConnectionToken(userID string) (string, error) { + return g.GenerateConnectionToken(ConnectionTokenOptions{ + UserID: userID, + }) +} + +// QuickSubscriptionToken 快速生成訂閱 Token(只需用戶 ID 和頻道) +func (g *TokenGenerator) QuickSubscriptionToken(userID, channel string) (string, error) { + return g.GenerateSubscriptionToken(SubscriptionTokenOptions{ + UserID: userID, + Channel: channel, + }) +} diff --git a/pkg/library/centrifugo/token_test.go b/pkg/library/centrifugo/token_test.go new file mode 100644 index 0000000..aec198d --- /dev/null +++ b/pkg/library/centrifugo/token_test.go @@ -0,0 +1,263 @@ +package centrifugo + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTokenGenerator(t *testing.T) { + secret := "test-secret" + gen := NewTokenGenerator(secret) + + assert.NotNil(t, gen) + assert.Equal(t, secret, gen.config.Secret) + assert.Equal(t, time.Hour, gen.config.ExpireIn) +} + +func TestNewTokenGeneratorWithConfig(t *testing.T) { + config := TokenConfig{ + Secret: "custom-secret", + ExpireIn: 24 * time.Hour, + } + gen := NewTokenGeneratorWithConfig(config) + + assert.NotNil(t, gen) + assert.Equal(t, config.Secret, gen.config.Secret) + assert.Equal(t, config.ExpireIn, gen.config.ExpireIn) +} + +func TestNewTokenGeneratorWithConfig_DefaultExpire(t *testing.T) { + config := TokenConfig{ + Secret: "test-secret", + ExpireIn: 0, // 應該使用預設值 + } + gen := NewTokenGeneratorWithConfig(config) + + assert.Equal(t, time.Hour, gen.config.ExpireIn) +} + +func TestQuickConnectionToken(t *testing.T) { + gen := NewTokenGenerator("test-secret") + userID := "user-123" + + token, err := gen.QuickConnectionToken(userID) + + require.NoError(t, err) + assert.NotEmpty(t, token) + + // 驗證 Token 內容 + claims := &ConnectionClaims{} + parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return []byte("test-secret"), nil + }) + + require.NoError(t, err) + assert.True(t, parsedToken.Valid) + assert.Equal(t, userID, claims.Sub) + assert.NotNil(t, claims.ExpiresAt) + assert.NotNil(t, claims.IssuedAt) +} + +func TestGenerateConnectionToken(t *testing.T) { + gen := NewTokenGenerator("test-secret") + + tests := []struct { + name string + opts ConnectionTokenOptions + checkFn func(t *testing.T, claims *ConnectionClaims) + }{ + { + name: "basic token", + opts: ConnectionTokenOptions{ + UserID: "user-123", + }, + checkFn: func(t *testing.T, claims *ConnectionClaims) { + assert.Equal(t, "user-123", claims.Sub) + assert.Nil(t, claims.Info) + assert.Nil(t, claims.Channels) + }, + }, + { + name: "token with info", + opts: ConnectionTokenOptions{ + UserID: "user-456", + Info: map[string]interface{}{ + "name": "Daniel", + "role": "admin", + }, + }, + checkFn: func(t *testing.T, claims *ConnectionClaims) { + assert.Equal(t, "user-456", claims.Sub) + assert.NotNil(t, claims.Info) + assert.Equal(t, "Daniel", claims.Info["name"]) + assert.Equal(t, "admin", claims.Info["role"]) + }, + }, + { + name: "token with channels", + opts: ConnectionTokenOptions{ + UserID: "user-789", + Channels: []string{"chat:room-1", "chat:room-2"}, + }, + checkFn: func(t *testing.T, claims *ConnectionClaims) { + assert.Equal(t, "user-789", claims.Sub) + assert.Equal(t, []string{"chat:room-1", "chat:room-2"}, claims.Channels) + }, + }, + { + name: "token with custom expire", + opts: ConnectionTokenOptions{ + UserID: "user-abc", + ExpireAt: ptrTime(time.Now().Add(48 * time.Hour)), + }, + checkFn: func(t *testing.T, claims *ConnectionClaims) { + assert.Equal(t, "user-abc", claims.Sub) + // 檢查過期時間大約在 48 小時後 + expireTime := claims.ExpiresAt.Time + assert.True(t, expireTime.After(time.Now().Add(47*time.Hour))) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, err := gen.GenerateConnectionToken(tt.opts) + require.NoError(t, err) + assert.NotEmpty(t, token) + + // 解析 Token + claims := &ConnectionClaims{} + parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return []byte("test-secret"), nil + }) + + require.NoError(t, err) + assert.True(t, parsedToken.Valid) + tt.checkFn(t, claims) + }) + } +} + +func TestQuickSubscriptionToken(t *testing.T) { + gen := NewTokenGenerator("test-secret") + userID := "user-123" + channel := "private:room-456" + + token, err := gen.QuickSubscriptionToken(userID, channel) + + require.NoError(t, err) + assert.NotEmpty(t, token) + + // 驗證 Token 內容 + claims := &SubscriptionClaims{} + parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return []byte("test-secret"), nil + }) + + require.NoError(t, err) + assert.True(t, parsedToken.Valid) + assert.Equal(t, userID, claims.Sub) + assert.Equal(t, channel, claims.Channel) +} + +func TestGenerateSubscriptionToken(t *testing.T) { + gen := NewTokenGenerator("test-secret") + + opts := SubscriptionTokenOptions{ + UserID: "user-123", + Channel: "private:room-456", + Info: map[string]interface{}{ + "role": "moderator", + }, + } + + token, err := gen.GenerateSubscriptionToken(opts) + + require.NoError(t, err) + assert.NotEmpty(t, token) + + // 驗證 Token 內容 + claims := &SubscriptionClaims{} + parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return []byte("test-secret"), nil + }) + + require.NoError(t, err) + assert.True(t, parsedToken.Valid) + assert.Equal(t, opts.UserID, claims.Sub) + assert.Equal(t, opts.Channel, claims.Channel) + assert.Equal(t, "moderator", claims.Info["role"]) +} + +func TestGenerateAnonymousToken(t *testing.T) { + gen := NewTokenGenerator("test-secret") + + token, err := gen.GenerateAnonymousToken() + + require.NoError(t, err) + assert.NotEmpty(t, token) + + // 驗證 Token 內容 + claims := &ConnectionClaims{} + parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return []byte("test-secret"), nil + }) + + require.NoError(t, err) + assert.True(t, parsedToken.Valid) + assert.Equal(t, "", claims.Sub) // 匿名用戶的 UserID 為空 +} + +func TestTokenExpiration(t *testing.T) { + // 創建一個很短過期時間的生成器 + gen := NewTokenGeneratorWithConfig(TokenConfig{ + Secret: "test-secret", + ExpireIn: 1 * time.Second, + }) + + token, err := gen.QuickConnectionToken("user-123") + require.NoError(t, err) + + // 立即驗證應該成功 + claims := &ConnectionClaims{} + parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return []byte("test-secret"), nil + }) + require.NoError(t, err) + assert.True(t, parsedToken.Valid) + + // 等待過期 + time.Sleep(2 * time.Second) + + // 過期後驗證應該失敗 + claims2 := &ConnectionClaims{} + _, err = jwt.ParseWithClaims(token, claims2, func(token *jwt.Token) (interface{}, error) { + return []byte("test-secret"), nil + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token is expired") +} + +func TestTokenWithWrongSecret(t *testing.T) { + gen := NewTokenGenerator("correct-secret") + + token, err := gen.QuickConnectionToken("user-123") + require.NoError(t, err) + + // 使用錯誤的密鑰驗證 + claims := &ConnectionClaims{} + _, err = jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return []byte("wrong-secret"), nil + }) + + assert.Error(t, err) +} + +// 輔助函數 +func ptrTime(t time.Time) *time.Time { + return &t +} diff --git a/pkg/utils/time.go b/pkg/utils/time.go new file mode 100644 index 0000000..d0a4841 --- /dev/null +++ b/pkg/utils/time.go @@ -0,0 +1,14 @@ +package utils + +import "time" + +// GetBucketDay 取得 bucket_day(yyyyMMdd 格式) +func GetBucketDay(t time.Time) string { + return t.Format("20060102") +} + +// GetTodayBucketDay 取得今天的 bucket_day +func GetTodayBucketDay() string { + return GetBucketDay(time.Now()) +} + diff --git a/pkg/utils/uuid.go b/pkg/utils/uuid.go new file mode 100644 index 0000000..7e7b5fd --- /dev/null +++ b/pkg/utils/uuid.go @@ -0,0 +1,20 @@ +package utils + +import ( + "github.com/google/uuid" +) + +//// GenerateUID 生成匿名 UID +//func GenerateUID() string { +// return fmt.Sprintf("%s%s", consts.AnonUIDPrefix, uuid.New().String()[:8]) +//} +// +//// GenerateRoomID 生成房間 ID +//func GenerateRoomID() string { +// return fmt.Sprintf("%s%s", consts.RoomIDPrefix, uuid.New().String()) +//} + +// GenerateMessageID 生成訊息 ID +func GenerateMessageID() string { + return uuid.New().String() +}