fix api server update version
This commit is contained in:
parent
08d1cf7069
commit
377b52515d
|
|
@ -0,0 +1,44 @@
|
|||
{
|
||||
"token_hmac_secret_key": "your-secret-key-change-in-production",
|
||||
"admin_password": "admin",
|
||||
"admin_secret": "admin-secret",
|
||||
"api_key": "api-key",
|
||||
"allowed_origins": [
|
||||
"*"
|
||||
],
|
||||
"log_level": "info",
|
||||
"log_handler": "stdout",
|
||||
"websocket_compression": true,
|
||||
"websocket_read_buffer_size": 1024,
|
||||
"websocket_write_buffer_size": 1024,
|
||||
"namespaces": [
|
||||
{
|
||||
"name": "default",
|
||||
"publish": true,
|
||||
"subscribe_to_publish": true,
|
||||
"presence": true,
|
||||
"join_leave": true,
|
||||
"history_size": 100,
|
||||
"history_ttl": "300s"
|
||||
},
|
||||
{
|
||||
"name": "room",
|
||||
"publish": true,
|
||||
"subscribe_to_publish": true,
|
||||
"presence": true,
|
||||
"join_leave": true,
|
||||
"history_size": 100,
|
||||
"history_ttl": "300s"
|
||||
},
|
||||
{
|
||||
"name": "user",
|
||||
"publish": false,
|
||||
"subscribe_to_publish": false,
|
||||
"presence": false,
|
||||
"join_leave": false,
|
||||
"history_size": 0,
|
||||
"history_ttl": "0s"
|
||||
}
|
||||
],
|
||||
"redis_address": "redis:6379"
|
||||
}
|
||||
|
|
@ -27,7 +27,6 @@ services:
|
|||
restart: always
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
minio:
|
||||
image: minio/minio
|
||||
container_name: minio
|
||||
|
|
@ -39,3 +38,36 @@ services:
|
|||
MINIO_ROOT_PASSWORD: minioadmin # Replace with your desired root password
|
||||
# MINIO_DEFAULT_BUCKETS: mybucket # Optional: Create a default bucket on startup
|
||||
command: server /data --console-address ":9001" # Start MinIO server and specify console address
|
||||
centrifugo:
|
||||
image: centrifugo/centrifugo:v5
|
||||
container_name: centrifugo
|
||||
restart: always
|
||||
ports:
|
||||
- "8000:8000" # HTTP API
|
||||
- "8001:8001" # WebSocket
|
||||
volumes:
|
||||
- ./centrifugo.json:/centrifugo/config.json:ro
|
||||
command: centrifugo --config=/centrifugo/config.json
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--quiet", "--tries=1", "--spider", "http://localhost:8000/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
depends_on:
|
||||
- redis
|
||||
cassandra:
|
||||
image: cassandra:5.0.4
|
||||
restart: always
|
||||
ports:
|
||||
- "9042:9042"
|
||||
environment:
|
||||
TZ: ${TIMEZONE:-UTC}
|
||||
MAX_HEAP_SIZE: 4G
|
||||
HEAP_NEWSIZE: 2G
|
||||
healthcheck:
|
||||
test: [ "CMD", "cqlsh", "-k", "sccflex" ]
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 12
|
||||
mem_limit: 8g # <--- 單機 docker-compose up 時建議明確加這行
|
||||
memswap_limit: 8g # <--- 關掉 swap
|
||||
|
|
@ -1259,7 +1259,7 @@
|
|||
"url": "https://localhost:8888"
|
||||
}
|
||||
],
|
||||
"x-date": "2025-11-12 14:59:58",
|
||||
"x-date": "2026-01-05 10:01:16",
|
||||
"x-description": "This is a go-doc generated swagger file.",
|
||||
"x-generator": "go-doc",
|
||||
"x-github": "https://github.com/danielchan-25/go-doc",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
CREATE TABLE messages (
|
||||
room_id uuid,
|
||||
bucket_day date,
|
||||
ts bigint,
|
||||
msg_id uuid,
|
||||
uid text,
|
||||
content text,
|
||||
PRIMARY KEY ((room_id, bucket_day), ts, msg_id)
|
||||
) WITH CLUSTERING ORDER BY (ts ASC);
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
CREATE TABLE message_dedup (
|
||||
room_id uuid,
|
||||
uid text,
|
||||
content_md5 text,
|
||||
bucket_sec bigint,
|
||||
PRIMARY KEY ((room_id, uid), bucket_sec, content_md5)
|
||||
) WITH default_time_to_live = 2;
|
||||
|
||||
1
go.mod
1
go.mod
|
|
@ -12,6 +12,7 @@ require (
|
|||
github.com/go-playground/validator/v10 v10.28.0
|
||||
github.com/gocql/gocql v1.7.0
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/matcornic/hermes/v2 v2.1.0
|
||||
github.com/minchao/go-mitake v1.0.0
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -101,6 +101,8 @@ github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
|||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
|
|
|
|||
|
|
@ -108,4 +108,20 @@ type Config struct {
|
|||
SecretKey string
|
||||
CloudFrontID string
|
||||
}
|
||||
|
||||
// Cassandra 配置
|
||||
Cassandra struct {
|
||||
Hosts []string
|
||||
Port int
|
||||
Keyspace string
|
||||
Username string
|
||||
Password string
|
||||
UseAuth bool
|
||||
}
|
||||
|
||||
// Centrifugo 配置
|
||||
Centrifugo struct {
|
||||
APIURL string
|
||||
APIKey string
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
package svc
|
||||
|
||||
import (
|
||||
"backend/internal/config"
|
||||
"backend/pkg/chat/domain/usecase"
|
||||
repo "backend/pkg/chat/repository"
|
||||
uc "backend/pkg/chat/usecase"
|
||||
"backend/pkg/library/cassandra"
|
||||
"backend/pkg/library/centrifugo"
|
||||
errs "backend/pkg/library/errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func MustMessageUseCase(c *config.Config, logger errs.Logger) usecase.MessageUseCase {
|
||||
// 初始化 Cassandra DB
|
||||
cassandraDB, err := initCassandraDB(c)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to initialize Cassandra DB: %v", err))
|
||||
}
|
||||
|
||||
// 初始化 Message Repository
|
||||
messageRepo := repo.MustMessageRepository(repo.MessageRepositoryParam{
|
||||
DB: cassandraDB,
|
||||
Keyspace: c.Cassandra.Keyspace,
|
||||
})
|
||||
|
||||
// 初始化 Room Repository
|
||||
roomRepo := repo.MustRoomRepository(repo.RoomRepositoryParam{
|
||||
DB: cassandraDB,
|
||||
Keyspace: c.Cassandra.Keyspace,
|
||||
})
|
||||
|
||||
// 初始化 Centrifugo Client
|
||||
msgClient := centrifugo.NewClientWithConfig(centrifugo.HighPerformanceConfig(c.Centrifugo.APIURL, c.Centrifugo.APIKey))
|
||||
|
||||
return uc.NewMessageUseCase(uc.MessageUseCaseParam{
|
||||
MessageRepo: messageRepo,
|
||||
RoomRepo: roomRepo,
|
||||
MsgClient: msgClient,
|
||||
Logger: logger,
|
||||
})
|
||||
}
|
||||
|
||||
// initCassandraDB 初始化 Cassandra 資料庫連接
|
||||
func initCassandraDB(c *config.Config) (*cassandra.DB, error) {
|
||||
if len(c.Cassandra.Hosts) == 0 {
|
||||
return nil, fmt.Errorf("cassandra hosts are required")
|
||||
}
|
||||
|
||||
opts := []cassandra.Option{
|
||||
cassandra.WithHosts(c.Cassandra.Hosts...),
|
||||
}
|
||||
|
||||
if c.Cassandra.Port > 0 {
|
||||
opts = append(opts, cassandra.WithPort(c.Cassandra.Port))
|
||||
}
|
||||
|
||||
if c.Cassandra.Keyspace != "" {
|
||||
opts = append(opts, cassandra.WithKeyspace(c.Cassandra.Keyspace))
|
||||
}
|
||||
|
||||
if c.Cassandra.UseAuth {
|
||||
opts = append(opts, cassandra.WithAuth(c.Cassandra.Username, c.Cassandra.Password))
|
||||
}
|
||||
|
||||
return cassandra.New(opts...)
|
||||
}
|
||||
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
|
||||
chatUC "backend/pkg/chat/domain/usecase"
|
||||
fileStorageUC "backend/pkg/fileStorage/domain/usecase"
|
||||
vi "backend/pkg/library/validator"
|
||||
memberUC "backend/pkg/member/domain/usecase"
|
||||
|
|
@ -31,6 +32,7 @@ type ServiceContext struct {
|
|||
UserRoleUC tokenUC.UserRoleUseCase
|
||||
DeliveryUC deliveryUC.DeliveryUseCase
|
||||
FileStorageUC fileStorageUC.FileStorageUseCase
|
||||
MessageUC chatUC.MessageUseCase
|
||||
Redis *redis.Redis
|
||||
Logger errs.Logger
|
||||
}
|
||||
|
|
@ -62,6 +64,7 @@ func NewServiceContext(c config.Config) *ServiceContext {
|
|||
Redis: rds,
|
||||
DeliveryUC: MustDeliveryUseCase(&c, lgr),
|
||||
FileStorageUC: MustS3Storage(&c, lgr),
|
||||
MessageUC: MustMessageUseCase(&c, lgr),
|
||||
Logger: lgr,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
package chat
|
||||
|
||||
// RoomRole 聊天室成員角色
|
||||
type RoomRole string
|
||||
|
||||
const (
|
||||
// RoomRoleMember 一般成員
|
||||
RoomRoleMember RoomRole = "member"
|
||||
// RoomRoleAdmin 管理員
|
||||
RoomRoleAdmin RoomRole = "admin"
|
||||
// RoomRoleOwner 擁有者
|
||||
RoomRoleOwner RoomRole = "owner"
|
||||
)
|
||||
|
||||
// String 返回角色的字串表示
|
||||
func (r RoomRole) String() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
// IsValid 檢查角色是否有效
|
||||
func (r RoomRole) IsValid() bool {
|
||||
switch r {
|
||||
case RoomRoleMember, RoomRoleAdmin, RoomRoleOwner:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsAdmin 檢查是否為管理員或擁有者
|
||||
func (r RoomRole) IsAdmin() bool {
|
||||
return r == RoomRoleAdmin || r == RoomRoleOwner
|
||||
}
|
||||
|
||||
// IsOwner 檢查是否為擁有者
|
||||
func (r RoomRole) IsOwner() bool {
|
||||
return r == RoomRoleOwner
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
package entity
|
||||
|
||||
import (
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Message 對應 Cassandra 的 messages_by_room 表
|
||||
// Primary Key: (room_id, bucket_day)
|
||||
// Clustering Key: ts DESC
|
||||
type Message struct {
|
||||
RoomID gocql.UUID `db:"room_id" partition_key:"true"`
|
||||
BucketDay string `db:"bucket_day" partition_key:"true"` // yyyyMMdd
|
||||
TS int64 `db:"ts" clustering_key:"true"` // timestamp
|
||||
UID string `db:"uid"`
|
||||
Content string `db:"content"`
|
||||
MsgID uuid.UUID `db:"msg_id"`
|
||||
}
|
||||
|
||||
// TableName 返回表名
|
||||
func (m Message) TableName() string {
|
||||
return "messages_by_room"
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
package entity
|
||||
|
||||
import (
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// MessageDedup 對應 Cassandra 的 message_dedup 表
|
||||
// Primary Key: ((room_id, uid), bucket_sec, content_md5)
|
||||
// TTL: 2 秒
|
||||
type MessageDedup struct {
|
||||
RoomID gocql.UUID `db:"room_id" partition_key:"true"`
|
||||
UID string `db:"uid" partition_key:"true"`
|
||||
BucketSec int64 `db:"bucket_sec" clustering_key:"true"` // Unix timestamp in seconds
|
||||
ContentMD5 string `db:"content_md5" clustering_key:"true"` // MD5 hash of content
|
||||
}
|
||||
|
||||
// TableName 返回表名
|
||||
func (m MessageDedup) TableName() string {
|
||||
return "message_dedup"
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
package entity
|
||||
|
||||
import "github.com/gocql/gocql"
|
||||
|
||||
// Room 對應 Cassandra 的 room 表
|
||||
// Primary Key: (room_id)
|
||||
// 設計說明:存儲聊天室本身的基本資訊
|
||||
type Room struct {
|
||||
RoomID gocql.UUID `db:"room_id" partition_key:"true"`
|
||||
Name string `db:"name"` // 聊天室名稱
|
||||
Status string `db:"status"` // 狀態:active, archived, deleted
|
||||
CreatedAt int64 `db:"created_at"` // 創建時間
|
||||
UpdatedAt int64 `db:"updated_at"` // 更新時間
|
||||
}
|
||||
|
||||
// TableName 返回表名
|
||||
func (r Room) TableName() string {
|
||||
return "room"
|
||||
}
|
||||
|
||||
// RoomMember 對應 Cassandra 的 room_member 表
|
||||
// Primary Key: (room_id)
|
||||
// Clustering Key: uid
|
||||
// 設計說明:
|
||||
// - room_id 作為 partition key,可以高效地查詢/刪除整個聊天室的所有成員
|
||||
// - uid 作為 clustering key,可以高效地查詢/刪除特定成員
|
||||
// - 支援三種操作:
|
||||
// 1. 刪除整個聊天室:刪除整個 partition(需要查詢後批量刪除或使用原生 CQL)
|
||||
// 2. 刪除特定成員:使用 Delete(roomID, uid)
|
||||
// 3. 查詢聊天室所有成員:使用 Query().Where(Eq("room_id", roomID)).Scan()
|
||||
type RoomMember struct {
|
||||
RoomID gocql.UUID `db:"room_id" partition_key:"true"`
|
||||
UID string `db:"uid" clustering_key:"true"`
|
||||
Role string `db:"role"` // Role 角色:member(一般成員)、admin(管理員)、owner(擁有者)等
|
||||
JoinedAt int64 `db:"joined_at"` // JoinedAt 加入時間(可選,用於記錄加入時間)
|
||||
}
|
||||
|
||||
// TableName 返回表名
|
||||
func (m RoomMember) TableName() string {
|
||||
return "room_member"
|
||||
}
|
||||
|
||||
// UserRoom 對應 Cassandra 的 user_room 表(反向查詢表)
|
||||
// Primary Key: (uid)
|
||||
// Clustering Key: room_id
|
||||
// 設計說明:
|
||||
// - uid 作為 partition key,可以高效地查詢某個用戶所在的所有聊天室
|
||||
// - room_id 作為 clustering key,可以高效地查詢/刪除特定關聯
|
||||
// - 這個表用於支援「查詢用戶在哪些聊天室中」的需求
|
||||
type UserRoom struct {
|
||||
UID string `db:"uid" partition_key:"true"`
|
||||
RoomID gocql.UUID `db:"room_id" clustering_key:"true"`
|
||||
JoinedAt int64 `db:"joined_at"` // 加入時間
|
||||
}
|
||||
|
||||
// TableName 返回表名
|
||||
func (u UserRoom) TableName() string {
|
||||
return "user_room"
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"backend/pkg/chat/domain/entity"
|
||||
"context"
|
||||
)
|
||||
|
||||
// MessageRepository 定義訊息相關的資料存取介面
|
||||
type MessageRepository interface {
|
||||
// Insert 插入訊息
|
||||
Insert(ctx context.Context, msg *entity.Message) error
|
||||
// ListMessages 查詢訊息列表(分頁)
|
||||
ListMessages(ctx context.Context, param ListMessagesReq) ([]entity.Message, error)
|
||||
// Count 計算符合條件的訊息總數
|
||||
Count(ctx context.Context, RoomID string) (int64, error)
|
||||
// CheckAndInsertDedup 檢查並插入去重記錄,如果已存在則返回 true(表示重複)
|
||||
CheckAndInsertDedup(ctx context.Context, param CheckDupReq) (bool, error)
|
||||
}
|
||||
|
||||
type SendMessageReq struct {
|
||||
RoomID string
|
||||
UID string
|
||||
Content string
|
||||
ClientMsgID string
|
||||
}
|
||||
|
||||
type ListMessagesReq struct {
|
||||
RoomID string
|
||||
BucketDay string
|
||||
PageSize int
|
||||
// LastTS 用於 cursor-based pagination,獲取 ts < LastTS 的訊息
|
||||
// 如果為 0,則獲取最新的訊息
|
||||
LastTS int64
|
||||
}
|
||||
|
||||
type CheckDupReq struct {
|
||||
RoomID string
|
||||
UID string
|
||||
BucketSec int64
|
||||
ContentMD5 string
|
||||
}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"backend/pkg/chat/domain/entity"
|
||||
"context"
|
||||
)
|
||||
|
||||
type RoomRepository interface {
|
||||
Room
|
||||
Member
|
||||
User
|
||||
}
|
||||
|
||||
// ListRoomsReq 查詢聊天室列表的請求參數
|
||||
type ListRoomsReq struct {
|
||||
Status string // 聊天室狀態(active, archived 等)
|
||||
PageSize int // 每頁大小
|
||||
LastID string // 用於 cursor-based pagination
|
||||
}
|
||||
|
||||
// CountRoomsReq 統計聊天室的請求參數
|
||||
type CountRoomsReq struct {
|
||||
Status string // 可選的狀態篩選
|
||||
}
|
||||
|
||||
type Room interface {
|
||||
Create(ctx context.Context, room *entity.Room) error // Create 創建聊天室
|
||||
RoomGet(ctx context.Context, roomID string) (*entity.Room, error) // Get 獲取聊天室資訊
|
||||
RoomUpdate(ctx context.Context, room *entity.Room) error // Update 更新聊天室資訊
|
||||
RoomDelete(ctx context.Context, roomID string) error // Delete 刪除聊天室(同時需要刪除相關的成員和訊息)
|
||||
RoomList(ctx context.Context, param ListRoomsReq) ([]entity.Room, error) // List 查詢聊天室列表(支援分頁和篩選)
|
||||
RoomCount(ctx context.Context, param CountRoomsReq) (int64, error) // Count 統計聊天室總數
|
||||
RoomExists(ctx context.Context, roomID string) (bool, error) // Exists 檢查聊天室是否存在
|
||||
RoomGetByID(ctx context.Context, roomIDs []string) ([]entity.Room, error) // 取得 Room by id
|
||||
}
|
||||
|
||||
type Member interface {
|
||||
Insert(ctx context.Context, member *entity.RoomMember) error // Insert 添加成員到聊天室
|
||||
Get(ctx context.Context, roomID, uid string) (*entity.RoomMember, error) // Get 獲取特定成員資訊
|
||||
AllMembers(ctx context.Context, roomID string) ([]entity.RoomMember, error) // AllMembers 查詢聊天室所有成員
|
||||
UpdateRole(ctx context.Context, member *entity.RoomMember) error // UpdateRole 更新成員資訊(例如更新角色)
|
||||
DeleteMember(ctx context.Context, roomID, uid string) error // DeleteMember 刪除特定成員(某人退出聊天室)
|
||||
DeleteRoom(ctx context.Context, roomID string) error // DeleteRoom 刪除整個聊天室的所有成員
|
||||
Count(ctx context.Context, roomID string) (int64, error) // Count 計算聊天室成員數量
|
||||
}
|
||||
|
||||
type User interface {
|
||||
GetUserRooms(ctx context.Context, uid string) ([]entity.UserRoom, error) // GetUserRooms 查詢用戶所在的所有聊天室
|
||||
CountUserRooms(ctx context.Context, uid string) (int64, error) // CountUserRooms 統計用戶所在的聊天室數量
|
||||
IsUserInRoom(ctx context.Context, uid, roomID string) (bool, error) // IsUserInRoom 檢查用戶是否在某個聊天室中
|
||||
}
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
package usecase
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// MessageUseCase 定義訊息相關的業務邏輯介面
|
||||
type MessageUseCase interface {
|
||||
// SendMessage 發送訊息
|
||||
SendMessage(ctx context.Context, param SendMessageReq) error
|
||||
// ListMessages 查詢訊息列表(分頁)
|
||||
ListMessages(ctx context.Context, req ListMessagesReq) ([]Message, int64, error)
|
||||
}
|
||||
|
||||
type SendMessageReq struct {
|
||||
RoomID string
|
||||
UID string
|
||||
Content string
|
||||
}
|
||||
|
||||
type ListMessagesReq struct {
|
||||
RoomID string
|
||||
UID string
|
||||
BucketDay string
|
||||
PageSize int64
|
||||
LastTS int64
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
RoomID string `json:"room_id"`
|
||||
BucketDay string `json:"bucket_day"` // yyyyMMdd
|
||||
TS int64 `json:"ts"` // timestamp
|
||||
UID string `json:"uid"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"backend/pkg/chat/domain/entity"
|
||||
"backend/pkg/chat/domain/repository"
|
||||
"backend/pkg/library/cassandra"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
type messageRepository struct {
|
||||
repo cassandra.Repository[entity.Message]
|
||||
dedupRepo cassandra.Repository[entity.MessageDedup]
|
||||
db *cassandra.DB
|
||||
keyspace string
|
||||
}
|
||||
|
||||
// MessageRepositoryParam 創建 MessageRepository 所需的參數
|
||||
type MessageRepositoryParam struct {
|
||||
DB *cassandra.DB
|
||||
Keyspace string
|
||||
}
|
||||
|
||||
// MustMessageRepository 創建 MessageRepository(如果失敗會 panic)
|
||||
func MustMessageRepository(param MessageRepositoryParam) repository.MessageRepository {
|
||||
repo, err := NewMessageRepository(param.DB, param.Keyspace)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create message repository: %v", err))
|
||||
}
|
||||
return repo
|
||||
}
|
||||
|
||||
// NewMessageRepository 創建新的訊息 Repository
|
||||
func NewMessageRepository(db *cassandra.DB, keyspace string) (repository.MessageRepository, error) {
|
||||
repo, err := cassandra.NewRepository[entity.Message](db, keyspace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dedupRepo, err := cassandra.NewRepository[entity.MessageDedup](db, keyspace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &messageRepository{
|
||||
repo: repo,
|
||||
dedupRepo: dedupRepo,
|
||||
db: db,
|
||||
keyspace: keyspace,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (message *messageRepository) Insert(ctx context.Context, msg *entity.Message) error {
|
||||
now := time.Now().UTC()
|
||||
if msg.TS == 0 {
|
||||
msg.TS = now.UnixNano()
|
||||
}
|
||||
// 只在 BucketDay 為空時才自動設置,保留先前傳入的值
|
||||
if msg.BucketDay == "" {
|
||||
msg.BucketDay = now.Format(time.DateOnly)
|
||||
}
|
||||
|
||||
return message.repo.Insert(ctx, *msg)
|
||||
}
|
||||
|
||||
func (message *messageRepository) ListMessages(ctx context.Context, param repository.ListMessagesReq) ([]entity.Message, error) {
|
||||
// 設定預設分頁大小
|
||||
if param.PageSize <= 0 {
|
||||
param.PageSize = 20
|
||||
}
|
||||
|
||||
// 將字串 RoomID 轉換為 UUID
|
||||
roomUUID, err := gocql.ParseUUID(param.RoomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 構建查詢條件
|
||||
query := message.repo.Query().
|
||||
Where(cassandra.Eq("room_id", roomUUID)).
|
||||
Where(cassandra.Eq("bucket_day", param.BucketDay))
|
||||
|
||||
// 使用 cursor-based pagination:如果提供了 LastTS,則查詢 ts < LastTS 的訊息
|
||||
// 因為排序是 DESC,所以使用 < 來獲取更早的訊息(下一頁)
|
||||
if param.LastTS > 0 {
|
||||
query = query.Where(cassandra.Lt("ts", param.LastTS))
|
||||
}
|
||||
|
||||
// 添加排序和限制
|
||||
query = query.
|
||||
OrderBy("ts", cassandra.DESC).
|
||||
Limit(int(param.PageSize))
|
||||
|
||||
// 執行查詢
|
||||
var messages []entity.Message
|
||||
if err := query.Scan(ctx, &messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (message *messageRepository) Count(ctx context.Context, roomID string) (int64, error) {
|
||||
// 將字串 RoomID 轉換為 UUID
|
||||
roomUUID, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 注意:由於 partition key 是 (room_id, bucket_day),只用 room_id 查詢需要 ALLOW FILTERING
|
||||
// 這在生產環境中效能較差,建議改用按 bucket_day 分別查詢後加總
|
||||
count, err := message.repo.Query().
|
||||
Where(cassandra.Eq("room_id", roomUUID)).
|
||||
AllowFiltering().
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CheckAndInsertDedup 檢查並插入去重記錄
|
||||
// 使用 IF NOT EXISTS 來實現原子性的去重檢查
|
||||
// 返回值:true 表示已存在(重複),false 表示成功插入(不重複)
|
||||
func (message *messageRepository) CheckAndInsertDedup(ctx context.Context, param repository.CheckDupReq) (bool, error) {
|
||||
// 將字串 RoomID 轉換為 UUID
|
||||
roomUUID, err := gocql.ParseUUID(param.RoomID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 使用 IF NOT EXISTS 來實現原子性的去重檢查
|
||||
// 如果記錄已存在,INSERT 不會插入,且 applied = false
|
||||
dedup := entity.MessageDedup{
|
||||
RoomID: roomUUID,
|
||||
UID: param.UID,
|
||||
BucketSec: param.BucketSec,
|
||||
ContentMD5: param.ContentMD5,
|
||||
}
|
||||
|
||||
// 使用原生 CQL 語句來實現 IF NOT EXISTS
|
||||
tableName := dedup.TableName()
|
||||
stmt := fmt.Sprintf(
|
||||
"INSERT INTO %s.%s (room_id, uid, bucket_sec, content_md5) VALUES (?, ?, ?, ?) IF NOT EXISTS",
|
||||
message.keyspace,
|
||||
tableName,
|
||||
)
|
||||
|
||||
// 執行 INSERT IF NOT EXISTS
|
||||
applied, err := message.db.GetSession().Query(stmt, nil).
|
||||
Bind(roomUUID, param.UID, param.BucketSec, param.ContentMD5).
|
||||
WithContext(ctx).
|
||||
MapScanCAS(make(map[string]interface{}))
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check dedup: %w", err)
|
||||
}
|
||||
|
||||
// applied = false 表示記錄已存在(重複)
|
||||
// applied = true 表示成功插入(不重複)
|
||||
return !applied, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,526 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"backend/pkg/chat/domain/entity"
|
||||
"backend/pkg/chat/domain/repository"
|
||||
"backend/pkg/library/cassandra"
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
var (
|
||||
// 全局變量:所有測試共享的資源
|
||||
testDB *cassandra.DB
|
||||
testRepo repository.MessageRepository
|
||||
testContainer testcontainers.Container
|
||||
cleanupContainer func()
|
||||
)
|
||||
|
||||
// TestMain 在所有測試之前執行一次,設置共享的資料庫
|
||||
func TestMain(m *testing.M) {
|
||||
// 使用更長的 context 超時時間(Cassandra 啟動需要較長時間)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 啟動 Cassandra 容器(只執行一次)
|
||||
// 不指定固定 port,讓 testcontainers 自動分配,避免與本地開發環境衝突
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: "cassandra:4.1",
|
||||
ExposedPorts: []string{"9042/tcp"},
|
||||
// 等待 Cassandra 日誌確認啟動完成(Cassandra 啟動需要較長時間)
|
||||
WaitingFor: wait.ForLog("Startup complete").
|
||||
WithStartupTimeout(3 * time.Minute).
|
||||
WithOccurrence(1),
|
||||
Env: map[string]string{
|
||||
"CASSANDRA_CLUSTER_NAME": "test-cluster",
|
||||
"HEAP_NEWSIZE": "128M",
|
||||
"MAX_HEAP_SIZE": "512M",
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
testContainer, err = testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to start Cassandra container: " + err.Error())
|
||||
}
|
||||
|
||||
port, err := testContainer.MappedPort(ctx, "9042")
|
||||
if err != nil {
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to get mapped port: " + err.Error())
|
||||
}
|
||||
|
||||
host, err := testContainer.Host(ctx)
|
||||
if err != nil {
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to get host: " + err.Error())
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(port.Port())
|
||||
if err != nil {
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to convert port to int: " + err.Error())
|
||||
}
|
||||
|
||||
// 先創建 DB 連接(不指定 keyspace,因為 keyspace 還不存在)
|
||||
testDB, err = cassandra.New(
|
||||
cassandra.WithHosts(host),
|
||||
cassandra.WithPort(portInt),
|
||||
// 不指定 keyspace,先連接後再創建
|
||||
)
|
||||
if err != nil {
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to create DB: " + err.Error())
|
||||
}
|
||||
|
||||
// 創建 keyspace(需要在連接後才能創建)
|
||||
createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
|
||||
if err := testDB.GetSession().Query(createKeyspaceStmt, nil).Exec(); err != nil {
|
||||
testDB.Close()
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to create keyspace: " + err.Error())
|
||||
}
|
||||
|
||||
// 等待 keyspace 創建完成
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// 創建 messages_by_room 表(需要指定 keyspace)
|
||||
createTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.messages_by_room (
|
||||
room_id uuid,
|
||||
bucket_day text,
|
||||
ts bigint,
|
||||
uid text,
|
||||
content text,
|
||||
PRIMARY KEY ((room_id, bucket_day), ts)
|
||||
) WITH CLUSTERING ORDER BY (ts DESC)`
|
||||
if err := testDB.GetSession().Query(createTableStmt, nil).Exec(); err != nil {
|
||||
testDB.Close()
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to create table: " + err.Error())
|
||||
}
|
||||
|
||||
// 創建 messages repository
|
||||
testRepo, err = NewMessageRepository(testDB, "test_keyspace")
|
||||
if err != nil {
|
||||
testDB.Close()
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to create message repository: " + err.Error())
|
||||
}
|
||||
|
||||
// 創建 room 相關表(供 room_test.go 使用)
|
||||
createRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room (
|
||||
room_id uuid PRIMARY KEY,
|
||||
name text,
|
||||
status text,
|
||||
created_at bigint,
|
||||
updated_at bigint
|
||||
)`
|
||||
if err := testDB.GetSession().Query(createRoomTableStmt, nil).Exec(); err != nil {
|
||||
testDB.Close()
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to create room table: " + err.Error())
|
||||
}
|
||||
|
||||
createMemberTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room_member (
|
||||
room_id uuid,
|
||||
uid text,
|
||||
role text,
|
||||
joined_at bigint,
|
||||
PRIMARY KEY (room_id, uid)
|
||||
)`
|
||||
if err := testDB.GetSession().Query(createMemberTableStmt, nil).Exec(); err != nil {
|
||||
testDB.Close()
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to create room_member table: " + err.Error())
|
||||
}
|
||||
|
||||
createUserRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.user_room (
|
||||
uid text,
|
||||
room_id uuid,
|
||||
joined_at bigint,
|
||||
PRIMARY KEY (uid, room_id)
|
||||
)`
|
||||
if err := testDB.GetSession().Query(createUserRoomTableStmt, nil).Exec(); err != nil {
|
||||
testDB.Close()
|
||||
testContainer.Terminate(ctx)
|
||||
panic("Failed to create user_room table: " + err.Error())
|
||||
}
|
||||
|
||||
// 設置清理函數
|
||||
cleanupContainer = func() {
|
||||
if testDB != nil {
|
||||
testDB.Close()
|
||||
}
|
||||
if testContainer != nil {
|
||||
_ = testContainer.Terminate(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// 執行所有測試
|
||||
code := m.Run()
|
||||
|
||||
// 清理資源
|
||||
cleanupContainer()
|
||||
|
||||
// 退出
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// clearMessages 清空 messages_by_room 表的所有數據
|
||||
func clearMessages(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用完整的 keyspace.table 名稱
|
||||
truncateStmt := "TRUNCATE test_keyspace.messages_by_room"
|
||||
if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil {
|
||||
t.Fatalf("Failed to truncate messages_by_room table: %v", err)
|
||||
}
|
||||
// 等待數據清空
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestNewMessageRepository(t *testing.T) {
|
||||
clearMessages(t)
|
||||
assert.NotNil(t, testRepo)
|
||||
}
|
||||
|
||||
func TestMessageRepository_Insert(t *testing.T) {
|
||||
clearMessages(t)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message *entity.Message
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful insert",
|
||||
message: &entity.Message{
|
||||
RoomID: gocql.TimeUUID(),
|
||||
BucketDay: "20250117",
|
||||
TS: time.Now().UnixNano(),
|
||||
UID: "user-1",
|
||||
Content: "Hello, world!",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insert with auto timestamp",
|
||||
message: &entity.Message{
|
||||
RoomID: gocql.TimeUUID(),
|
||||
BucketDay: "20250117",
|
||||
UID: "user-2",
|
||||
Content: "Test message",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := testRepo.Insert(ctx, tt.message)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
// 驗證 TS 和 BucketDay 是否被自動設置
|
||||
if tt.message.TS == 0 {
|
||||
assert.NotZero(t, tt.message.TS)
|
||||
}
|
||||
if tt.message.BucketDay == "" {
|
||||
assert.NotEmpty(t, tt.message.BucketDay)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageRepository_ListMessages(t *testing.T) {
|
||||
clearMessages(t)
|
||||
ctx := context.Background()
|
||||
roomID := gocql.TimeUUID()
|
||||
roomIDStr := roomID.String()
|
||||
bucketDay := time.Now().Format(time.DateOnly)
|
||||
|
||||
// 插入測試數據(使用相同的 roomID)
|
||||
messages := []*entity.Message{
|
||||
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().UnixNano(), UID: "user-1", Content: "Message 1"},
|
||||
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Second).UnixNano(), UID: "user-2", Content: "Message 2"},
|
||||
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(2 * time.Second).UnixNano(), UID: "user-3", Content: "Message 3"},
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
require.NoError(t, testRepo.Insert(ctx, msg))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
param repository.ListMessagesReq
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "list all messages",
|
||||
param: repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: bucketDay,
|
||||
PageSize: 10,
|
||||
LastTS: 0,
|
||||
},
|
||||
wantLen: 3,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "list with page size limit",
|
||||
param: repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: bucketDay,
|
||||
PageSize: 2,
|
||||
LastTS: 0,
|
||||
},
|
||||
wantLen: 2,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "list with cursor pagination",
|
||||
param: repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: bucketDay,
|
||||
PageSize: 10,
|
||||
LastTS: messages[1].TS, // 使用第二條訊息的 TS 作為 cursor
|
||||
},
|
||||
wantLen: 1, // 應該只返回比 LastTS 更早的訊息
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "list with default page size",
|
||||
param: repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: bucketDay,
|
||||
PageSize: 0, // 應該使用預設值 20
|
||||
LastTS: 0,
|
||||
},
|
||||
wantLen: 3,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "list empty room",
|
||||
param: repository.ListMessagesReq{
|
||||
RoomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
|
||||
BucketDay: bucketDay,
|
||||
PageSize: 10,
|
||||
LastTS: 0,
|
||||
},
|
||||
wantLen: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := testRepo.ListMessages(ctx, tt.param)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, tt.wantLen)
|
||||
// 驗證排序(應該按 TS DESC 排序)
|
||||
if len(result) > 1 {
|
||||
for i := 0; i < len(result)-1; i++ {
|
||||
assert.GreaterOrEqual(t, result[i].TS, result[i+1].TS, "Messages should be sorted by TS DESC")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageRepository_Count(t *testing.T) {
|
||||
clearMessages(t)
|
||||
ctx := context.Background()
|
||||
roomID := gocql.TimeUUID()
|
||||
roomIDStr := roomID.String()
|
||||
bucketDay := time.Now().Format(time.DateOnly)
|
||||
|
||||
// 插入測試數據(使用相同的 roomID)
|
||||
for i := 0; i < 5; i++ {
|
||||
msg := &entity.Message{
|
||||
RoomID: roomID,
|
||||
BucketDay: bucketDay,
|
||||
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
|
||||
UID: "user-1",
|
||||
Content: "Message",
|
||||
}
|
||||
require.NoError(t, testRepo.Insert(ctx, msg))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roomID string
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "count existing room",
|
||||
roomID: roomIDStr,
|
||||
want: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "count non-existent room",
|
||||
roomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
|
||||
want: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
count, err := testRepo.Count(ctx, tt.roomID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, count)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageRepository_Insert_ListMessages_Integration(t *testing.T) {
|
||||
clearMessages(t)
|
||||
ctx := context.Background()
|
||||
roomID := gocql.TimeUUID()
|
||||
roomIDStr := roomID.String()
|
||||
bucketDay := time.Now().Format(time.DateOnly)
|
||||
|
||||
// 插入多條訊息(使用相同的 roomID)
|
||||
messages := make([]*entity.Message, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
msg := &entity.Message{
|
||||
RoomID: roomID,
|
||||
BucketDay: bucketDay,
|
||||
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
|
||||
UID: "user-1",
|
||||
Content: "Message",
|
||||
}
|
||||
messages[i] = msg
|
||||
require.NoError(t, testRepo.Insert(ctx, msg))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 驗證可以查詢到所有訊息
|
||||
result, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: bucketDay,
|
||||
PageSize: 20,
|
||||
LastTS: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 10)
|
||||
|
||||
// 驗證分頁功能
|
||||
firstPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: bucketDay,
|
||||
PageSize: 5,
|
||||
LastTS: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, firstPage, 5)
|
||||
|
||||
// 使用 cursor 獲取下一頁
|
||||
if len(firstPage) > 0 {
|
||||
lastTS := firstPage[len(firstPage)-1].TS
|
||||
secondPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: bucketDay,
|
||||
PageSize: 5,
|
||||
LastTS: lastTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, secondPage, 5)
|
||||
// 驗證第二頁的訊息都比第一頁的最後一條更早
|
||||
if len(secondPage) > 0 {
|
||||
assert.Less(t, secondPage[0].TS, lastTS)
|
||||
}
|
||||
}
|
||||
|
||||
// 驗證計數
|
||||
count, err := testRepo.Count(ctx, roomIDStr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(10), count)
|
||||
}
|
||||
|
||||
func TestMessageRepository_DifferentBuckets(t *testing.T) {
|
||||
clearMessages(t)
|
||||
ctx := context.Background()
|
||||
roomID := gocql.TimeUUID()
|
||||
roomIDStr := roomID.String()
|
||||
today := time.Now().Format(time.DateOnly)
|
||||
yesterday := time.Now().AddDate(0, 0, -1).Format(time.DateOnly)
|
||||
|
||||
// 插入不同 bucket 的訊息(使用相同的 roomID)
|
||||
todayMsg := &entity.Message{
|
||||
RoomID: roomID,
|
||||
BucketDay: today,
|
||||
TS: time.Now().UnixNano(),
|
||||
UID: "user-1",
|
||||
Content: "Today's message",
|
||||
}
|
||||
yesterdayMsg := &entity.Message{
|
||||
RoomID: roomID,
|
||||
BucketDay: yesterday,
|
||||
TS: time.Now().UnixNano(),
|
||||
UID: "user-1",
|
||||
Content: "Yesterday's message",
|
||||
}
|
||||
|
||||
require.NoError(t, testRepo.Insert(ctx, todayMsg))
|
||||
require.NoError(t, testRepo.Insert(ctx, yesterdayMsg))
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 查詢今天的訊息
|
||||
todayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: today,
|
||||
PageSize: 10,
|
||||
LastTS: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, todayMessages, 1)
|
||||
assert.Equal(t, "Today's message", todayMessages[0].Content)
|
||||
|
||||
// 查詢昨天的訊息
|
||||
yesterdayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||||
RoomID: roomIDStr,
|
||||
BucketDay: yesterday,
|
||||
PageSize: 10,
|
||||
LastTS: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, yesterdayMessages, 1)
|
||||
assert.Equal(t, "Yesterday's message", yesterdayMessages[0].Content)
|
||||
}
|
||||
|
|
@ -0,0 +1,404 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"backend/pkg/chat/domain/chat"
|
||||
"backend/pkg/chat/domain/entity"
|
||||
"backend/pkg/chat/domain/repository"
|
||||
"backend/pkg/library/cassandra"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
type roomRepository struct {
|
||||
roomRepo cassandra.Repository[entity.Room]
|
||||
memberRepo cassandra.Repository[entity.RoomMember]
|
||||
userRoomRepo cassandra.Repository[entity.UserRoom]
|
||||
db *cassandra.DB
|
||||
keyspace string
|
||||
}
|
||||
|
||||
// RoomRepositoryParam 創建 RoomRepository 所需的參數
|
||||
type RoomRepositoryParam struct {
|
||||
DB *cassandra.DB
|
||||
Keyspace string
|
||||
}
|
||||
|
||||
// MustRoomRepository 創建 RoomRepository(如果失敗會 panic)
|
||||
func MustRoomRepository(param RoomRepositoryParam) repository.RoomRepository {
|
||||
repo, err := NewRoomRepository(param.DB, param.Keyspace)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create room repository: %v", err))
|
||||
}
|
||||
return repo
|
||||
}
|
||||
|
||||
// NewRoomRepository 創建新的聊天室 Repository
|
||||
func NewRoomRepository(db *cassandra.DB, keyspace string) (repository.RoomRepository, error) {
|
||||
roomRepo, err := cassandra.NewRepository[entity.Room](db, keyspace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memberRepo, err := cassandra.NewRepository[entity.RoomMember](db, keyspace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userRoomRepo, err := cassandra.NewRepository[entity.UserRoom](db, keyspace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &roomRepository{
|
||||
roomRepo: roomRepo,
|
||||
memberRepo: memberRepo,
|
||||
userRoomRepo: userRoomRepo,
|
||||
db: db,
|
||||
keyspace: keyspace,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ==================== Room Interface 實作 ====================
|
||||
|
||||
func (r *roomRepository) Create(ctx context.Context, room *entity.Room) error {
|
||||
now := time.Now().UTC().UnixNano()
|
||||
if room.CreatedAt == 0 {
|
||||
room.CreatedAt = now
|
||||
}
|
||||
if room.UpdatedAt == 0 {
|
||||
room.UpdatedAt = now
|
||||
}
|
||||
room.RoomID = gocql.TimeUUID()
|
||||
|
||||
return r.roomRepo.Insert(ctx, *room)
|
||||
}
|
||||
|
||||
func (r *roomRepository) RoomGet(ctx context.Context, roomID string) (*entity.Room, error) {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
room, err := r.roomRepo.Get(ctx, entity.Room{
|
||||
RoomID: uuid,
|
||||
})
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, cassandra.ErrNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &room, nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) RoomUpdate(ctx context.Context, room *entity.Room) error {
|
||||
update := entity.Room{}
|
||||
now := time.Now().UTC().UnixNano()
|
||||
|
||||
get, err := r.RoomGet(ctx, room.RoomID.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
update.CreatedAt = get.CreatedAt
|
||||
update.UpdatedAt = now
|
||||
update.Status = room.Status
|
||||
update.Name = room.Name
|
||||
update.RoomID = room.RoomID
|
||||
|
||||
if err = r.roomRepo.Update(ctx, update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) RoomDelete(ctx context.Context, roomID string) error {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用原生 CQL 語句來刪除,避免 cassandra library Delete 方法的 struct 綁定問題
|
||||
e := entity.Room{}
|
||||
stmt := fmt.Sprintf("DELETE FROM %s.%s WHERE room_id = ?", r.keyspace, e.TableName())
|
||||
return r.db.GetSession().Query(stmt, nil).
|
||||
Bind(uuid).
|
||||
WithContext(ctx).
|
||||
ExecRelease()
|
||||
}
|
||||
|
||||
func (r *roomRepository) RoomList(ctx context.Context, param repository.ListRoomsReq) ([]entity.Room, error) {
|
||||
query := r.roomRepo.Query()
|
||||
|
||||
if param.Status != "" {
|
||||
query = query.Where(cassandra.Eq("status", param.Status)).AllowFiltering()
|
||||
}
|
||||
|
||||
if param.PageSize > 0 {
|
||||
query = query.Limit(param.PageSize)
|
||||
}
|
||||
|
||||
if param.LastID != "" {
|
||||
lastUUID, err := gocql.ParseUUID(param.LastID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = query.Where(cassandra.Lt("room_id", lastUUID))
|
||||
}
|
||||
|
||||
var rooms []entity.Room
|
||||
if err := query.Scan(ctx, &rooms); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rooms, nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) RoomCount(ctx context.Context, param repository.CountRoomsReq) (int64, error) {
|
||||
query := r.roomRepo.Query()
|
||||
if param.Status != "" {
|
||||
query = query.Where(cassandra.Eq("status", param.Status)).AllowFiltering()
|
||||
}
|
||||
return query.Count(ctx)
|
||||
}
|
||||
|
||||
func (r *roomRepository) RoomExists(ctx context.Context, roomID string) (bool, error) {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, err = r.roomRepo.Get(ctx, entity.Room{
|
||||
RoomID: uuid,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) RoomGetByID(ctx context.Context, roomIDs []string) ([]entity.Room, error) {
|
||||
if len(roomIDs) == 0 {
|
||||
return []entity.Room{}, nil
|
||||
}
|
||||
|
||||
// 將字串 ID 轉換為 UUID
|
||||
uuids := make([]any, 0, len(roomIDs))
|
||||
for _, id := range roomIDs {
|
||||
uuid, err := gocql.ParseUUID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
uuids = append(uuids, uuid)
|
||||
}
|
||||
|
||||
var rooms []entity.Room
|
||||
err := r.roomRepo.Query().
|
||||
Where(cassandra.In("room_id", uuids)).
|
||||
Scan(ctx, &rooms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rooms, nil
|
||||
}
|
||||
|
||||
// ==================== Member Interface 實作 ====================
|
||||
|
||||
func (r *roomRepository) Insert(ctx context.Context, member *entity.RoomMember) error {
|
||||
now := time.Now().UTC().UnixNano()
|
||||
if member.JoinedAt == 0 {
|
||||
member.JoinedAt = now
|
||||
}
|
||||
if member.Role == "" {
|
||||
member.Role = chat.RoomRoleMember.String()
|
||||
}
|
||||
|
||||
// 同時插入到 user_room 表(反向查詢表)
|
||||
userRoom := entity.UserRoom{
|
||||
UID: member.UID,
|
||||
RoomID: member.RoomID,
|
||||
JoinedAt: member.JoinedAt,
|
||||
}
|
||||
|
||||
if err := r.memberRepo.Insert(ctx, *member); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.userRoomRepo.Insert(ctx, userRoom); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) Get(ctx context.Context, roomID, uid string) (*entity.RoomMember, error) {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
member, err := r.memberRepo.Get(ctx, entity.RoomMember{
|
||||
RoomID: uuid,
|
||||
UID: uid,
|
||||
})
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, cassandra.ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &member, nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) AllMembers(ctx context.Context, roomID string) ([]entity.RoomMember, error) {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var members []entity.RoomMember
|
||||
err = r.memberRepo.Query().
|
||||
Where(cassandra.Eq("room_id", uuid)).
|
||||
Scan(ctx, &members)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) UpdateRole(ctx context.Context, member *entity.RoomMember) error {
|
||||
get, err := r.Get(ctx, member.RoomID.String(), member.UID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
update := entity.RoomMember{
|
||||
RoomID: member.RoomID,
|
||||
UID: member.UID,
|
||||
Role: member.Role,
|
||||
JoinedAt: get.JoinedAt,
|
||||
}
|
||||
|
||||
return r.memberRepo.Update(ctx, update)
|
||||
}
|
||||
|
||||
func (r *roomRepository) DeleteMember(ctx context.Context, roomID, uid string) error {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 同時從兩個表中刪除
|
||||
if err := r.memberRepo.Delete(ctx, entity.RoomMember{
|
||||
RoomID: uuid,
|
||||
UID: uid,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.userRoomRepo.Delete(ctx, entity.UserRoom{
|
||||
UID: uid,
|
||||
RoomID: uuid,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) DeleteRoom(ctx context.Context, roomID string) error {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 先查詢所有成員(必須在刪除之前查詢)
|
||||
members, err := r.AllMembers(ctx, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 刪除 user_room 表中的所有關聯
|
||||
for _, member := range members {
|
||||
if err := r.userRoomRepo.Delete(ctx, entity.UserRoom{
|
||||
UID: member.UID,
|
||||
RoomID: uuid,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 最後刪除 room_member 表中的所有成員
|
||||
e := entity.RoomMember{}
|
||||
stmt := fmt.Sprintf("DELETE FROM %s.%s WHERE room_id = ?", r.keyspace, e.TableName())
|
||||
if err := r.db.GetSession().Query(stmt, nil).
|
||||
Bind(uuid).
|
||||
WithContext(ctx).
|
||||
WithTimestamp(time.Now().UnixNano() / 1e3).
|
||||
ExecRelease(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) Count(ctx context.Context, roomID string) (int64, error) {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return r.memberRepo.Query().
|
||||
Where(cassandra.Eq("room_id", uuid)).
|
||||
Count(ctx)
|
||||
}
|
||||
|
||||
// ==================== User Interface 實作 ====================
|
||||
|
||||
func (r *roomRepository) GetUserRooms(ctx context.Context, uid string) ([]entity.UserRoom, error) {
|
||||
var userRooms []entity.UserRoom
|
||||
err := r.userRoomRepo.Query().
|
||||
Where(cassandra.Eq("uid", uid)).
|
||||
Scan(ctx, &userRooms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userRooms, nil
|
||||
}
|
||||
|
||||
func (r *roomRepository) CountUserRooms(ctx context.Context, uid string) (int64, error) {
|
||||
return r.userRoomRepo.Query().
|
||||
Where(cassandra.Eq("uid", uid)).
|
||||
Count(ctx)
|
||||
}
|
||||
|
||||
func (r *roomRepository) IsUserInRoom(ctx context.Context, uid, roomID string) (bool, error) {
|
||||
uuid, err := gocql.ParseUUID(roomID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = r.userRoomRepo.Get(ctx, entity.UserRoom{
|
||||
UID: uid,
|
||||
RoomID: uuid,
|
||||
})
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,953 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"backend/pkg/chat/domain/chat"
|
||||
"backend/pkg/chat/domain/entity"
|
||||
"backend/pkg/chat/domain/repository"
|
||||
"backend/pkg/library/cassandra"
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
// 全局變量:所有測試共享的資源(與 message_test.go 共享)
|
||||
roomTestRepo repository.RoomRepository
|
||||
)
|
||||
|
||||
// ensureRoomRepository 確保 Room Repository 已初始化
|
||||
func ensureRoomRepository(t *testing.T) {
|
||||
if roomTestRepo == nil {
|
||||
if testDB == nil {
|
||||
t.Fatal("testDB is not initialized. Make sure TestMain in message_test.go runs first.")
|
||||
}
|
||||
var err error
|
||||
roomTestRepo, err = NewRoomRepository(testDB, "test_keyspace")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create room repository: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clearRoomData 清空所有 room 相關表的數據
|
||||
func clearRoomData(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tables := []string{"user_room", "room_member", "room"}
|
||||
for _, table := range tables {
|
||||
truncateStmt := "TRUNCATE test_keyspace." + table
|
||||
if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil {
|
||||
t.Fatalf("Failed to truncate %s table: %v", table, err)
|
||||
}
|
||||
}
|
||||
// 等待數據清空
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// ==================== Room Interface 測試 ====================
|
||||
|
||||
func TestNewRoomRepository(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
assert.NotNil(t, roomTestRepo)
|
||||
}
|
||||
|
||||
func TestRoomRepository_Create(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
room *entity.Room
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful create",
|
||||
room: &entity.Room{
|
||||
Name: "Test Room",
|
||||
Status: "active",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create with auto timestamp",
|
||||
room: &entity.Room{
|
||||
Name: "Auto Timestamp Room",
|
||||
Status: "active",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := roomTestRepo.Create(ctx, tt.room)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
// 驗證 RoomID 和時間戳是否被自動設置
|
||||
assert.NotZero(t, tt.room.RoomID)
|
||||
assert.NotZero(t, tt.room.CreatedAt)
|
||||
assert.NotZero(t, tt.room.UpdatedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomRepository_RoomGet(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "Get Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
roomIDStr := room.RoomID.String()
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roomID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "get existing room",
|
||||
roomID: roomIDStr,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "get non-existent room",
|
||||
roomID: gocql.TimeUUID().String(),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "get with invalid UUID",
|
||||
roomID: "invalid-uuid",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := roomTestRepo.RoomGet(ctx, tt.roomID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, room.Name, result.Name)
|
||||
assert.Equal(t, room.Status, result.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomRepository_RoomUpdate(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "Original Name",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
originalCreatedAt := room.CreatedAt
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 更新房間
|
||||
room.Name = "Updated Name"
|
||||
room.Status = "archived"
|
||||
err := roomTestRepo.RoomUpdate(ctx, room)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 驗證更新
|
||||
updated, err := roomTestRepo.RoomGet(ctx, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Name", updated.Name)
|
||||
assert.Equal(t, "archived", updated.Status)
|
||||
assert.Equal(t, originalCreatedAt, updated.CreatedAt) // CreatedAt 不應該改變
|
||||
assert.Greater(t, updated.UpdatedAt, originalCreatedAt) // UpdatedAt 應該更新
|
||||
}
|
||||
|
||||
func TestRoomRepository_RoomDelete(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "Delete Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
roomIDStr := room.RoomID.String()
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 刪除房間
|
||||
err := roomTestRepo.RoomDelete(ctx, roomIDStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 驗證房間已被刪除
|
||||
_, err = roomTestRepo.RoomGet(ctx, roomIDStr)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, cassandra.IsNotFound(err))
|
||||
}
|
||||
|
||||
func TestRoomRepository_RoomList(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建多個測試房間
|
||||
rooms := make([]*entity.Room, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
room := &entity.Room{
|
||||
Name: "Room " + strconv.Itoa(i),
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
rooms[i] = room
|
||||
time.Sleep(10 * time.Millisecond) // 確保時間戳不同
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
param repository.ListRoomsReq
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "list all rooms",
|
||||
param: repository.ListRoomsReq{
|
||||
PageSize: 10,
|
||||
},
|
||||
wantLen: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "list with status filter",
|
||||
param: repository.ListRoomsReq{
|
||||
Status: "active",
|
||||
PageSize: 10,
|
||||
},
|
||||
wantLen: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "list with page size limit",
|
||||
param: repository.ListRoomsReq{
|
||||
PageSize: 2,
|
||||
},
|
||||
wantLen: 2,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := roomTestRepo.RoomList(ctx, tt.param)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomRepository_RoomCount(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
for i := 0; i < 3; i++ {
|
||||
room := &entity.Room{
|
||||
Name: "Count Room " + strconv.Itoa(i),
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
}
|
||||
|
||||
// 創建一個 archived 房間
|
||||
archivedRoom := &entity.Room{
|
||||
Name: "Archived Room",
|
||||
Status: "archived",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, archivedRoom))
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
param repository.CountRoomsReq
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "count all rooms",
|
||||
param: repository.CountRoomsReq{},
|
||||
want: 4,
|
||||
},
|
||||
{
|
||||
name: "count active rooms",
|
||||
param: repository.CountRoomsReq{
|
||||
Status: "active",
|
||||
},
|
||||
want: 3,
|
||||
},
|
||||
{
|
||||
name: "count archived rooms",
|
||||
param: repository.CountRoomsReq{
|
||||
Status: "archived",
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
count, err := roomTestRepo.RoomCount(ctx, tt.param)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, count)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomRepository_RoomExists(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "Exists Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
roomIDStr := room.RoomID.String()
|
||||
nonExistentID := gocql.TimeUUID().String()
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roomID string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "existing room",
|
||||
roomID: roomIDStr,
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent room",
|
||||
roomID: nonExistentID,
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
exists, err := roomTestRepo.RoomExists(ctx, tt.roomID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomRepository_RoomGetByID(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建多個測試房間
|
||||
rooms := make([]*entity.Room, 3)
|
||||
roomIDs := make([]string, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
room := &entity.Room{
|
||||
Name: "Room " + strconv.Itoa(i),
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
rooms[i] = room
|
||||
roomIDs[i] = room.RoomID.String()
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roomIDs []string
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "get multiple rooms",
|
||||
roomIDs: roomIDs,
|
||||
wantLen: 3,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "get empty list",
|
||||
roomIDs: []string{},
|
||||
wantLen: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "get with non-existent ID",
|
||||
roomIDs: []string{roomIDs[0], gocql.TimeUUID().String()},
|
||||
wantLen: 1, // 只返回存在的房間
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := roomTestRepo.RoomGetByID(ctx, tt.roomIDs)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Member Interface 測試 ====================
|
||||
|
||||
func TestRoomRepository_Insert(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "Member Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
member *entity.RoomMember
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful insert",
|
||||
member: &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: "user-1",
|
||||
Role: chat.RoomRoleMember.String(),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insert with auto role",
|
||||
member: &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: "user-2",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insert with admin role",
|
||||
member: &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: "user-3",
|
||||
Role: chat.RoomRoleAdmin.String(),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := roomTestRepo.Insert(ctx, tt.member)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
// 驗證 JoinedAt 是否被自動設置
|
||||
assert.NotZero(t, tt.member.JoinedAt)
|
||||
// 驗證 Role 是否被自動設置(如果為空)
|
||||
if tt.member.Role == "" {
|
||||
assert.Equal(t, chat.RoomRoleMember.String(), tt.member.Role)
|
||||
}
|
||||
|
||||
// 驗證 user_room 表也被更新
|
||||
userRooms, err := roomTestRepo.GetUserRooms(ctx, tt.member.UID)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, len(userRooms), 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomRepository_Get(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間和成員
|
||||
room := &entity.Room{
|
||||
Name: "Get Member Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
member := &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: "user-1",
|
||||
Role: chat.RoomRoleAdmin.String(),
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roomID string
|
||||
uid string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "get existing member",
|
||||
roomID: room.RoomID.String(),
|
||||
uid: "user-1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "get non-existent member",
|
||||
roomID: room.RoomID.String(),
|
||||
uid: "non-existent-user",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := roomTestRepo.Get(ctx, tt.roomID, tt.uid)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, member.UID, result.UID)
|
||||
assert.Equal(t, member.Role, result.Role)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomRepository_AllMembers(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "All Members Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
// 插入多個成員
|
||||
members := []*entity.RoomMember{
|
||||
{RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleMember.String()},
|
||||
{RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()},
|
||||
{RoomID: room.RoomID, UID: "user-3", Role: chat.RoomRoleMember.String()},
|
||||
}
|
||||
|
||||
for _, member := range members {
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 查詢所有成員
|
||||
result, err := roomTestRepo.AllMembers(ctx, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 3)
|
||||
}
|
||||
|
||||
func TestRoomRepository_UpdateRole(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間和成員
|
||||
room := &entity.Room{
|
||||
Name: "Update Role Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
member := &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: "user-1",
|
||||
Role: chat.RoomRoleMember.String(),
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 更新角色
|
||||
member.Role = chat.RoomRoleAdmin.String()
|
||||
err := roomTestRepo.UpdateRole(ctx, member)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 驗證更新
|
||||
updated, err := roomTestRepo.Get(ctx, room.RoomID.String(), "user-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, chat.RoomRoleAdmin.String(), updated.Role)
|
||||
}
|
||||
|
||||
func TestRoomRepository_DeleteMember(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間和成員
|
||||
room := &entity.Room{
|
||||
Name: "Delete Member Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
member := &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: "user-1",
|
||||
Role: chat.RoomRoleMember.String(),
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 刪除成員
|
||||
err := roomTestRepo.DeleteMember(ctx, room.RoomID.String(), "user-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 驗證成員已被刪除
|
||||
_, err = roomTestRepo.Get(ctx, room.RoomID.String(), "user-1")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, cassandra.IsNotFound(err))
|
||||
|
||||
// 驗證 user_room 表也被刪除
|
||||
userRooms, err := roomTestRepo.GetUserRooms(ctx, "user-1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, userRooms, 0)
|
||||
}
|
||||
|
||||
func TestRoomRepository_DeleteRoom(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "Delete Room Test",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
// 插入多個成員
|
||||
members := []*entity.RoomMember{
|
||||
{RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleMember.String()},
|
||||
{RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()},
|
||||
}
|
||||
|
||||
for _, member := range members {
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 刪除整個房間
|
||||
err := roomTestRepo.DeleteRoom(ctx, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
// 驗證所有成員已被刪除
|
||||
allMembers, err := roomTestRepo.AllMembers(ctx, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, allMembers, 0)
|
||||
|
||||
// 驗證 user_room 表中的關聯也被刪除
|
||||
for _, member := range members {
|
||||
userRooms, err := roomTestRepo.GetUserRooms(ctx, member.UID)
|
||||
require.NoError(t, err)
|
||||
// 驗證該用戶的 user_room 記錄中不包含這個房間
|
||||
found := false
|
||||
for _, ur := range userRooms {
|
||||
if ur.RoomID == room.RoomID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.False(t, found, "user_room should not contain deleted room")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomRepository_Count(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "Count Members Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
// 插入多個成員
|
||||
for i := 0; i < 5; i++ {
|
||||
member := &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: "user-" + strconv.Itoa(i),
|
||||
Role: chat.RoomRoleMember.String(),
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 計算成員數
|
||||
count, err := roomTestRepo.Count(ctx, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), count)
|
||||
}
|
||||
|
||||
// ==================== User Interface 測試 ====================
|
||||
|
||||
func TestRoomRepository_GetUserRooms(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建多個測試房間
|
||||
rooms := make([]*entity.Room, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
room := &entity.Room{
|
||||
Name: "User Room " + strconv.Itoa(i),
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
rooms[i] = room
|
||||
}
|
||||
|
||||
// 將用戶加入多個房間
|
||||
uid := "user-1"
|
||||
for _, room := range rooms {
|
||||
member := &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: uid,
|
||||
Role: chat.RoomRoleMember.String(),
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 查詢用戶所在的所有房間
|
||||
userRooms, err := roomTestRepo.GetUserRooms(ctx, uid)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, userRooms, 3)
|
||||
}
|
||||
|
||||
func TestRoomRepository_CountUserRooms(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
rooms := make([]*entity.Room, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
room := &entity.Room{
|
||||
Name: "Count User Room " + strconv.Itoa(i),
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
rooms[i] = room
|
||||
}
|
||||
|
||||
// 將用戶加入多個房間
|
||||
uid := "user-1"
|
||||
for _, room := range rooms {
|
||||
member := &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: uid,
|
||||
Role: chat.RoomRoleMember.String(),
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 計算用戶所在的房間數
|
||||
count, err := roomTestRepo.CountUserRooms(ctx, uid)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), count)
|
||||
}
|
||||
|
||||
func TestRoomRepository_IsUserInRoom(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建測試房間
|
||||
room := &entity.Room{
|
||||
Name: "Is User In Room Test",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
// 插入成員
|
||||
member := &entity.RoomMember{
|
||||
RoomID: room.RoomID,
|
||||
UID: "user-1",
|
||||
Role: chat.RoomRoleMember.String(),
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uid string
|
||||
roomID string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "user in room",
|
||||
uid: "user-1",
|
||||
roomID: room.RoomID.String(),
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "user not in room",
|
||||
uid: "user-2",
|
||||
roomID: room.RoomID.String(),
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := roomTestRepo.IsUserInRoom(ctx, tt.uid, tt.roomID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 整合測試 ====================
|
||||
|
||||
func TestRoomRepository_Integration(t *testing.T) {
|
||||
ensureRoomRepository(t)
|
||||
clearRoomData(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 創建房間
|
||||
room := &entity.Room{
|
||||
Name: "Integration Test Room",
|
||||
Status: "active",
|
||||
}
|
||||
require.NoError(t, roomTestRepo.Create(ctx, room))
|
||||
|
||||
// 添加多個成員
|
||||
members := []*entity.RoomMember{
|
||||
{RoomID: room.RoomID, UID: "user-1", Role: chat.RoomRoleOwner.String()},
|
||||
{RoomID: room.RoomID, UID: "user-2", Role: chat.RoomRoleAdmin.String()},
|
||||
{RoomID: room.RoomID, UID: "user-3", Role: chat.RoomRoleMember.String()},
|
||||
}
|
||||
|
||||
for _, member := range members {
|
||||
require.NoError(t, roomTestRepo.Insert(ctx, member))
|
||||
}
|
||||
|
||||
// 等待數據寫入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 驗證房間存在
|
||||
exists, err := roomTestRepo.RoomExists(ctx, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// 驗證成員數量
|
||||
count, err := roomTestRepo.Count(ctx, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
|
||||
// 驗證所有成員
|
||||
allMembers, err := roomTestRepo.AllMembers(ctx, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, allMembers, 3)
|
||||
|
||||
// 驗證用戶所在的房間
|
||||
for _, member := range members {
|
||||
userRooms, err := roomTestRepo.GetUserRooms(ctx, member.UID)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, len(userRooms), 0)
|
||||
|
||||
inRoom, err := roomTestRepo.IsUserInRoom(ctx, member.UID, room.RoomID.String())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, inRoom)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,272 @@
|
|||
package usecase
|
||||
|
||||
import (
|
||||
"backend/pkg/chat/domain/entity"
|
||||
"backend/pkg/chat/domain/repository"
|
||||
"backend/pkg/chat/domain/usecase"
|
||||
"backend/pkg/library/centrifugo"
|
||||
errs "backend/pkg/library/errors"
|
||||
"backend/pkg/utils"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPageSize = 20
|
||||
maxPageSize = 100 // 設置最大分頁大小
|
||||
)
|
||||
|
||||
type MessageUseCaseParam struct {
|
||||
MessageRepo repository.MessageRepository
|
||||
RoomRepo repository.RoomRepository
|
||||
MsgClient *centrifugo.Client
|
||||
Logger errs.Logger
|
||||
}
|
||||
|
||||
type MessageUseCase struct {
|
||||
MessageUseCaseParam
|
||||
}
|
||||
|
||||
// NewMessageUseCase 創建新的訊息 UseCase
|
||||
func NewMessageUseCase(param MessageUseCaseParam) usecase.MessageUseCase {
|
||||
return &MessageUseCase{
|
||||
param,
|
||||
}
|
||||
}
|
||||
|
||||
func (use *MessageUseCase) SendMessage(ctx context.Context, param usecase.SendMessageReq) error {
|
||||
// 驗證輸入參數
|
||||
if err := use.validateSendMessageReq(param); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 驗證使用者是否在房間中(優先檢查,避免不必要的去重操作)
|
||||
if err := use.verifyRoomMembership(ctx, param.UID, param.RoomID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 去重檢查
|
||||
now := time.Now().UTC()
|
||||
bucketSec := now.Unix()
|
||||
contentMD5 := calculateMD5(param.Content)
|
||||
|
||||
isDuplicate, err := use.MessageRepo.CheckAndInsertDedup(ctx, repository.CheckDupReq{
|
||||
RoomID: param.RoomID,
|
||||
UID: param.UID, // 修復:應該是 param.UID 而不是 param.RoomID
|
||||
BucketSec: bucketSec,
|
||||
ContentMD5: contentMD5,
|
||||
})
|
||||
if err != nil {
|
||||
return use.logError("messageRepo.CheckAndInsertDedup", param, err, "failed to check message deduplication")
|
||||
}
|
||||
|
||||
if isDuplicate {
|
||||
return errs.InputInvalidFormatError("duplicate message detected")
|
||||
}
|
||||
|
||||
// 建立並儲存訊息
|
||||
msg, err := use.createMessage(param, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = use.MessageRepo.Insert(ctx, msg); err != nil {
|
||||
return use.logError("messageRepo.Insert", param, err, "failed to insert message")
|
||||
}
|
||||
|
||||
// 發布到 Centrifugo(非阻塞,失敗不影響主流程)
|
||||
use.publishToCentrifugo(ctx, param.RoomID, msg, now)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateSendMessageReq 驗證發送訊息的請求參數
|
||||
func (use *MessageUseCase) validateSendMessageReq(param usecase.SendMessageReq) error {
|
||||
if param.Content == "" {
|
||||
return errs.InputInvalidFormatError("content cannot be empty")
|
||||
}
|
||||
if param.RoomID == "" {
|
||||
return errs.InputInvalidFormatError("room_id cannot be empty")
|
||||
}
|
||||
if param.UID == "" {
|
||||
return errs.InputInvalidFormatError("uid cannot be empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyRoomMembership 驗證使用者是否在房間中
|
||||
func (use *MessageUseCase) verifyRoomMembership(ctx context.Context, uid, roomID string) error {
|
||||
isMember, err := use.RoomRepo.IsUserInRoom(ctx, uid, roomID)
|
||||
if err != nil {
|
||||
return use.logError("roomRepo.IsUserInRoom", map[string]interface{}{
|
||||
"uid": uid,
|
||||
"roomID": roomID,
|
||||
}, err, "failed to check room membership")
|
||||
}
|
||||
|
||||
if !isMember {
|
||||
return errs.AuthForbiddenErrorL(
|
||||
use.Logger,
|
||||
[]errs.LogField{
|
||||
{Key: "uid", Val: uid},
|
||||
{Key: "roomID", Val: roomID},
|
||||
},
|
||||
"user is not a member of the room")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createMessage 建立訊息實體
|
||||
func (use *MessageUseCase) createMessage(param usecase.SendMessageReq, now time.Time) (*entity.Message, error) {
|
||||
roomID, err := gocql.ParseUUID(param.RoomID)
|
||||
if err != nil {
|
||||
return nil, errs.InputInvalidFormatError("invalid room_id format").Wrap(err)
|
||||
}
|
||||
|
||||
msgID, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
return nil, errs.SysInternalError("failed to generate message id").Wrap(err)
|
||||
}
|
||||
|
||||
return &entity.Message{
|
||||
RoomID: roomID,
|
||||
BucketDay: utils.GetBucketDay(now),
|
||||
UID: param.UID,
|
||||
MsgID: msgID,
|
||||
Content: param.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// publishToCentrifugo 發布訊息到 Centrifugo
|
||||
func (use *MessageUseCase) publishToCentrifugo(ctx context.Context, roomID string, msg *entity.Message, now time.Time) {
|
||||
channel := fmt.Sprintf("room:%s", roomID)
|
||||
messageData := map[string]interface{}{
|
||||
"msg_id": msg.MsgID.String(),
|
||||
"uid": msg.UID,
|
||||
"content": msg.Content,
|
||||
"timestamp": now.UnixNano(), // 使用實際時間戳,而不是 msg.TS(可能為 0)
|
||||
"room_id": msg.RoomID.String(),
|
||||
}
|
||||
|
||||
if _, err := use.MsgClient.PublishJSON(ctx, channel, messageData); err != nil {
|
||||
// 記錄錯誤但不影響主流程,因為訊息已經成功儲存
|
||||
if use.Logger != nil {
|
||||
use.Logger.WithFields(
|
||||
errs.LogField{Key: "roomID", Val: roomID},
|
||||
errs.LogField{Key: "msgID", Val: msg.MsgID.String()},
|
||||
errs.LogField{Key: "error", Val: err.Error()},
|
||||
).Error(fmt.Sprintf("failed to publish message to Centrifugo: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// logError 統一的錯誤記錄方法
|
||||
func (use *MessageUseCase) logError(funcName string, param interface{}, err error, message string) error {
|
||||
return errs.DBErrorErrorL(
|
||||
use.Logger,
|
||||
[]errs.LogField{
|
||||
{Key: "param", Val: param},
|
||||
{Key: "func", Val: funcName},
|
||||
{Key: "err", Val: err.Error()},
|
||||
},
|
||||
message)
|
||||
}
|
||||
|
||||
func (use *MessageUseCase) ListMessages(ctx context.Context, req usecase.ListMessagesReq) ([]usecase.Message, int64, error) {
|
||||
// 驗證輸入參數
|
||||
if err := use.validateListMessagesReq(req); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 驗證使用者是否在房間中
|
||||
if err := use.verifyRoomMembership(ctx, req.UID, req.RoomID); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 取得 bucket_day(如果未提供則使用今天)
|
||||
bucketDay := req.BucketDay
|
||||
if bucketDay == "" {
|
||||
bucketDay = utils.GetTodayBucketDay()
|
||||
}
|
||||
|
||||
// 防止 PageSize overflow 並設置合理的範圍
|
||||
pageSize := use.normalizePageSize(req.PageSize)
|
||||
|
||||
// 查詢訊息
|
||||
messages, err := use.MessageRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||||
RoomID: req.RoomID,
|
||||
BucketDay: bucketDay,
|
||||
PageSize: pageSize,
|
||||
LastTS: req.LastTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, use.logError("messageRepo.ListMessages", req, err, "failed to list messages")
|
||||
}
|
||||
|
||||
// 轉換為 usecase.Message
|
||||
result := make([]usecase.Message, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
result = append(result, usecase.Message{
|
||||
RoomID: msg.RoomID.String(),
|
||||
BucketDay: msg.BucketDay,
|
||||
TS: msg.TS,
|
||||
UID: msg.UID,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// 計算總數(只在第一頁時計算)
|
||||
var total int64
|
||||
if req.LastTS == 0 {
|
||||
total, err = use.MessageRepo.Count(ctx, req.RoomID)
|
||||
if err != nil {
|
||||
return nil, 0, use.logError("messageRepo.Count", req, err, "failed to count messages")
|
||||
}
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// validateListMessagesReq 驗證查詢訊息的請求參數
|
||||
func (use *MessageUseCase) validateListMessagesReq(req usecase.ListMessagesReq) error {
|
||||
if req.RoomID == "" {
|
||||
return errs.InputInvalidFormatError("room_id cannot be empty")
|
||||
}
|
||||
if req.UID == "" {
|
||||
return errs.InputInvalidFormatError("uid cannot be empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizePageSize 正規化分頁大小
|
||||
func (use *MessageUseCase) normalizePageSize(pageSize int64) int {
|
||||
// 檢查是否超過 int 的最大值
|
||||
if pageSize > int64(math.MaxInt) {
|
||||
return maxPageSize
|
||||
}
|
||||
|
||||
size := int(pageSize)
|
||||
if size <= 0 {
|
||||
return defaultPageSize
|
||||
}
|
||||
|
||||
if size > maxPageSize {
|
||||
return maxPageSize
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
// calculateMD5 計算字串的 MD5 雜湊值
|
||||
func calculateMD5(content string) string {
|
||||
hash := md5.Sum([]byte(content))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
|
@ -75,6 +75,7 @@ type QueryBuilder[T Table] interface {
|
|||
OrderBy(column string, order Order) QueryBuilder[T]
|
||||
Limit(n int) QueryBuilder[T]
|
||||
Select(columns ...string) QueryBuilder[T]
|
||||
AllowFiltering() QueryBuilder[T]
|
||||
Scan(ctx context.Context, dest *[]T) error
|
||||
One(ctx context.Context) (T, error)
|
||||
Count(ctx context.Context) (int64, error)
|
||||
|
|
@ -82,11 +83,12 @@ type QueryBuilder[T Table] interface {
|
|||
|
||||
// queryBuilder 是 QueryBuilder 的具體實作
|
||||
type queryBuilder[T Table] struct {
|
||||
repo *repository[T]
|
||||
conditions []Condition
|
||||
orders []orderBy
|
||||
limit int
|
||||
columns []string
|
||||
repo *repository[T]
|
||||
conditions []Condition
|
||||
orders []orderBy
|
||||
limit int
|
||||
columns []string
|
||||
allowFiltering bool
|
||||
}
|
||||
|
||||
type orderBy struct {
|
||||
|
|
@ -125,6 +127,12 @@ func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] {
|
|||
return q
|
||||
}
|
||||
|
||||
// AllowFiltering 允許不使用 partition key 的查詢(效能較差,慎用)
|
||||
func (q *queryBuilder[T]) AllowFiltering() QueryBuilder[T] {
|
||||
q.allowFiltering = true
|
||||
return q
|
||||
}
|
||||
|
||||
// Scan 執行查詢並將結果掃描到 dest
|
||||
func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
|
||||
if dest == nil {
|
||||
|
|
@ -171,6 +179,11 @@ func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
|
|||
builder = builder.Limit(uint(q.limit))
|
||||
}
|
||||
|
||||
// 添加 ALLOW FILTERING
|
||||
if q.allowFiltering {
|
||||
builder = builder.AllowFiltering()
|
||||
}
|
||||
|
||||
stmt, names := builder.ToCql()
|
||||
query := q.repo.db.withContextAndTimestamp(ctx,
|
||||
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
|
||||
|
|
@ -213,6 +226,11 @@ func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) {
|
|||
builder = builder.Where(cmps...)
|
||||
}
|
||||
|
||||
// 添加 ALLOW FILTERING
|
||||
if q.allowFiltering {
|
||||
builder = builder.AllowFiltering()
|
||||
}
|
||||
|
||||
stmt, names := builder.ToCql()
|
||||
query := q.repo.db.withContextAndTimestamp(ctx,
|
||||
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
|
||||
|
|
|
|||
|
|
@ -130,8 +130,23 @@ func (r *repository[T]) updateSelective(ctx context.Context, doc T, includeZero
|
|||
func (r *repository[T]) Delete(ctx context.Context, pk any) error {
|
||||
t := table.New(r.metadata)
|
||||
stmt, names := t.Delete()
|
||||
q := r.db.withContextAndTimestamp(ctx,
|
||||
r.db.session.Query(stmt, names).Bind(pk))
|
||||
|
||||
// 如果 pk 是 struct,使用 BindStruct;否則使用 Bind
|
||||
var q *gocqlx.Queryx
|
||||
if reflect.TypeOf(pk).Kind() == reflect.Struct {
|
||||
q = r.db.withContextAndTimestamp(ctx,
|
||||
r.db.session.Query(stmt, names).BindStruct(pk))
|
||||
} else {
|
||||
// 單一主鍵欄位的情況
|
||||
// 注意:這只適用於單一 Partition Key 且無 Clustering Key 的情況
|
||||
if len(r.metadata.PartKey) != 1 || len(r.metadata.SortKey) > 0 {
|
||||
return ErrInvalidInput.WithTable(r.table).WithError(
|
||||
fmt.Errorf("single value primary key only supported for single partition key without clustering key"),
|
||||
)
|
||||
}
|
||||
q = r.db.withContextAndTimestamp(ctx,
|
||||
r.db.session.Query(stmt, names).Bind(pk))
|
||||
}
|
||||
return q.ExecRelease()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,474 @@
|
|||
# Centrifugo Client Library
|
||||
|
||||
Go 語言的 Centrifugo 即時訊息服務客戶端庫,提供完整的 Server-side API 支援。
|
||||
|
||||
## 功能特色
|
||||
|
||||
- ✅ **HTTP API 客戶端** - 發布訊息、訂閱管理、在線狀態、歷史訊息
|
||||
- ✅ **JWT Token 生成** - 連線認證和私有頻道訂閱
|
||||
- ✅ **Token 黑名單** - 撤銷單一 Token 或用戶所有 Token
|
||||
- ✅ **在線狀態追蹤** - Redis 存儲 + Centrifugo Presence API
|
||||
- ✅ **統一服務入口** - 簡單易用的 `Service` 整合介面
|
||||
|
||||
## 安裝依賴
|
||||
|
||||
```bash
|
||||
go get github.com/golang-jwt/jwt/v5
|
||||
go get github.com/zeromicro/go-zero/core/stores/redis
|
||||
```
|
||||
|
||||
## 快速開始
|
||||
|
||||
### 1. 創建服務實例(推薦方式)
|
||||
|
||||
```go
|
||||
import (
|
||||
"backend/pkg/library/centrifugo"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// 創建 Redis 客戶端
|
||||
rds, _ := redis.NewRedis(redis.RedisConf{
|
||||
Host: "localhost:6379",
|
||||
Type: "node",
|
||||
})
|
||||
|
||||
// 創建 Centrifugo 服務
|
||||
svc := centrifugo.NewService(centrifugo.ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "your-api-key",
|
||||
TokenSecret: "your-jwt-secret",
|
||||
Redis: rds, // 可選,用於黑名單和在線狀態
|
||||
})
|
||||
```
|
||||
|
||||
### 2. 發布訊息
|
||||
|
||||
```go
|
||||
ctx := context.Background()
|
||||
|
||||
// 方法 1: 使用 Service 便捷方法
|
||||
result, err := svc.PublishJSON(ctx, "chat:room-123", map[string]interface{}{
|
||||
"message": "Hello, World!",
|
||||
"user": "daniel",
|
||||
})
|
||||
|
||||
// 方法 2: 使用 Client 直接調用
|
||||
result, err := svc.Client().Publish(ctx, "chat:room-123", []byte(`{"message": "Hello!"}`))
|
||||
|
||||
// 批量發布到多個頻道
|
||||
channels := []string{"user:1", "user:2", "user:3"}
|
||||
err := svc.BroadcastJSON(ctx, channels, map[string]string{
|
||||
"type": "notification",
|
||||
"message": "System maintenance",
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Token 生成
|
||||
|
||||
```go
|
||||
// 快速生成連線 Token
|
||||
token, err := svc.GenerateToken("user-123")
|
||||
|
||||
// 生成帶用戶資訊的 Token
|
||||
token, err := svc.GenerateTokenWithInfo("user-123", map[string]interface{}{
|
||||
"name": "Daniel",
|
||||
"avatar": "https://example.com/avatar.jpg",
|
||||
})
|
||||
|
||||
// 完整選項
|
||||
token, err := svc.Token().GenerateConnectionToken(centrifugo.ConnectionTokenOptions{
|
||||
UserID: "user-123",
|
||||
Info: map[string]interface{}{"role": "admin"},
|
||||
Channels: []string{"chat:room-1", "chat:room-2"}, // 自動訂閱
|
||||
})
|
||||
|
||||
// 訂閱 Token(用於私有頻道)
|
||||
token, err := svc.Token().QuickSubscriptionToken("user-123", "private:room-456")
|
||||
```
|
||||
|
||||
### 4. 撤銷 Token(踢人)
|
||||
|
||||
```go
|
||||
// 最常用:撤銷用戶所有 Token 並斷開連線
|
||||
// 適用於:用戶被封禁、密碼變更、用戶登出全部設備
|
||||
err := svc.InvalidateUser(ctx, "user-123")
|
||||
|
||||
// 只斷開連線(不撤銷 Token)
|
||||
err := svc.Disconnect(ctx, "user-123")
|
||||
|
||||
// 撤銷特定 Token(需要 JTI)
|
||||
err := svc.Blacklist().RevokeToken(ctx, jti, time.Hour)
|
||||
|
||||
// 撤銷用戶所有 Token(不斷開連線)
|
||||
err := svc.Blacklist().RevokeUserTokens(ctx, "user-123")
|
||||
```
|
||||
|
||||
### 5. 在線狀態追蹤
|
||||
|
||||
```go
|
||||
// 檢查單一用戶是否在線
|
||||
online, err := svc.IsUserOnline(ctx, "user-123")
|
||||
|
||||
// 批量獲取在線狀態
|
||||
status, err := svc.GetUsersOnlineStatus(ctx, []string{"user-1", "user-2", "user-3"})
|
||||
// status = map[string]bool{"user-1": true, "user-2": false, "user-3": true}
|
||||
|
||||
// 處理 Centrifugo Connect/Disconnect Proxy 事件
|
||||
svc.Online().HandleConnect(ctx, "user-123")
|
||||
svc.Online().HandleDisconnect(ctx, "user-123")
|
||||
|
||||
// 使用 Centrifugo Presence API(頻道級別)
|
||||
users, err := svc.Online().GetChannelOnlineUsers(ctx, "chat:room-123")
|
||||
stats, err := svc.Online().GetChannelStats(ctx, "chat:room-123")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 獨立使用各元件
|
||||
|
||||
如果不需要完整的 Service,可以獨立使用各元件:
|
||||
|
||||
### HTTP API Client
|
||||
|
||||
```go
|
||||
// 創建客戶端
|
||||
client := centrifugo.NewClient("http://localhost:8000", "your-api-key")
|
||||
|
||||
// 使用自定義配置
|
||||
client := centrifugo.NewClientWithConfig(centrifugo.ClientConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "your-api-key",
|
||||
Timeout: 5 * time.Second,
|
||||
MaxIdleConns: 200,
|
||||
MaxIdleConnsPerHost: 50,
|
||||
})
|
||||
|
||||
// API 調用
|
||||
client.Publish(ctx, channel, data)
|
||||
client.PublishJSON(ctx, channel, data)
|
||||
client.Broadcast(ctx, channels, data)
|
||||
client.Subscribe(ctx, user, channel)
|
||||
client.Unsubscribe(ctx, user, channel)
|
||||
client.Disconnect(ctx, user)
|
||||
client.DisconnectWithCode(ctx, user, code, reason)
|
||||
client.Presence(ctx, channel)
|
||||
client.PresenceStats(ctx, channel)
|
||||
client.History(ctx, channel, limit)
|
||||
client.HistoryReverse(ctx, channel, limit)
|
||||
client.Channels(ctx)
|
||||
client.ChannelsWithPattern(ctx, pattern)
|
||||
client.Info(ctx)
|
||||
client.Ping(ctx)
|
||||
```
|
||||
|
||||
### Token Generator
|
||||
|
||||
```go
|
||||
// 創建生成器
|
||||
tokenGen := centrifugo.NewTokenGenerator("your-jwt-secret")
|
||||
|
||||
// 使用自定義配置
|
||||
tokenGen := centrifugo.NewTokenGeneratorWithConfig(centrifugo.TokenConfig{
|
||||
Secret: "your-jwt-secret",
|
||||
ExpireIn: 24 * time.Hour,
|
||||
})
|
||||
|
||||
// 生成 Token
|
||||
tokenGen.QuickConnectionToken(userID)
|
||||
tokenGen.QuickSubscriptionToken(userID, channel)
|
||||
tokenGen.GenerateConnectionToken(opts)
|
||||
tokenGen.GenerateSubscriptionToken(opts)
|
||||
tokenGen.GenerateAnonymousToken()
|
||||
```
|
||||
|
||||
### Token Blacklist
|
||||
|
||||
```go
|
||||
// 創建黑名單管理器
|
||||
blacklist := centrifugo.NewTokenBlacklist(redisClient)
|
||||
|
||||
// 撤銷操作
|
||||
blacklist.RevokeToken(ctx, jti, ttl) // 撤銷單一 Token
|
||||
blacklist.RevokeUserTokens(ctx, userID) // 撤銷用戶所有 Token
|
||||
|
||||
// 驗證操作
|
||||
blacklist.IsTokenRevoked(ctx, jti) // 檢查 Token 是否被撤銷
|
||||
blacklist.GetUserTokenVersion(ctx, userID) // 獲取用戶 Token 版本
|
||||
blacklist.IsTokenVersionValid(ctx, userID, v) // 檢查版本是否有效
|
||||
```
|
||||
|
||||
### Online Manager
|
||||
|
||||
```go
|
||||
// 創建管理器
|
||||
store := centrifugo.NewRedisOnlineStore(redisClient)
|
||||
onlineManager := centrifugo.NewOnlineManagerWithTTL(client, store, 5*time.Minute)
|
||||
|
||||
// Redis 存儲操作
|
||||
onlineManager.HandleConnect(ctx, userID)
|
||||
onlineManager.HandleDisconnect(ctx, userID)
|
||||
onlineManager.IsUserOnline(ctx, userID)
|
||||
onlineManager.GetUsersOnlineStatus(ctx, userIDs)
|
||||
onlineManager.RefreshOnline(ctx, userID)
|
||||
|
||||
// Centrifugo Presence API
|
||||
onlineManager.IsUserInChannel(ctx, userID, channel)
|
||||
onlineManager.GetChannelOnlineUsers(ctx, channel)
|
||||
onlineManager.GetChannelStats(ctx, channel)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 前端整合範例
|
||||
|
||||
### JavaScript (使用 centrifuge-js)
|
||||
|
||||
```javascript
|
||||
import { Centrifuge } from 'centrifuge';
|
||||
|
||||
// 從後端 API 獲取 Token
|
||||
const getToken = async () => {
|
||||
const response = await fetch('/api/centrifugo/token');
|
||||
const data = await response.json();
|
||||
return data.token;
|
||||
};
|
||||
|
||||
// 創建連線
|
||||
const centrifuge = new Centrifuge('ws://localhost:8000/connection/websocket', {
|
||||
getToken: getToken,
|
||||
});
|
||||
|
||||
// 訂閱頻道
|
||||
const sub = centrifuge.newSubscription('chat:room-123');
|
||||
sub.on('publication', (ctx) => {
|
||||
console.log('Received:', ctx.data);
|
||||
});
|
||||
sub.subscribe();
|
||||
|
||||
// 連線
|
||||
centrifuge.connect();
|
||||
```
|
||||
|
||||
### 後端 Token API
|
||||
|
||||
```go
|
||||
// handlers/centrifugo.go
|
||||
func (h *Handler) GetConnectionToken(c *gin.Context) {
|
||||
userID := c.GetString("user_id") // 從 JWT 或 session 獲取
|
||||
|
||||
token, err := h.svc.GenerateToken(userID)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": "failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"token": token})
|
||||
}
|
||||
|
||||
func (h *Handler) GetSubscriptionToken(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
channel := c.Query("channel")
|
||||
|
||||
// 驗證用戶是否有權限訂閱此頻道
|
||||
if !h.canSubscribe(userID, channel) {
|
||||
c.JSON(403, gin.H{"error": "forbidden"})
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.svc.Token().QuickSubscriptionToken(userID, channel)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": "failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"token": token})
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Centrifugo Proxy 整合
|
||||
|
||||
### Connect Proxy
|
||||
|
||||
```go
|
||||
// POST /centrifugo/connect
|
||||
func (h *Handler) CentrifugoConnect(c *gin.Context) {
|
||||
var req struct {
|
||||
Client string `json:"client"`
|
||||
Transport string `json:"transport"`
|
||||
Protocol string `json:"protocol"`
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
c.BindJSON(&req)
|
||||
|
||||
// 從 Token 驗證用戶(Centrifugo 會傳遞)
|
||||
userID := extractUserID(req.Data)
|
||||
|
||||
// 記錄連線
|
||||
h.svc.Online().HandleConnect(c, userID)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"result": map[string]interface{}{
|
||||
"user": userID,
|
||||
},
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
### Disconnect Proxy
|
||||
|
||||
```go
|
||||
// POST /centrifugo/disconnect
|
||||
func (h *Handler) CentrifugoDisconnect(c *gin.Context) {
|
||||
var req struct {
|
||||
Client string `json:"client"`
|
||||
User string `json:"user"`
|
||||
}
|
||||
c.BindJSON(&req)
|
||||
|
||||
// 記錄離線
|
||||
h.svc.Online().HandleDisconnect(c, req.User)
|
||||
|
||||
c.JSON(200, gin.H{"result": map[string]interface{}{}})
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Centrifugo 配置參考
|
||||
|
||||
```json
|
||||
{
|
||||
"token_hmac_secret_key": "your-jwt-secret",
|
||||
"api_key": "your-api-key",
|
||||
"admin": true,
|
||||
"allowed_origins": ["http://localhost:3000"],
|
||||
"proxy_connect_endpoint": "http://localhost:8080/centrifugo/connect",
|
||||
"proxy_disconnect_endpoint": "http://localhost:8080/centrifugo/disconnect",
|
||||
"namespaces": [
|
||||
{
|
||||
"name": "chat",
|
||||
"presence": true,
|
||||
"history_size": 100,
|
||||
"history_ttl": "300s"
|
||||
},
|
||||
{
|
||||
"name": "private",
|
||||
"presence": true,
|
||||
"protected": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API 參考
|
||||
|
||||
### Service 方法
|
||||
|
||||
| 方法 | 說明 |
|
||||
|------|------|
|
||||
| `Client()` | 返回 HTTP API 客戶端 |
|
||||
| `Token()` | 返回 Token 生成器 |
|
||||
| `Blacklist()` | 返回黑名單管理器(可能為 nil) |
|
||||
| `Online()` | 返回在線狀態管理器(可能為 nil) |
|
||||
| `PublishJSON(ctx, channel, data)` | 發布 JSON 訊息 |
|
||||
| `BroadcastJSON(ctx, channels, data)` | 批量發布 JSON 訊息 |
|
||||
| `Disconnect(ctx, userID)` | 斷開用戶連線 |
|
||||
| `GenerateToken(userID)` | 快速生成連線 Token |
|
||||
| `GenerateTokenWithInfo(userID, info)` | 生成帶資訊的連線 Token |
|
||||
| `InvalidateUser(ctx, userID)` | 撤銷所有 Token 並斷開連線 |
|
||||
| `IsUserOnline(ctx, userID)` | 檢查用戶是否在線 |
|
||||
| `GetUsersOnlineStatus(ctx, userIDs)` | 批量獲取在線狀態 |
|
||||
|
||||
### Client 方法
|
||||
|
||||
| 方法 | 說明 | 返回值 |
|
||||
|------|------|--------|
|
||||
| `Publish(ctx, channel, data)` | 發布訊息 | `*PublishResult, error` |
|
||||
| `PublishJSON(ctx, channel, data)` | 發布 JSON | `*PublishResult, error` |
|
||||
| `Broadcast(ctx, channels, data)` | 批量發布 | `error` |
|
||||
| `BroadcastJSON(ctx, channels, data)` | 批量發布 JSON | `error` |
|
||||
| `Subscribe(ctx, user, channel)` | 訂閱用戶 | `error` |
|
||||
| `Unsubscribe(ctx, user, channel)` | 取消訂閱 | `error` |
|
||||
| `Disconnect(ctx, user)` | 斷開連線 | `error` |
|
||||
| `DisconnectWithCode(ctx, user, code, reason)` | 帶代碼斷開連線 | `error` |
|
||||
| `Presence(ctx, channel)` | 在線用戶 | `*PresenceResult, error` |
|
||||
| `PresenceStats(ctx, channel)` | 在線統計 | `*PresenceStatsResult, error` |
|
||||
| `History(ctx, channel, limit)` | 歷史訊息 | `*HistoryResult, error` |
|
||||
| `HistoryReverse(ctx, channel, limit)` | 歷史訊息(倒序) | `*HistoryResult, error` |
|
||||
| `Channels(ctx)` | 活躍頻道 | `*ChannelsResult, error` |
|
||||
| `ChannelsWithPattern(ctx, pattern)` | 匹配頻道 | `*ChannelsResult, error` |
|
||||
| `Info(ctx)` | 伺服器資訊 | `*InfoResult, error` |
|
||||
| `Ping(ctx)` | 健康檢查 | `error` |
|
||||
|
||||
### TokenGenerator 方法
|
||||
|
||||
| 方法 | 說明 |
|
||||
|------|------|
|
||||
| `GenerateConnectionToken(opts)` | 生成連線 Token(完整選項) |
|
||||
| `GenerateSubscriptionToken(opts)` | 生成訂閱 Token(完整選項) |
|
||||
| `GenerateAnonymousToken()` | 生成匿名 Token |
|
||||
| `QuickConnectionToken(userID)` | 快速生成連線 Token |
|
||||
| `QuickSubscriptionToken(userID, channel)` | 快速生成訂閱 Token |
|
||||
|
||||
### TokenBlacklist 方法
|
||||
|
||||
| 方法 | 說明 |
|
||||
|------|------|
|
||||
| `RevokeToken(ctx, jti, ttl)` | 撤銷特定 Token |
|
||||
| `RevokeUserTokens(ctx, userID)` | 撤銷用戶所有 Token |
|
||||
| `IsTokenRevoked(ctx, jti)` | 檢查 Token 是否被撤銷 |
|
||||
| `GetUserTokenVersion(ctx, userID)` | 獲取用戶 Token 版本 |
|
||||
| `IsTokenVersionValid(ctx, userID, version)` | 檢查版本是否有效 |
|
||||
|
||||
---
|
||||
|
||||
## 錯誤處理
|
||||
|
||||
```go
|
||||
result, err := svc.Client().Publish(ctx, channel, data)
|
||||
if err != nil {
|
||||
// 檢查是否為 Centrifugo API 錯誤
|
||||
if apiErr, ok := err.(*centrifugo.APIError); ok {
|
||||
fmt.Printf("Centrifugo error code: %d, message: %s\n",
|
||||
apiErr.Code, apiErr.Message)
|
||||
} else {
|
||||
// 網路錯誤或其他錯誤
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 檢查特定錯誤
|
||||
if errors.Is(err, centrifugo.ErrBlacklistNotConfigured) {
|
||||
// 黑名單未配置
|
||||
}
|
||||
if errors.Is(err, centrifugo.ErrOnlineStoreNotConfigured) {
|
||||
// 在線狀態存儲未配置
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 檔案結構
|
||||
|
||||
```
|
||||
pkg/library/centrifugo/
|
||||
├── centrifugo.go # 主入口,Service 整合介面
|
||||
├── client.go # HTTP API 客戶端
|
||||
├── token.go # JWT Token 生成器
|
||||
├── blacklist.go # Token 黑名單管理
|
||||
├── online.go # 在線狀態管理介面
|
||||
├── online_redis.go # Redis 在線狀態實作
|
||||
├── README.md # 文檔
|
||||
└── *_test.go # 測試文件
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// 錯誤定義
|
||||
var (
|
||||
ErrBlacklistNotConfigured = errors.New("token blacklist is not configured")
|
||||
ErrOnlineStoreNotConfigured = errors.New("online store is not configured")
|
||||
)
|
||||
|
||||
// TokenBlacklist Token 黑名單管理器
|
||||
// 提供兩種撤銷機制:
|
||||
// 1. 單一 Token 撤銷:使用 JTI(JWT ID)將特定 Token 加入黑名單
|
||||
// 2. 用戶全部撤銷:使用版本號機制,使用戶之前所有 Token 失效
|
||||
type TokenBlacklist struct {
|
||||
redis *redis.Redis
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewTokenBlacklist 創建 Token 黑名單管理器
|
||||
func NewTokenBlacklist(redisClient *redis.Redis) *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
redis: redisClient,
|
||||
prefix: "centrifugo:blacklist:",
|
||||
}
|
||||
}
|
||||
|
||||
// NewTokenBlacklistWithPrefix 創建帶自定義前綴的 Token 黑名單管理器
|
||||
func NewTokenBlacklistWithPrefix(redisClient *redis.Redis, prefix string) *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
redis: redisClient,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 撤銷操作 ====================
|
||||
|
||||
// RevokeToken 撤銷特定 Token(使用 JTI)
|
||||
// ttl: 黑名單過期時間,應設置為 Token 的剩餘有效時間
|
||||
//
|
||||
// 使用場景:
|
||||
// - 用戶登出單一設備
|
||||
// - 檢測到可疑活動的特定 session
|
||||
func (b *TokenBlacklist) RevokeToken(ctx context.Context, jti string, ttl time.Duration) error {
|
||||
if jti == "" {
|
||||
return errors.New("jti cannot be empty")
|
||||
}
|
||||
key := b.tokenKey(jti)
|
||||
return b.redis.SetexCtx(ctx, key, "revoked", int(ttl.Seconds()))
|
||||
}
|
||||
|
||||
// RevokeUserTokens 撤銷用戶的所有 Token(使用版本控制)
|
||||
// 通過更新版本號,使該用戶之前發出的所有 Token 失效
|
||||
//
|
||||
// 使用場景:
|
||||
// - 用戶被封禁
|
||||
// - 密碼變更
|
||||
// - 用戶主動登出全部設備
|
||||
func (b *TokenBlacklist) RevokeUserTokens(ctx context.Context, userID string) error {
|
||||
if userID == "" {
|
||||
return errors.New("userID cannot be empty")
|
||||
}
|
||||
key := b.userVersionKey(userID)
|
||||
version := time.Now().UnixNano()
|
||||
// 設置 7 天過期,足夠長於任何 Token 的有效期
|
||||
return b.redis.SetexCtx(ctx, key, fmt.Sprintf("%d", version), 7*24*3600)
|
||||
}
|
||||
|
||||
// ==================== 驗證操作 ====================
|
||||
|
||||
// IsTokenRevoked 檢查 Token 是否被撤銷(使用 JTI)
|
||||
func (b *TokenBlacklist) IsTokenRevoked(ctx context.Context, jti string) (bool, error) {
|
||||
if jti == "" {
|
||||
return false, nil
|
||||
}
|
||||
key := b.tokenKey(jti)
|
||||
exists, err := b.redis.ExistsCtx(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetUserTokenVersion 獲取用戶的 Token 版本
|
||||
// 返回 0 表示沒有設置版本(用戶從未被撤銷過)
|
||||
func (b *TokenBlacklist) GetUserTokenVersion(ctx context.Context, userID string) (int64, error) {
|
||||
if userID == "" {
|
||||
return 0, nil
|
||||
}
|
||||
key := b.userVersionKey(userID)
|
||||
val, err := b.redis.GetCtx(ctx, key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if val == "" {
|
||||
return 0, nil
|
||||
}
|
||||
var version int64
|
||||
_, err = fmt.Sscanf(val, "%d", &version)
|
||||
return version, err
|
||||
}
|
||||
|
||||
// IsTokenVersionValid 檢查 Token 版本是否有效
|
||||
// tokenVersion: Token 內嵌的版本號
|
||||
// 如果 currentVersion > tokenVersion,表示 Token 已被撤銷
|
||||
func (b *TokenBlacklist) IsTokenVersionValid(ctx context.Context, userID string, tokenVersion int64) (bool, error) {
|
||||
currentVersion, err := b.GetUserTokenVersion(ctx, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// 如果沒有設置版本,或 Token 版本 >= 當前版本,則有效
|
||||
return currentVersion == 0 || tokenVersion >= currentVersion, nil
|
||||
}
|
||||
|
||||
// ==================== Key 生成 ====================
|
||||
|
||||
func (b *TokenBlacklist) tokenKey(jti string) string {
|
||||
return b.prefix + "token:" + jti
|
||||
}
|
||||
|
||||
func (b *TokenBlacklist) userVersionKey(userID string) string {
|
||||
return b.prefix + "user_version:" + userID
|
||||
}
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
func setupTestRedis(t *testing.T) (*redis.Redis, func()) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
rds, err := redis.NewRedis(redis.RedisConf{
|
||||
Host: mr.Addr(),
|
||||
Type: "node",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return rds, func() {
|
||||
mr.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTokenBlacklist(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
|
||||
assert.NotNil(t, blacklist)
|
||||
assert.Equal(t, "centrifugo:blacklist:", blacklist.prefix)
|
||||
}
|
||||
|
||||
func TestNewTokenBlacklistWithPrefix(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklistWithPrefix(rds, "custom:prefix:")
|
||||
|
||||
assert.NotNil(t, blacklist)
|
||||
assert.Equal(t, "custom:prefix:", blacklist.prefix)
|
||||
}
|
||||
|
||||
func TestRevokeToken(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
jti := "test-jti-123"
|
||||
ttl := 1 * time.Hour
|
||||
|
||||
// 撤銷 Token
|
||||
err := blacklist.RevokeToken(ctx, jti, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 檢查是否被撤銷
|
||||
revoked, err := blacklist.IsTokenRevoked(ctx, jti)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, revoked)
|
||||
}
|
||||
|
||||
func TestRevokeToken_EmptyJTI(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
err := blacklist.RevokeToken(ctx, "", time.Hour)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "jti cannot be empty")
|
||||
}
|
||||
|
||||
func TestIsTokenRevoked_NotRevoked(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
// 檢查未撤銷的 Token
|
||||
revoked, err := blacklist.IsTokenRevoked(ctx, "non-existent-jti")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, revoked)
|
||||
}
|
||||
|
||||
func TestIsTokenRevoked_EmptyJTI(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
// 空 JTI 應該返回 false
|
||||
revoked, err := blacklist.IsTokenRevoked(ctx, "")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, revoked)
|
||||
}
|
||||
|
||||
func TestRevokeUserTokens(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
|
||||
// 撤銷用戶所有 Token
|
||||
err := blacklist.RevokeUserTokens(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 獲取版本
|
||||
version, err := blacklist.GetUserTokenVersion(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, version, int64(0))
|
||||
}
|
||||
|
||||
func TestRevokeUserTokens_EmptyUserID(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
err := blacklist.RevokeUserTokens(ctx, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "userID cannot be empty")
|
||||
}
|
||||
|
||||
func TestGetUserTokenVersion_NoVersion(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
// 未設置版本的用戶應該返回 0
|
||||
version, err := blacklist.GetUserTokenVersion(ctx, "new-user")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), version)
|
||||
}
|
||||
|
||||
func TestGetUserTokenVersion_EmptyUserID(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
version, err := blacklist.GetUserTokenVersion(ctx, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), version)
|
||||
}
|
||||
|
||||
func TestIsTokenVersionValid(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
|
||||
// 未設置版本時,任何版本都應該有效
|
||||
valid, err := blacklist.IsTokenVersionValid(ctx, userID, 0)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid)
|
||||
|
||||
// 撤銷用戶 Token
|
||||
err = blacklist.RevokeUserTokens(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 獲取當前版本
|
||||
currentVersion, err := blacklist.GetUserTokenVersion(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 舊版本應該無效
|
||||
valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion-1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, valid)
|
||||
|
||||
// 當前版本應該有效
|
||||
valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid)
|
||||
|
||||
// 更新版本應該有效
|
||||
valid, err = blacklist.IsTokenVersionValid(ctx, userID, currentVersion+1)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid)
|
||||
}
|
||||
|
||||
func TestRevokeUserTokens_MultipleRevokes(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklist(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
|
||||
// 第一次撤銷
|
||||
err := blacklist.RevokeUserTokens(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
version1, err := blacklist.GetUserTokenVersion(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 等待一點時間確保時間戳不同
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// 第二次撤銷
|
||||
err = blacklist.RevokeUserTokens(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
version2, err := blacklist.GetUserTokenVersion(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 第二次的版本應該更大
|
||||
assert.Greater(t, version2, version1)
|
||||
|
||||
// 第一次的版本應該已經無效
|
||||
valid, err := blacklist.IsTokenVersionValid(ctx, userID, version1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, valid)
|
||||
}
|
||||
|
||||
func TestKeyGeneration(t *testing.T) {
|
||||
rds, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
blacklist := NewTokenBlacklistWithPrefix(rds, "test:")
|
||||
|
||||
// 測試 key 生成
|
||||
assert.Equal(t, "test:token:jti-123", blacklist.tokenKey("jti-123"))
|
||||
assert.Equal(t, "test:user_version:user-456", blacklist.userVersionKey("user-456"))
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
// Package centrifugo 提供 Centrifugo 即時訊息服務的完整 Go 客戶端
|
||||
//
|
||||
// 功能包含:
|
||||
// - HTTP API 客戶端(發布訊息、訂閱管理、在線狀態等)
|
||||
// - JWT Token 生成(連線認證、私有頻道訂閱)
|
||||
// - Token 黑名單管理(撤銷單一 Token、撤銷用戶所有 Token)
|
||||
// - 在線狀態追蹤(Redis 或記憶體存儲)
|
||||
//
|
||||
// 基本使用:
|
||||
//
|
||||
// // 創建服務實例
|
||||
// svc := centrifugo.NewService(centrifugo.ServiceConfig{
|
||||
// APIURL: "http://localhost:8000",
|
||||
// APIKey: "your-api-key",
|
||||
// TokenSecret: "your-jwt-secret",
|
||||
// Redis: redisClient, // 可選,用於黑名單和在線狀態
|
||||
// })
|
||||
//
|
||||
// // 發布訊息
|
||||
// svc.Client().PublishJSON(ctx, "chat:room-1", data)
|
||||
//
|
||||
// // 生成 Token
|
||||
// token, _ := svc.Token().QuickConnectionToken("user-123")
|
||||
//
|
||||
// // 撤銷用戶所有 Token 並踢出
|
||||
// svc.InvalidateUser(ctx, "user-123")
|
||||
package centrifugo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// Service Centrifugo 服務整合介面
|
||||
// 提供 HTTP API、Token 生成、黑名單管理、在線狀態追蹤的統一入口
|
||||
type Service struct {
|
||||
client *Client
|
||||
token *TokenGenerator
|
||||
blacklist *TokenBlacklist
|
||||
online *OnlineManager
|
||||
}
|
||||
|
||||
// ServiceConfig 服務配置
|
||||
type ServiceConfig struct {
|
||||
// APIURL Centrifugo HTTP API 地址(必填)
|
||||
APIURL string
|
||||
// APIKey Centrifugo API 密鑰(必填)
|
||||
APIKey string
|
||||
// TokenSecret JWT Token 簽名密鑰(必填)
|
||||
TokenSecret string
|
||||
// TokenExpire Token 過期時間(預設 1 小時)
|
||||
TokenExpire time.Duration
|
||||
// Redis 客戶端(可選,用於黑名單和在線狀態)
|
||||
Redis *redis.Redis
|
||||
// ClientConfig HTTP 客戶端配置(可選)
|
||||
ClientConfig *ClientConfig
|
||||
// OnlineTTL 在線狀態過期時間(預設 5 分鐘)
|
||||
OnlineTTL time.Duration
|
||||
// KeyPrefix Redis key 前綴(預設 "centrifugo:")
|
||||
KeyPrefix string
|
||||
}
|
||||
|
||||
// NewService 創建 Centrifugo 服務實例
|
||||
func NewService(cfg ServiceConfig) *Service {
|
||||
// 設定預設值
|
||||
if cfg.TokenExpire == 0 {
|
||||
cfg.TokenExpire = time.Hour
|
||||
}
|
||||
if cfg.OnlineTTL == 0 {
|
||||
cfg.OnlineTTL = 5 * time.Minute
|
||||
}
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "centrifugo:"
|
||||
}
|
||||
|
||||
// 創建 HTTP 客戶端
|
||||
var client *Client
|
||||
if cfg.ClientConfig != nil {
|
||||
client = NewClientWithConfig(*cfg.ClientConfig)
|
||||
} else {
|
||||
client = NewClient(cfg.APIURL, cfg.APIKey)
|
||||
}
|
||||
|
||||
// 創建 Token 生成器
|
||||
token := NewTokenGeneratorWithConfig(TokenConfig{
|
||||
Secret: cfg.TokenSecret,
|
||||
ExpireIn: cfg.TokenExpire,
|
||||
})
|
||||
|
||||
svc := &Service{
|
||||
client: client,
|
||||
token: token,
|
||||
}
|
||||
|
||||
// 如果有 Redis,創建黑名單管理器和在線狀態管理器
|
||||
if cfg.Redis != nil {
|
||||
svc.blacklist = NewTokenBlacklistWithPrefix(cfg.Redis, cfg.KeyPrefix+"blacklist:")
|
||||
store := NewRedisOnlineStoreWithPrefix(cfg.Redis, cfg.KeyPrefix+"online:")
|
||||
svc.online = NewOnlineManagerWithTTL(client, store, cfg.OnlineTTL)
|
||||
}
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// Client 返回 HTTP API 客戶端
|
||||
func (s *Service) Client() *Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
// Token 返回 Token 生成器
|
||||
func (s *Service) Token() *TokenGenerator {
|
||||
return s.token
|
||||
}
|
||||
|
||||
// Blacklist 返回黑名單管理器(可能為 nil)
|
||||
func (s *Service) Blacklist() *TokenBlacklist {
|
||||
return s.blacklist
|
||||
}
|
||||
|
||||
// Online 返回在線狀態管理器(可能為 nil)
|
||||
func (s *Service) Online() *OnlineManager {
|
||||
return s.online
|
||||
}
|
||||
|
||||
// ==================== 便捷方法 ====================
|
||||
|
||||
// PublishJSON 發布 JSON 訊息到頻道
|
||||
func (s *Service) PublishJSON(ctx context.Context, channel string, data interface{}) (*PublishResult, error) {
|
||||
return s.client.PublishJSON(ctx, channel, data)
|
||||
}
|
||||
|
||||
// BroadcastJSON 批量發布 JSON 訊息到多個頻道
|
||||
func (s *Service) BroadcastJSON(ctx context.Context, channels []string, data interface{}) error {
|
||||
return s.client.BroadcastJSON(ctx, channels, data)
|
||||
}
|
||||
|
||||
// Disconnect 斷開用戶連線
|
||||
func (s *Service) Disconnect(ctx context.Context, userID string) error {
|
||||
return s.client.Disconnect(ctx, userID)
|
||||
}
|
||||
|
||||
// GenerateToken 快速生成連線 Token
|
||||
func (s *Service) GenerateToken(userID string) (string, error) {
|
||||
return s.token.QuickConnectionToken(userID)
|
||||
}
|
||||
|
||||
// GenerateTokenWithInfo 生成帶用戶資訊的連線 Token
|
||||
func (s *Service) GenerateTokenWithInfo(userID string, info map[string]interface{}) (string, error) {
|
||||
return s.token.GenerateConnectionToken(ConnectionTokenOptions{
|
||||
UserID: userID,
|
||||
Info: info,
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidateUser 撤銷用戶所有 Token 並斷開連線
|
||||
// 這是最常用的「踢人」方法,適用於:
|
||||
// - 用戶被封禁
|
||||
// - 密碼變更
|
||||
// - 用戶登出(全設備)
|
||||
func (s *Service) InvalidateUser(ctx context.Context, userID string) error {
|
||||
// 撤銷所有 Token
|
||||
if s.blacklist != nil {
|
||||
if err := s.blacklist.RevokeUserTokens(ctx, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// 斷開連線
|
||||
return s.client.Disconnect(ctx, userID)
|
||||
}
|
||||
|
||||
// IsUserOnline 檢查用戶是否在線
|
||||
func (s *Service) IsUserOnline(ctx context.Context, userID string) (bool, error) {
|
||||
if s.online == nil {
|
||||
return false, ErrOnlineStoreNotConfigured
|
||||
}
|
||||
return s.online.IsUserOnline(ctx, userID)
|
||||
}
|
||||
|
||||
// GetUsersOnlineStatus 批量獲取用戶在線狀態
|
||||
func (s *Service) GetUsersOnlineStatus(ctx context.Context, userIDs []string) (map[string]bool, error) {
|
||||
if s.online == nil {
|
||||
return nil, ErrOnlineStoreNotConfigured
|
||||
}
|
||||
return s.online.GetUsersOnlineStatus(ctx, userIDs)
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,303 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
func setupServiceTestRedis(t *testing.T) (*redis.Redis, func()) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
rds, err := redis.NewRedis(redis.RedisConf{
|
||||
Host: mr.Addr(),
|
||||
Type: "node",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return rds, func() {
|
||||
mr.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewService(t *testing.T) {
|
||||
rds, cleanup := setupServiceTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
Redis: rds,
|
||||
})
|
||||
|
||||
assert.NotNil(t, svc)
|
||||
assert.NotNil(t, svc.Client())
|
||||
assert.NotNil(t, svc.Token())
|
||||
assert.NotNil(t, svc.Blacklist())
|
||||
assert.NotNil(t, svc.Online())
|
||||
}
|
||||
|
||||
func TestNewService_WithoutRedis(t *testing.T) {
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
Redis: nil, // 沒有 Redis
|
||||
})
|
||||
|
||||
assert.NotNil(t, svc)
|
||||
assert.NotNil(t, svc.Client())
|
||||
assert.NotNil(t, svc.Token())
|
||||
assert.Nil(t, svc.Blacklist()) // 應該為 nil
|
||||
assert.Nil(t, svc.Online()) // 應該為 nil
|
||||
}
|
||||
|
||||
func TestNewService_DefaultValues(t *testing.T) {
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
})
|
||||
|
||||
assert.NotNil(t, svc)
|
||||
// 檢查預設值已被應用(通過生成 Token 間接驗證)
|
||||
token, err := svc.GenerateToken("user-123")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestNewService_WithCustomConfig(t *testing.T) {
|
||||
rds, cleanup := setupServiceTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
customExpire := 24 * time.Hour
|
||||
customTTL := 10 * time.Minute
|
||||
customPrefix := "custom:"
|
||||
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
TokenExpire: customExpire,
|
||||
Redis: rds,
|
||||
OnlineTTL: customTTL,
|
||||
KeyPrefix: customPrefix,
|
||||
})
|
||||
|
||||
assert.NotNil(t, svc)
|
||||
}
|
||||
|
||||
func TestService_GenerateToken(t *testing.T) {
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
})
|
||||
|
||||
token, err := svc.GenerateToken("user-123")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestService_GenerateTokenWithInfo(t *testing.T) {
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
})
|
||||
|
||||
info := map[string]interface{}{
|
||||
"name": "Daniel",
|
||||
"avatar": "https://example.com/avatar.jpg",
|
||||
}
|
||||
|
||||
token, err := svc.GenerateTokenWithInfo("user-123", info)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestService_IsUserOnline_WithoutRedis(t *testing.T) {
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
Redis: nil,
|
||||
})
|
||||
|
||||
_, err := svc.IsUserOnline(context.Background(), "user-123")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrOnlineStoreNotConfigured)
|
||||
}
|
||||
|
||||
func TestService_GetUsersOnlineStatus_WithoutRedis(t *testing.T) {
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
Redis: nil,
|
||||
})
|
||||
|
||||
_, err := svc.GetUsersOnlineStatus(context.Background(), []string{"user-1", "user-2"})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrOnlineStoreNotConfigured)
|
||||
}
|
||||
|
||||
func TestService_IsUserOnline_WithRedis(t *testing.T) {
|
||||
rds, cleanup := setupServiceTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
Redis: rds,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
userID := "user-123"
|
||||
|
||||
// 初始狀態應該是離線
|
||||
online, err := svc.IsUserOnline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, online)
|
||||
|
||||
// 處理連線事件
|
||||
err = svc.Online().HandleConnect(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 現在應該在線
|
||||
online, err = svc.IsUserOnline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, online)
|
||||
|
||||
// 處理斷線事件
|
||||
err = svc.Online().HandleDisconnect(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 現在應該離線
|
||||
online, err = svc.IsUserOnline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, online)
|
||||
}
|
||||
|
||||
func TestService_GetUsersOnlineStatus_WithRedis(t *testing.T) {
|
||||
rds, cleanup := setupServiceTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
Redis: rds,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 設置一些用戶在線
|
||||
err := svc.Online().HandleConnect(ctx, "user-1")
|
||||
require.NoError(t, err)
|
||||
err = svc.Online().HandleConnect(ctx, "user-3")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 批量獲取在線狀態
|
||||
status, err := svc.GetUsersOnlineStatus(ctx, []string{"user-1", "user-2", "user-3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, status["user-1"])
|
||||
assert.False(t, status["user-2"])
|
||||
assert.True(t, status["user-3"])
|
||||
}
|
||||
|
||||
func TestService_Blacklist_Integration(t *testing.T) {
|
||||
rds, cleanup := setupServiceTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
Redis: rds,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
userID := "user-123"
|
||||
|
||||
// 撤銷用戶所有 Token
|
||||
err := svc.Blacklist().RevokeUserTokens(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 獲取版本
|
||||
version, err := svc.Blacklist().GetUserTokenVersion(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, version, int64(0))
|
||||
|
||||
// 檢查舊版本無效
|
||||
valid, err := svc.Blacklist().IsTokenVersionValid(ctx, userID, version-1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, valid)
|
||||
|
||||
// 檢查當前版本有效
|
||||
valid, err = svc.Blacklist().IsTokenVersionValid(ctx, userID, version)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid)
|
||||
}
|
||||
|
||||
func TestService_MultipleConnections(t *testing.T) {
|
||||
rds, cleanup := setupServiceTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
svc := NewService(ServiceConfig{
|
||||
APIURL: "http://localhost:8000",
|
||||
APIKey: "test-api-key",
|
||||
TokenSecret: "test-secret",
|
||||
Redis: rds,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
userID := "user-123"
|
||||
|
||||
// 模擬多個設備連線
|
||||
err := svc.Online().HandleConnect(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
err = svc.Online().HandleConnect(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
err = svc.Online().HandleConnect(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 用戶應該在線
|
||||
online, err := svc.IsUserOnline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, online)
|
||||
|
||||
// 斷開一個設備
|
||||
err = svc.Online().HandleDisconnect(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 用戶仍然在線(還有 2 個連線)
|
||||
online, err = svc.IsUserOnline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, online)
|
||||
|
||||
// 斷開剩餘設備
|
||||
err = svc.Online().HandleDisconnect(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
err = svc.Online().HandleDisconnect(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 用戶現在離線
|
||||
online, err = svc.IsUserOnline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, online)
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,519 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client Centrifugo 客戶端
|
||||
type Client struct {
|
||||
apiURL string
|
||||
apiKey string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// ClientConfig 客戶端配置
|
||||
type ClientConfig struct {
|
||||
// APIURL Centrifugo API 地址(必填)
|
||||
APIURL string
|
||||
// APIKey API 密鑰(必填)
|
||||
APIKey string
|
||||
|
||||
// Timeout 整體請求超時時間(預設 10 秒)
|
||||
Timeout time.Duration
|
||||
// MaxIdleConns 最大閒置連線數(預設 100)
|
||||
MaxIdleConns int
|
||||
// MaxIdleConnsPerHost 每個 host 最大閒置連線數(預設 20,適合高併發)
|
||||
MaxIdleConnsPerHost int
|
||||
// IdleConnTimeout 閒置連線超時時間(預設 90 秒)
|
||||
IdleConnTimeout time.Duration
|
||||
// DialTimeout 建立連線超時時間(預設 5 秒)
|
||||
DialTimeout time.Duration
|
||||
// TLSHandshakeTimeout TLS 握手超時時間(預設 5 秒)
|
||||
TLSHandshakeTimeout time.Duration
|
||||
// ResponseHeaderTimeout 等待響應頭超時時間(預設 10 秒)
|
||||
ResponseHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultConfig 返回預設配置
|
||||
func DefaultConfig(apiURL, apiKey string) ClientConfig {
|
||||
return ClientConfig{
|
||||
APIURL: apiURL,
|
||||
APIKey: apiKey,
|
||||
Timeout: 10 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 20, // 提高以支援高併發
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// HighPerformanceConfig 返回高效能配置(適合高併發場景)
|
||||
func HighPerformanceConfig(apiURL, apiKey string) ClientConfig {
|
||||
return ClientConfig{
|
||||
APIURL: apiURL,
|
||||
APIKey: apiKey,
|
||||
Timeout: 5 * time.Second, // 更短的超時
|
||||
MaxIdleConns: 200, // 更多閒置連線
|
||||
MaxIdleConnsPerHost: 50, // 更多每 host 連線
|
||||
IdleConnTimeout: 120 * time.Second,
|
||||
DialTimeout: 3 * time.Second,
|
||||
TLSHandshakeTimeout: 3 * time.Second,
|
||||
ResponseHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient 創建新的 Centrifugo 客戶端(使用預設配置)
|
||||
func NewClient(apiURL, apiKey string) *Client {
|
||||
return NewClientWithConfig(DefaultConfig(apiURL, apiKey))
|
||||
}
|
||||
|
||||
// NewClientWithConfig 創建使用自定義配置的 Centrifugo 客戶端
|
||||
func NewClientWithConfig(config ClientConfig) *Client {
|
||||
// 設定預設值
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 10 * time.Second
|
||||
}
|
||||
if config.MaxIdleConns == 0 {
|
||||
config.MaxIdleConns = 100
|
||||
}
|
||||
if config.MaxIdleConnsPerHost == 0 {
|
||||
config.MaxIdleConnsPerHost = 20
|
||||
}
|
||||
if config.IdleConnTimeout == 0 {
|
||||
config.IdleConnTimeout = 90 * time.Second
|
||||
}
|
||||
if config.DialTimeout == 0 {
|
||||
config.DialTimeout = 5 * time.Second
|
||||
}
|
||||
if config.TLSHandshakeTimeout == 0 {
|
||||
config.TLSHandshakeTimeout = 5 * time.Second
|
||||
}
|
||||
if config.ResponseHeaderTimeout == 0 {
|
||||
config.ResponseHeaderTimeout = 10 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: config.DialTimeout,
|
||||
KeepAlive: 30 * time.Second, // TCP keep-alive
|
||||
}).DialContext,
|
||||
MaxIdleConns: config.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: config.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: config.IdleConnTimeout,
|
||||
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
|
||||
ResponseHeaderTimeout: config.ResponseHeaderTimeout,
|
||||
ForceAttemptHTTP2: true, // 嘗試使用 HTTP/2
|
||||
}
|
||||
|
||||
return &Client{
|
||||
apiURL: config.APIURL,
|
||||
apiKey: config.APIKey,
|
||||
client: &http.Client{
|
||||
Timeout: config.Timeout,
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewClientWithHTTP 創建使用自定義 HTTP 客戶端的 Centrifugo 客戶端
|
||||
func NewClientWithHTTP(apiURL, apiKey string, httpClient *http.Client) *Client {
|
||||
return &Client{
|
||||
apiURL: apiURL,
|
||||
apiKey: apiKey,
|
||||
client: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Request/Response 結構 ====================
|
||||
|
||||
// PublishRequest 發布請求
|
||||
type PublishRequest struct {
|
||||
Channel string `json:"channel"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// BroadcastRequest 批量發布請求
|
||||
type BroadcastRequest struct {
|
||||
Channels []string `json:"channels"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// SubscribeRequest 訂閱請求
|
||||
type SubscribeRequest struct {
|
||||
User string `json:"user"`
|
||||
Channel string `json:"channel"`
|
||||
}
|
||||
|
||||
// UnsubscribeRequest 取消訂閱請求
|
||||
type UnsubscribeRequest struct {
|
||||
User string `json:"user"`
|
||||
Channel string `json:"channel"`
|
||||
}
|
||||
|
||||
// DisconnectRequest 斷開連線請求
|
||||
type DisconnectRequest struct {
|
||||
User string `json:"user"`
|
||||
}
|
||||
|
||||
// DisconnectWithCodeRequest 帶代碼的斷開連線請求
|
||||
type DisconnectWithCodeRequest struct {
|
||||
User string `json:"user"`
|
||||
Disconnect DisconnectInfo `json:"disconnect,omitempty"`
|
||||
}
|
||||
|
||||
// DisconnectInfo 斷開連線資訊
|
||||
type DisconnectInfo struct {
|
||||
Code uint32 `json:"code"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// PresenceRequest 在線狀態請求
|
||||
type PresenceRequest struct {
|
||||
Channel string `json:"channel"`
|
||||
}
|
||||
|
||||
// PresenceStatsRequest 在線統計請求
|
||||
type PresenceStatsRequest struct {
|
||||
Channel string `json:"channel"`
|
||||
}
|
||||
|
||||
// HistoryRequest 歷史訊息請求
|
||||
type HistoryRequest struct {
|
||||
Channel string `json:"channel"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Reverse bool `json:"reverse,omitempty"`
|
||||
}
|
||||
|
||||
// ChannelsRequest 頻道列表請求
|
||||
type ChannelsRequest struct {
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
}
|
||||
|
||||
// InfoRequest 伺服器資訊請求
|
||||
type InfoRequest struct{}
|
||||
|
||||
// APIResponse Centrifugo API 通用響應
|
||||
type APIResponse struct {
|
||||
Error *APIError `json:"error,omitempty"`
|
||||
Result interface{} `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// APIError Centrifugo API 錯誤
|
||||
type APIError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (e *APIError) Error() string {
|
||||
return fmt.Sprintf("centrifugo error %d: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// PublishResult 發布結果
|
||||
type PublishResult struct {
|
||||
Offset uint64 `json:"offset,omitempty"`
|
||||
Epoch string `json:"epoch,omitempty"`
|
||||
}
|
||||
|
||||
// PresenceResult 在線狀態結果
|
||||
type PresenceResult struct {
|
||||
Presence map[string]ClientInfo `json:"presence"`
|
||||
}
|
||||
|
||||
// ClientInfo 客戶端資訊
|
||||
type ClientInfo struct {
|
||||
User string `json:"user"`
|
||||
Client string `json:"client"`
|
||||
ConnInfo json.RawMessage `json:"conn_info,omitempty"`
|
||||
ChanInfo json.RawMessage `json:"chan_info,omitempty"`
|
||||
}
|
||||
|
||||
// PresenceStatsResult 在線統計結果
|
||||
type PresenceStatsResult struct {
|
||||
NumClients int `json:"num_clients"`
|
||||
NumUsers int `json:"num_users"`
|
||||
}
|
||||
|
||||
// HistoryResult 歷史訊息結果
|
||||
type HistoryResult struct {
|
||||
Publications []Publication `json:"publications"`
|
||||
Offset uint64 `json:"offset,omitempty"`
|
||||
Epoch string `json:"epoch,omitempty"`
|
||||
}
|
||||
|
||||
// Publication 發布的訊息
|
||||
type Publication struct {
|
||||
Offset uint64 `json:"offset,omitempty"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Info *ClientInfo `json:"info,omitempty"`
|
||||
}
|
||||
|
||||
// ChannelsResult 頻道列表結果
|
||||
type ChannelsResult struct {
|
||||
Channels map[string]ChannelInfo `json:"channels"`
|
||||
}
|
||||
|
||||
// ChannelInfo 頻道資訊
|
||||
type ChannelInfo struct {
|
||||
NumClients int `json:"num_clients"`
|
||||
}
|
||||
|
||||
// InfoResult 伺服器資訊結果
|
||||
type InfoResult struct {
|
||||
Nodes []NodeInfo `json:"nodes"`
|
||||
}
|
||||
|
||||
// NodeInfo 節點資訊
|
||||
type NodeInfo struct {
|
||||
UID string `json:"uid"`
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
NumClients int `json:"num_clients"`
|
||||
NumUsers int `json:"num_users"`
|
||||
NumChannels int `json:"num_channels"`
|
||||
Uptime int `json:"uptime"`
|
||||
}
|
||||
|
||||
// ==================== 發布相關方法 ====================
|
||||
|
||||
// Publish 發布訊息到指定頻道
|
||||
func (c *Client) Publish(ctx context.Context, channel string, data []byte) (*PublishResult, error) {
|
||||
req := PublishRequest{
|
||||
Channel: channel,
|
||||
Data: json.RawMessage(data),
|
||||
}
|
||||
return c.publish(ctx, req)
|
||||
}
|
||||
|
||||
// PublishJSON 發布 JSON 訊息到指定頻道
|
||||
func (c *Client) PublishJSON(ctx context.Context, channel string, data interface{}) (*PublishResult, error) {
|
||||
req := PublishRequest{
|
||||
Channel: channel,
|
||||
Data: data,
|
||||
}
|
||||
return c.publish(ctx, req)
|
||||
}
|
||||
|
||||
func (c *Client) publish(ctx context.Context, req PublishRequest) (*PublishResult, error) {
|
||||
var result PublishResult
|
||||
if err := c.callAPI(ctx, "publish", req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Broadcast 批量發布訊息到多個頻道
|
||||
func (c *Client) Broadcast(ctx context.Context, channels []string, data []byte) error {
|
||||
req := BroadcastRequest{
|
||||
Channels: channels,
|
||||
Data: json.RawMessage(data),
|
||||
}
|
||||
return c.callAPI(ctx, "broadcast", req, nil)
|
||||
}
|
||||
|
||||
// BroadcastJSON 批量發布 JSON 訊息到多個頻道
|
||||
func (c *Client) BroadcastJSON(ctx context.Context, channels []string, data interface{}) error {
|
||||
req := BroadcastRequest{
|
||||
Channels: channels,
|
||||
Data: data,
|
||||
}
|
||||
return c.callAPI(ctx, "broadcast", req, nil)
|
||||
}
|
||||
|
||||
// ==================== 訂閱管理方法 ====================
|
||||
|
||||
// Subscribe 訂閱用戶到頻道
|
||||
func (c *Client) Subscribe(ctx context.Context, user, channel string) error {
|
||||
req := SubscribeRequest{
|
||||
User: user,
|
||||
Channel: channel,
|
||||
}
|
||||
return c.callAPI(ctx, "subscribe", req, nil)
|
||||
}
|
||||
|
||||
// Unsubscribe 取消用戶訂閱
|
||||
func (c *Client) Unsubscribe(ctx context.Context, user, channel string) error {
|
||||
req := UnsubscribeRequest{
|
||||
User: user,
|
||||
Channel: channel,
|
||||
}
|
||||
return c.callAPI(ctx, "unsubscribe", req, nil)
|
||||
}
|
||||
|
||||
// Disconnect 強制斷開用戶連線
|
||||
func (c *Client) Disconnect(ctx context.Context, user string) error {
|
||||
req := DisconnectRequest{
|
||||
User: user,
|
||||
}
|
||||
return c.callAPI(ctx, "disconnect", req, nil)
|
||||
}
|
||||
|
||||
// DisconnectWithCode 強制斷開用戶連線(帶斷開代碼和原因)
|
||||
func (c *Client) DisconnectWithCode(ctx context.Context, user string, code uint32, reason string) error {
|
||||
req := DisconnectWithCodeRequest{
|
||||
User: user,
|
||||
Disconnect: DisconnectInfo{
|
||||
Code: code,
|
||||
Reason: reason,
|
||||
},
|
||||
}
|
||||
return c.callAPI(ctx, "disconnect", req, nil)
|
||||
}
|
||||
|
||||
|
||||
// ==================== 在線狀態方法 ====================
|
||||
|
||||
// Presence 獲取頻道在線用戶
|
||||
func (c *Client) Presence(ctx context.Context, channel string) (*PresenceResult, error) {
|
||||
req := PresenceRequest{
|
||||
Channel: channel,
|
||||
}
|
||||
var result PresenceResult
|
||||
if err := c.callAPI(ctx, "presence", req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// PresenceStats 獲取頻道在線統計
|
||||
func (c *Client) PresenceStats(ctx context.Context, channel string) (*PresenceStatsResult, error) {
|
||||
req := PresenceStatsRequest{
|
||||
Channel: channel,
|
||||
}
|
||||
var result PresenceStatsResult
|
||||
if err := c.callAPI(ctx, "presence_stats", req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ==================== 歷史訊息方法 ====================
|
||||
|
||||
// History 獲取頻道歷史訊息
|
||||
func (c *Client) History(ctx context.Context, channel string, limit int) (*HistoryResult, error) {
|
||||
req := HistoryRequest{
|
||||
Channel: channel,
|
||||
Limit: limit,
|
||||
}
|
||||
var result HistoryResult
|
||||
if err := c.callAPI(ctx, "history", req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// HistoryReverse 獲取頻道歷史訊息(倒序)
|
||||
func (c *Client) HistoryReverse(ctx context.Context, channel string, limit int) (*HistoryResult, error) {
|
||||
req := HistoryRequest{
|
||||
Channel: channel,
|
||||
Limit: limit,
|
||||
Reverse: true,
|
||||
}
|
||||
var result HistoryResult
|
||||
if err := c.callAPI(ctx, "history", req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ==================== 頻道管理方法 ====================
|
||||
|
||||
// Channels 獲取所有活躍頻道
|
||||
func (c *Client) Channels(ctx context.Context) (*ChannelsResult, error) {
|
||||
req := ChannelsRequest{}
|
||||
var result ChannelsResult
|
||||
if err := c.callAPI(ctx, "channels", req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ChannelsWithPattern 獲取匹配模式的活躍頻道
|
||||
func (c *Client) ChannelsWithPattern(ctx context.Context, pattern string) (*ChannelsResult, error) {
|
||||
req := ChannelsRequest{
|
||||
Pattern: pattern,
|
||||
}
|
||||
var result ChannelsResult
|
||||
if err := c.callAPI(ctx, "channels", req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ==================== 伺服器狀態方法 ====================
|
||||
|
||||
// Info 獲取伺服器狀態資訊
|
||||
func (c *Client) Info(ctx context.Context) (*InfoResult, error) {
|
||||
req := InfoRequest{}
|
||||
var result InfoResult
|
||||
if err := c.callAPI(ctx, "info", req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Ping 檢查伺服器是否健康
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
_, err := c.Info(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ==================== 內部方法 ====================
|
||||
|
||||
// callAPI 調用 Centrifugo API
|
||||
func (c *Client) callAPI(ctx context.Context, method string, params interface{}, result interface{}) error {
|
||||
body, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/%s", c.apiURL, method)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "apikey "+c.apiKey)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("centrifugo returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// 解析響應
|
||||
var apiResp APIResponse
|
||||
if result != nil {
|
||||
apiResp.Result = result
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBody, &apiResp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
if apiResp.Error != nil {
|
||||
return apiResp.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OnlineStatus 在線狀態
|
||||
type OnlineStatus struct {
|
||||
UserID string `json:"user_id"`
|
||||
IsOnline bool `json:"is_online"`
|
||||
LastSeenAt time.Time `json:"last_seen_at,omitempty"`
|
||||
Clients int `json:"clients,omitempty"` // 連線數(可能多個設備)
|
||||
}
|
||||
|
||||
// OnlineStore 在線狀態存儲介面
|
||||
// 可以用 Redis、Memory 或其他存儲實作
|
||||
type OnlineStore interface {
|
||||
// SetOnline 設置用戶在線
|
||||
SetOnline(ctx context.Context, userID string, ttl time.Duration) error
|
||||
// SetOffline 設置用戶離線
|
||||
SetOffline(ctx context.Context, userID string) error
|
||||
// IsOnline 檢查用戶是否在線
|
||||
IsOnline(ctx context.Context, userID string) (bool, error)
|
||||
// GetOnlineUsers 獲取在線用戶列表
|
||||
GetOnlineUsers(ctx context.Context, userIDs []string) (map[string]bool, error)
|
||||
// IncrClient 增加用戶連線數
|
||||
IncrClient(ctx context.Context, userID string) (int64, error)
|
||||
// DecrClient 減少用戶連線數
|
||||
DecrClient(ctx context.Context, userID string) (int64, error)
|
||||
}
|
||||
|
||||
// OnlineManager 在線狀態管理器
|
||||
// 結合 Redis 存儲和 Centrifugo Presence API 提供在線狀態追蹤
|
||||
type OnlineManager struct {
|
||||
client *Client
|
||||
store OnlineStore
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewOnlineManager 創建在線狀態管理器
|
||||
// store 可以為 nil,此時只使用 Centrifugo Presence API
|
||||
func NewOnlineManager(client *Client, store OnlineStore) *OnlineManager {
|
||||
return &OnlineManager{
|
||||
client: client,
|
||||
store: store,
|
||||
ttl: 5 * time.Minute, // 預設 5 分鐘過期
|
||||
}
|
||||
}
|
||||
|
||||
// NewOnlineManagerWithTTL 創建帶 TTL 的在線狀態管理器
|
||||
func NewOnlineManagerWithTTL(client *Client, store OnlineStore, ttl time.Duration) *OnlineManager {
|
||||
return &OnlineManager{
|
||||
client: client,
|
||||
store: store,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 連線事件處理 ====================
|
||||
|
||||
// HandleConnect 處理用戶連線事件(用於 Centrifugo Connect Proxy)
|
||||
func (m *OnlineManager) HandleConnect(ctx context.Context, userID string) error {
|
||||
if m.store == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 增加連線數
|
||||
count, err := m.store.IncrClient(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to incr client: %w", err)
|
||||
}
|
||||
|
||||
// 如果是第一個連線,設置在線狀態
|
||||
if count == 1 {
|
||||
if err := m.store.SetOnline(ctx, userID, m.ttl); err != nil {
|
||||
return fmt.Errorf("failed to set online: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleDisconnect 處理用戶斷線事件(用於 Centrifugo Disconnect Proxy)
|
||||
func (m *OnlineManager) HandleDisconnect(ctx context.Context, userID string) error {
|
||||
if m.store == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 減少連線數
|
||||
count, err := m.store.DecrClient(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decr client: %w", err)
|
||||
}
|
||||
|
||||
// 如果沒有連線了,設置離線狀態
|
||||
if count <= 0 {
|
||||
if err := m.store.SetOffline(ctx, userID); err != nil {
|
||||
return fmt.Errorf("failed to set offline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 在線狀態查詢 ====================
|
||||
|
||||
// IsUserOnline 檢查用戶是否在線(使用 Store)
|
||||
func (m *OnlineManager) IsUserOnline(ctx context.Context, userID string) (bool, error) {
|
||||
if m.store == nil {
|
||||
return false, ErrOnlineStoreNotConfigured
|
||||
}
|
||||
return m.store.IsOnline(ctx, userID)
|
||||
}
|
||||
|
||||
// GetUsersOnlineStatus 批量獲取用戶在線狀態
|
||||
func (m *OnlineManager) GetUsersOnlineStatus(ctx context.Context, userIDs []string) (map[string]bool, error) {
|
||||
if m.store == nil {
|
||||
return nil, ErrOnlineStoreNotConfigured
|
||||
}
|
||||
return m.store.GetOnlineUsers(ctx, userIDs)
|
||||
}
|
||||
|
||||
// RefreshOnline 刷新用戶在線狀態(用於心跳)
|
||||
func (m *OnlineManager) RefreshOnline(ctx context.Context, userID string) error {
|
||||
if m.store == nil {
|
||||
return nil
|
||||
}
|
||||
return m.store.SetOnline(ctx, userID, m.ttl)
|
||||
}
|
||||
|
||||
// ==================== Centrifugo Presence API ====================
|
||||
|
||||
// IsUserInChannel 檢查用戶是否在指定頻道中(使用 Centrifugo Presence)
|
||||
func (m *OnlineManager) IsUserInChannel(ctx context.Context, userID, channel string) (bool, error) {
|
||||
presence, err := m.client.Presence(ctx, channel)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, info := range presence.Presence {
|
||||
if info.User == userID {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetChannelOnlineUsers 獲取頻道中的在線用戶(使用 Centrifugo Presence)
|
||||
func (m *OnlineManager) GetChannelOnlineUsers(ctx context.Context, channel string) ([]string, error) {
|
||||
presence, err := m.client.Presence(ctx, channel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 去重(一個用戶可能有多個連線)
|
||||
userMap := make(map[string]bool)
|
||||
for _, info := range presence.Presence {
|
||||
userMap[info.User] = true
|
||||
}
|
||||
|
||||
users := make([]string, 0, len(userMap))
|
||||
for userID := range userMap {
|
||||
users = append(users, userID)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetChannelStats 獲取頻道在線統計
|
||||
func (m *OnlineManager) GetChannelStats(ctx context.Context, channel string) (*PresenceStatsResult, error) {
|
||||
return m.client.PresenceStats(ctx, channel)
|
||||
}
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// RedisOnlineStore 使用 Redis 實作的在線狀態存儲
|
||||
type RedisOnlineStore struct {
|
||||
client *redis.Redis
|
||||
keyPrefix string
|
||||
}
|
||||
|
||||
// NewRedisOnlineStore 創建 Redis 在線狀態存儲
|
||||
func NewRedisOnlineStore(client *redis.Redis) *RedisOnlineStore {
|
||||
return &RedisOnlineStore{
|
||||
client: client,
|
||||
keyPrefix: "online:",
|
||||
}
|
||||
}
|
||||
|
||||
// NewRedisOnlineStoreWithPrefix 創建帶自定義前綴的 Redis 在線狀態存儲
|
||||
func NewRedisOnlineStoreWithPrefix(client *redis.Redis, prefix string) *RedisOnlineStore {
|
||||
return &RedisOnlineStore{
|
||||
client: client,
|
||||
keyPrefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// key 生成 Redis key
|
||||
func (s *RedisOnlineStore) key(userID string) string {
|
||||
return s.keyPrefix + userID
|
||||
}
|
||||
|
||||
// clientCountKey 生成連線數 key
|
||||
func (s *RedisOnlineStore) clientCountKey(userID string) string {
|
||||
return s.keyPrefix + "clients:" + userID
|
||||
}
|
||||
|
||||
// SetOnline 設置用戶在線
|
||||
func (s *RedisOnlineStore) SetOnline(ctx context.Context, userID string, ttl time.Duration) error {
|
||||
return s.client.SetexCtx(ctx, s.key(userID), fmt.Sprintf("%d", time.Now().Unix()), int(ttl.Seconds()))
|
||||
}
|
||||
|
||||
// SetOffline 設置用戶離線
|
||||
func (s *RedisOnlineStore) SetOffline(ctx context.Context, userID string) error {
|
||||
// 刪除在線狀態
|
||||
_, err := s.client.DelCtx(ctx, s.key(userID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 刪除連線數
|
||||
_, err = s.client.DelCtx(ctx, s.clientCountKey(userID))
|
||||
return err
|
||||
}
|
||||
|
||||
// IsOnline 檢查用戶是否在線
|
||||
func (s *RedisOnlineStore) IsOnline(ctx context.Context, userID string) (bool, error) {
|
||||
return s.client.ExistsCtx(ctx, s.key(userID))
|
||||
}
|
||||
|
||||
// GetOnlineUsers 批量獲取在線用戶狀態
|
||||
func (s *RedisOnlineStore) GetOnlineUsers(ctx context.Context, userIDs []string) (map[string]bool, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return make(map[string]bool), nil
|
||||
}
|
||||
|
||||
result := make(map[string]bool, len(userIDs))
|
||||
for _, userID := range userIDs {
|
||||
exists, err := s.client.ExistsCtx(ctx, s.key(userID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check online status for %s: %w", userID, err)
|
||||
}
|
||||
result[userID] = exists
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// IncrClient 增加用戶連線數
|
||||
func (s *RedisOnlineStore) IncrClient(ctx context.Context, userID string) (int64, error) {
|
||||
count, err := s.client.IncrCtx(ctx, s.clientCountKey(userID))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// DecrClient 減少用戶連線數
|
||||
func (s *RedisOnlineStore) DecrClient(ctx context.Context, userID string) (int64, error) {
|
||||
count, err := s.client.DecrCtx(ctx, s.clientCountKey(userID))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 確保不會變成負數
|
||||
if count < 0 {
|
||||
_ = s.client.SetCtx(ctx, s.clientCountKey(userID), "0")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// GetClientCount 獲取用戶連線數
|
||||
func (s *RedisOnlineStore) GetClientCount(ctx context.Context, userID string) (int64, error) {
|
||||
val, err := s.client.GetCtx(ctx, s.clientCountKey(userID))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if val == "" {
|
||||
return 0, nil
|
||||
}
|
||||
return strconv.ParseInt(val, 10, 64)
|
||||
}
|
||||
|
||||
// GetAllOnlineUserIDs 獲取所有在線用戶 ID
|
||||
// 注意:此方法使用 KEYS 命令,在大規模生產環境中可能有性能問題
|
||||
// 建議在需要時使用 Centrifugo Presence API 替代
|
||||
func (s *RedisOnlineStore) GetAllOnlineUserIDs(ctx context.Context) ([]string, error) {
|
||||
// 使用 KEYS 查找所有在線用戶(排除 clients: 開頭的 key)
|
||||
pattern := s.keyPrefix + "[^c]*"
|
||||
keys, err := s.client.KeysCtx(ctx, pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userIDs := make([]string, 0, len(keys))
|
||||
prefixLen := len(s.keyPrefix)
|
||||
for _, key := range keys {
|
||||
if len(key) > prefixLen {
|
||||
userIDs = append(userIDs, key[prefixLen:])
|
||||
}
|
||||
}
|
||||
|
||||
return userIDs, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,272 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
func setupOnlineTestRedis(t *testing.T) (*redis.Redis, func()) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
rds, err := redis.NewRedis(redis.RedisConf{
|
||||
Host: mr.Addr(),
|
||||
Type: "node",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return rds, func() {
|
||||
mr.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRedisOnlineStore(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
|
||||
assert.NotNil(t, store)
|
||||
assert.Equal(t, "online:", store.keyPrefix)
|
||||
}
|
||||
|
||||
func TestNewRedisOnlineStoreWithPrefix(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStoreWithPrefix(rds, "custom:online:")
|
||||
|
||||
assert.NotNil(t, store)
|
||||
assert.Equal(t, "custom:online:", store.keyPrefix)
|
||||
}
|
||||
|
||||
func TestSetOnline(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
ttl := 5 * time.Minute
|
||||
|
||||
// 設置在線
|
||||
err := store.SetOnline(ctx, userID, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 檢查是否在線
|
||||
online, err := store.IsOnline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, online)
|
||||
}
|
||||
|
||||
func TestSetOffline(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
|
||||
// 先設置在線
|
||||
err := store.SetOnline(ctx, userID, 5*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 設置離線
|
||||
err = store.SetOffline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 檢查是否離線
|
||||
online, err := store.IsOnline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, online)
|
||||
}
|
||||
|
||||
func TestIsOnline_NotOnline(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
// 未設置在線的用戶應該返回 false
|
||||
online, err := store.IsOnline(ctx, "non-existent-user")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, online)
|
||||
}
|
||||
|
||||
func TestGetOnlineUsers(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
// 設置一些用戶在線
|
||||
err := store.SetOnline(ctx, "user-1", 5*time.Minute)
|
||||
require.NoError(t, err)
|
||||
err = store.SetOnline(ctx, "user-3", 5*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 批量獲取在線狀態
|
||||
userIDs := []string{"user-1", "user-2", "user-3"}
|
||||
status, err := store.GetOnlineUsers(ctx, userIDs)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, status["user-1"])
|
||||
assert.False(t, status["user-2"])
|
||||
assert.True(t, status["user-3"])
|
||||
}
|
||||
|
||||
func TestGetOnlineUsers_Empty(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
// 空列表應該返回空 map
|
||||
status, err := store.GetOnlineUsers(ctx, []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, status)
|
||||
}
|
||||
|
||||
func TestIncrClient(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
|
||||
// 第一次增加
|
||||
count, err := store.IncrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
// 第二次增加
|
||||
count, err = store.IncrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), count)
|
||||
|
||||
// 獲取連線數
|
||||
count, err = store.GetClientCount(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func TestDecrClient(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
|
||||
// 先增加到 2
|
||||
_, err := store.IncrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
_, err = store.IncrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 減少一次
|
||||
count, err := store.DecrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
// 再減少一次
|
||||
count, err = store.DecrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestDecrClient_NegativeProtection(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
|
||||
// 直接減少(沒有先增加)
|
||||
count, err := store.DecrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count) // 應該被保護為 0
|
||||
|
||||
// 再次減少
|
||||
count, err = store.DecrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count) // 仍然是 0
|
||||
}
|
||||
|
||||
func TestGetClientCount_NoClient(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
// 未設置連線數的用戶應該返回 0
|
||||
count, err := store.GetClientCount(ctx, "non-existent-user")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestSetOffline_ClearsClientCount(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user-123"
|
||||
|
||||
// 設置在線和連線數
|
||||
err := store.SetOnline(ctx, userID, 5*time.Minute)
|
||||
require.NoError(t, err)
|
||||
_, err = store.IncrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
_, err = store.IncrClient(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 設置離線
|
||||
err = store.SetOffline(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 連線數也應該被清除
|
||||
count, err := store.GetClientCount(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestKeyGeneration_OnlineStore(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStoreWithPrefix(rds, "test:")
|
||||
|
||||
// 測試 key 生成
|
||||
assert.Equal(t, "test:user-123", store.key("user-123"))
|
||||
assert.Equal(t, "test:clients:user-456", store.clientCountKey("user-456"))
|
||||
}
|
||||
|
||||
func TestOnlineStore_ImplementsInterface(t *testing.T) {
|
||||
rds, cleanup := setupOnlineTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
store := NewRedisOnlineStore(rds)
|
||||
|
||||
// 確保 RedisOnlineStore 實現了 OnlineStore 介面
|
||||
var _ OnlineStore = store
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TokenConfig JWT Token 配置
|
||||
type TokenConfig struct {
|
||||
// Secret 用於簽名的密鑰(與 Centrifugo 配置的 token_hmac_secret_key 一致)
|
||||
Secret string
|
||||
// ExpireIn Token 過期時間(預設 1 小時)
|
||||
ExpireIn time.Duration
|
||||
}
|
||||
|
||||
// TokenGenerator JWT Token 生成器
|
||||
type TokenGenerator struct {
|
||||
config TokenConfig
|
||||
}
|
||||
|
||||
// NewTokenGenerator 創建新的 Token 生成器
|
||||
func NewTokenGenerator(secret string) *TokenGenerator {
|
||||
return &TokenGenerator{
|
||||
config: TokenConfig{
|
||||
Secret: secret,
|
||||
ExpireIn: time.Hour,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewTokenGeneratorWithConfig 創建使用自定義配置的 Token 生成器
|
||||
func NewTokenGeneratorWithConfig(config TokenConfig) *TokenGenerator {
|
||||
if config.ExpireIn == 0 {
|
||||
config.ExpireIn = time.Hour
|
||||
}
|
||||
return &TokenGenerator{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectionClaims 連線 Token 的 Claims
|
||||
type ConnectionClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
// Sub 用戶 ID(必填)
|
||||
Sub string `json:"sub"`
|
||||
// Info 用戶資訊(可選,會在 presence 中顯示)
|
||||
Info map[string]interface{} `json:"info,omitempty"`
|
||||
// Channels 自動訂閱的頻道列表(可選)
|
||||
Channels []string `json:"channels,omitempty"`
|
||||
// TokenVersion Token 版本(用於批量撤銷)
|
||||
TokenVersion int64 `json:"tv,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionClaims 訂閱 Token 的 Claims(用於私有頻道)
|
||||
type SubscriptionClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
// Sub 用戶 ID(必填)
|
||||
Sub string `json:"sub"`
|
||||
// Channel 頻道名稱(必填)
|
||||
Channel string `json:"channel"`
|
||||
// Info 頻道特定的用戶資訊(可選)
|
||||
Info map[string]interface{} `json:"info,omitempty"`
|
||||
// TokenVersion Token 版本(用於批量撤銷)
|
||||
TokenVersion int64 `json:"tv,omitempty"`
|
||||
}
|
||||
|
||||
// ConnectionTokenOptions 連線 Token 選項
|
||||
type ConnectionTokenOptions struct {
|
||||
// UserID 用戶 ID(必填)
|
||||
UserID string
|
||||
// Info 用戶資訊(可選)
|
||||
Info map[string]interface{}
|
||||
// Channels 自動訂閱的頻道列表(可選)
|
||||
Channels []string
|
||||
// ExpireAt 自定義過期時間(可選,為空則使用預設)
|
||||
ExpireAt *time.Time
|
||||
// TokenVersion Token 版本(可選,用於黑名單機制)
|
||||
TokenVersion int64
|
||||
}
|
||||
|
||||
// SubscriptionTokenOptions 訂閱 Token 選項
|
||||
type SubscriptionTokenOptions struct {
|
||||
// UserID 用戶 ID(必填)
|
||||
UserID string
|
||||
// Channel 頻道名稱(必填)
|
||||
Channel string
|
||||
// Info 頻道特定的用戶資訊(可選)
|
||||
Info map[string]interface{}
|
||||
// ExpireAt 自定義過期時間(可選,為空則使用預設)
|
||||
ExpireAt *time.Time
|
||||
// TokenVersion Token 版本(可選,用於黑名單機制)
|
||||
TokenVersion int64
|
||||
}
|
||||
|
||||
// TokenResult 生成 Token 的結果
|
||||
type TokenResult struct {
|
||||
Token string // JWT Token 字串
|
||||
JTI string // JWT ID(用於撤銷單一 Token)
|
||||
ExpiresAt time.Time // 過期時間
|
||||
}
|
||||
|
||||
// GenerateConnectionToken 生成連線 Token
|
||||
// 用於前端建立 WebSocket 連線時的身份驗證
|
||||
func (g *TokenGenerator) GenerateConnectionToken(opts ConnectionTokenOptions) (string, error) {
|
||||
result, err := g.GenerateConnectionTokenWithJTI(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
// GenerateConnectionTokenWithJTI 生成連線 Token 並返回 JTI
|
||||
// 用於需要支援單一 Token 撤銷的場景
|
||||
func (g *TokenGenerator) GenerateConnectionTokenWithJTI(opts ConnectionTokenOptions) (*TokenResult, error) {
|
||||
now := time.Now()
|
||||
expireAt := now.Add(g.config.ExpireIn)
|
||||
if opts.ExpireAt != nil {
|
||||
expireAt = *opts.ExpireAt
|
||||
}
|
||||
|
||||
// 生成唯一的 JTI
|
||||
jti := uuid.New().String()
|
||||
|
||||
// 如果沒有指定 TokenVersion,使用當前時間戳
|
||||
tokenVersion := opts.TokenVersion
|
||||
if tokenVersion == 0 {
|
||||
tokenVersion = now.UnixNano()
|
||||
}
|
||||
|
||||
claims := ConnectionClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: jti,
|
||||
Subject: opts.UserID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(expireAt),
|
||||
},
|
||||
Sub: opts.UserID,
|
||||
Info: opts.Info,
|
||||
Channels: opts.Channels,
|
||||
TokenVersion: tokenVersion,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := token.SignedString([]byte(g.config.Secret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenResult{
|
||||
Token: tokenStr,
|
||||
JTI: jti,
|
||||
ExpiresAt: expireAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateSubscriptionToken 生成訂閱 Token
|
||||
// 用於訂閱私有頻道時的身份驗證
|
||||
func (g *TokenGenerator) GenerateSubscriptionToken(opts SubscriptionTokenOptions) (string, error) {
|
||||
now := time.Now()
|
||||
expireAt := now.Add(g.config.ExpireIn)
|
||||
if opts.ExpireAt != nil {
|
||||
expireAt = *opts.ExpireAt
|
||||
}
|
||||
|
||||
claims := SubscriptionClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: opts.UserID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(expireAt),
|
||||
},
|
||||
Sub: opts.UserID,
|
||||
Channel: opts.Channel,
|
||||
Info: opts.Info,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(g.config.Secret))
|
||||
}
|
||||
|
||||
// GenerateAnonymousToken 生成匿名連線 Token
|
||||
// 用於允許匿名用戶連線(但仍需要驗證)
|
||||
func (g *TokenGenerator) GenerateAnonymousToken() (string, error) {
|
||||
return g.GenerateConnectionToken(ConnectionTokenOptions{
|
||||
UserID: "", // 空的 UserID 表示匿名用戶
|
||||
})
|
||||
}
|
||||
|
||||
// QuickConnectionToken 快速生成連線 Token(只需用戶 ID)
|
||||
func (g *TokenGenerator) QuickConnectionToken(userID string) (string, error) {
|
||||
return g.GenerateConnectionToken(ConnectionTokenOptions{
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
|
||||
// QuickSubscriptionToken 快速生成訂閱 Token(只需用戶 ID 和頻道)
|
||||
func (g *TokenGenerator) QuickSubscriptionToken(userID, channel string) (string, error) {
|
||||
return g.GenerateSubscriptionToken(SubscriptionTokenOptions{
|
||||
UserID: userID,
|
||||
Channel: channel,
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
package centrifugo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTokenGenerator(t *testing.T) {
|
||||
secret := "test-secret"
|
||||
gen := NewTokenGenerator(secret)
|
||||
|
||||
assert.NotNil(t, gen)
|
||||
assert.Equal(t, secret, gen.config.Secret)
|
||||
assert.Equal(t, time.Hour, gen.config.ExpireIn)
|
||||
}
|
||||
|
||||
func TestNewTokenGeneratorWithConfig(t *testing.T) {
|
||||
config := TokenConfig{
|
||||
Secret: "custom-secret",
|
||||
ExpireIn: 24 * time.Hour,
|
||||
}
|
||||
gen := NewTokenGeneratorWithConfig(config)
|
||||
|
||||
assert.NotNil(t, gen)
|
||||
assert.Equal(t, config.Secret, gen.config.Secret)
|
||||
assert.Equal(t, config.ExpireIn, gen.config.ExpireIn)
|
||||
}
|
||||
|
||||
func TestNewTokenGeneratorWithConfig_DefaultExpire(t *testing.T) {
|
||||
config := TokenConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireIn: 0, // 應該使用預設值
|
||||
}
|
||||
gen := NewTokenGeneratorWithConfig(config)
|
||||
|
||||
assert.Equal(t, time.Hour, gen.config.ExpireIn)
|
||||
}
|
||||
|
||||
func TestQuickConnectionToken(t *testing.T) {
|
||||
gen := NewTokenGenerator("test-secret")
|
||||
userID := "user-123"
|
||||
|
||||
token, err := gen.QuickConnectionToken(userID)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 驗證 Token 內容
|
||||
claims := &ConnectionClaims{}
|
||||
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("test-secret"), nil
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, parsedToken.Valid)
|
||||
assert.Equal(t, userID, claims.Sub)
|
||||
assert.NotNil(t, claims.ExpiresAt)
|
||||
assert.NotNil(t, claims.IssuedAt)
|
||||
}
|
||||
|
||||
func TestGenerateConnectionToken(t *testing.T) {
|
||||
gen := NewTokenGenerator("test-secret")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts ConnectionTokenOptions
|
||||
checkFn func(t *testing.T, claims *ConnectionClaims)
|
||||
}{
|
||||
{
|
||||
name: "basic token",
|
||||
opts: ConnectionTokenOptions{
|
||||
UserID: "user-123",
|
||||
},
|
||||
checkFn: func(t *testing.T, claims *ConnectionClaims) {
|
||||
assert.Equal(t, "user-123", claims.Sub)
|
||||
assert.Nil(t, claims.Info)
|
||||
assert.Nil(t, claims.Channels)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token with info",
|
||||
opts: ConnectionTokenOptions{
|
||||
UserID: "user-456",
|
||||
Info: map[string]interface{}{
|
||||
"name": "Daniel",
|
||||
"role": "admin",
|
||||
},
|
||||
},
|
||||
checkFn: func(t *testing.T, claims *ConnectionClaims) {
|
||||
assert.Equal(t, "user-456", claims.Sub)
|
||||
assert.NotNil(t, claims.Info)
|
||||
assert.Equal(t, "Daniel", claims.Info["name"])
|
||||
assert.Equal(t, "admin", claims.Info["role"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token with channels",
|
||||
opts: ConnectionTokenOptions{
|
||||
UserID: "user-789",
|
||||
Channels: []string{"chat:room-1", "chat:room-2"},
|
||||
},
|
||||
checkFn: func(t *testing.T, claims *ConnectionClaims) {
|
||||
assert.Equal(t, "user-789", claims.Sub)
|
||||
assert.Equal(t, []string{"chat:room-1", "chat:room-2"}, claims.Channels)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token with custom expire",
|
||||
opts: ConnectionTokenOptions{
|
||||
UserID: "user-abc",
|
||||
ExpireAt: ptrTime(time.Now().Add(48 * time.Hour)),
|
||||
},
|
||||
checkFn: func(t *testing.T, claims *ConnectionClaims) {
|
||||
assert.Equal(t, "user-abc", claims.Sub)
|
||||
// 檢查過期時間大約在 48 小時後
|
||||
expireTime := claims.ExpiresAt.Time
|
||||
assert.True(t, expireTime.After(time.Now().Add(47*time.Hour)))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, err := gen.GenerateConnectionToken(tt.opts)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 解析 Token
|
||||
claims := &ConnectionClaims{}
|
||||
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("test-secret"), nil
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, parsedToken.Valid)
|
||||
tt.checkFn(t, claims)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuickSubscriptionToken(t *testing.T) {
|
||||
gen := NewTokenGenerator("test-secret")
|
||||
userID := "user-123"
|
||||
channel := "private:room-456"
|
||||
|
||||
token, err := gen.QuickSubscriptionToken(userID, channel)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 驗證 Token 內容
|
||||
claims := &SubscriptionClaims{}
|
||||
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("test-secret"), nil
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, parsedToken.Valid)
|
||||
assert.Equal(t, userID, claims.Sub)
|
||||
assert.Equal(t, channel, claims.Channel)
|
||||
}
|
||||
|
||||
func TestGenerateSubscriptionToken(t *testing.T) {
|
||||
gen := NewTokenGenerator("test-secret")
|
||||
|
||||
opts := SubscriptionTokenOptions{
|
||||
UserID: "user-123",
|
||||
Channel: "private:room-456",
|
||||
Info: map[string]interface{}{
|
||||
"role": "moderator",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := gen.GenerateSubscriptionToken(opts)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 驗證 Token 內容
|
||||
claims := &SubscriptionClaims{}
|
||||
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("test-secret"), nil
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, parsedToken.Valid)
|
||||
assert.Equal(t, opts.UserID, claims.Sub)
|
||||
assert.Equal(t, opts.Channel, claims.Channel)
|
||||
assert.Equal(t, "moderator", claims.Info["role"])
|
||||
}
|
||||
|
||||
func TestGenerateAnonymousToken(t *testing.T) {
|
||||
gen := NewTokenGenerator("test-secret")
|
||||
|
||||
token, err := gen.GenerateAnonymousToken()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 驗證 Token 內容
|
||||
claims := &ConnectionClaims{}
|
||||
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("test-secret"), nil
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, parsedToken.Valid)
|
||||
assert.Equal(t, "", claims.Sub) // 匿名用戶的 UserID 為空
|
||||
}
|
||||
|
||||
func TestTokenExpiration(t *testing.T) {
|
||||
// 創建一個很短過期時間的生成器
|
||||
gen := NewTokenGeneratorWithConfig(TokenConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireIn: 1 * time.Second,
|
||||
})
|
||||
|
||||
token, err := gen.QuickConnectionToken("user-123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 立即驗證應該成功
|
||||
claims := &ConnectionClaims{}
|
||||
parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("test-secret"), nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, parsedToken.Valid)
|
||||
|
||||
// 等待過期
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 過期後驗證應該失敗
|
||||
claims2 := &ConnectionClaims{}
|
||||
_, err = jwt.ParseWithClaims(token, claims2, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("test-secret"), nil
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "token is expired")
|
||||
}
|
||||
|
||||
func TestTokenWithWrongSecret(t *testing.T) {
|
||||
gen := NewTokenGenerator("correct-secret")
|
||||
|
||||
token, err := gen.QuickConnectionToken("user-123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 使用錯誤的密鑰驗證
|
||||
claims := &ConnectionClaims{}
|
||||
_, err = jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("wrong-secret"), nil
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// 輔助函數
|
||||
func ptrTime(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
package utils
|
||||
|
||||
import "time"
|
||||
|
||||
// GetBucketDay 取得 bucket_day(yyyyMMdd 格式)
|
||||
func GetBucketDay(t time.Time) string {
|
||||
return t.Format("20060102")
|
||||
}
|
||||
|
||||
// GetTodayBucketDay 取得今天的 bucket_day
|
||||
func GetTodayBucketDay() string {
|
||||
return GetBucketDay(time.Now())
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
//// GenerateUID 生成匿名 UID
|
||||
//func GenerateUID() string {
|
||||
// return fmt.Sprintf("%s%s", consts.AnonUIDPrefix, uuid.New().String()[:8])
|
||||
//}
|
||||
//
|
||||
//// GenerateRoomID 生成房間 ID
|
||||
//func GenerateRoomID() string {
|
||||
// return fmt.Sprintf("%s%s", consts.RoomIDPrefix, uuid.New().String())
|
||||
//}
|
||||
|
||||
// GenerateMessageID 生成訊息 ID
|
||||
func GenerateMessageID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
Loading…
Reference in New Issue