backend/pkg/chat/repository/message_test.go

527 lines
13 KiB
Go
Raw Normal View History

2026-01-06 07:15:18 +00:00
package repository
import (
"backend/pkg/chat/domain/entity"
"backend/pkg/chat/domain/repository"
"backend/pkg/library/cassandra"
"context"
"os"
"strconv"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
var (
// 全局變量:所有測試共享的資源
testDB *cassandra.DB
testRepo repository.MessageRepository
testContainer testcontainers.Container
cleanupContainer func()
)
// TestMain 在所有測試之前執行一次,設置共享的資料庫
func TestMain(m *testing.M) {
// 使用更長的 context 超時時間Cassandra 啟動需要較長時間)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// 啟動 Cassandra 容器(只執行一次)
// 不指定固定 port讓 testcontainers 自動分配,避免與本地開發環境衝突
req := testcontainers.ContainerRequest{
Image: "cassandra:4.1",
ExposedPorts: []string{"9042/tcp"},
// 等待 Cassandra 日誌確認啟動完成Cassandra 啟動需要較長時間)
WaitingFor: wait.ForLog("Startup complete").
WithStartupTimeout(3 * time.Minute).
WithOccurrence(1),
Env: map[string]string{
"CASSANDRA_CLUSTER_NAME": "test-cluster",
"HEAP_NEWSIZE": "128M",
"MAX_HEAP_SIZE": "512M",
},
}
var err error
testContainer, err = testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
panic("Failed to start Cassandra container: " + err.Error())
}
port, err := testContainer.MappedPort(ctx, "9042")
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to get mapped port: " + err.Error())
}
host, err := testContainer.Host(ctx)
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to get host: " + err.Error())
}
portInt, err := strconv.Atoi(port.Port())
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to convert port to int: " + err.Error())
}
// 先創建 DB 連接(不指定 keyspace因為 keyspace 還不存在)
testDB, err = cassandra.New(
cassandra.WithHosts(host),
cassandra.WithPort(portInt),
// 不指定 keyspace先連接後再創建
)
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to create DB: " + err.Error())
}
// 創建 keyspace需要在連接後才能創建
createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
if err := testDB.GetSession().Query(createKeyspaceStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create keyspace: " + err.Error())
}
// 等待 keyspace 創建完成
time.Sleep(500 * time.Millisecond)
// 創建 messages_by_room 表(需要指定 keyspace
createTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.messages_by_room (
room_id uuid,
bucket_day text,
ts bigint,
uid text,
content text,
PRIMARY KEY ((room_id, bucket_day), ts)
) WITH CLUSTERING ORDER BY (ts DESC)`
if err := testDB.GetSession().Query(createTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create table: " + err.Error())
}
// 創建 messages repository
testRepo, err = NewMessageRepository(testDB, "test_keyspace")
if err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create message repository: " + err.Error())
}
// 創建 room 相關表(供 room_test.go 使用)
createRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room (
room_id uuid PRIMARY KEY,
name text,
status text,
created_at bigint,
updated_at bigint
)`
if err := testDB.GetSession().Query(createRoomTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create room table: " + err.Error())
}
createMemberTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room_member (
room_id uuid,
uid text,
role text,
joined_at bigint,
PRIMARY KEY (room_id, uid)
)`
if err := testDB.GetSession().Query(createMemberTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create room_member table: " + err.Error())
}
createUserRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.user_room (
uid text,
room_id uuid,
joined_at bigint,
PRIMARY KEY (uid, room_id)
)`
if err := testDB.GetSession().Query(createUserRoomTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create user_room table: " + err.Error())
}
// 設置清理函數
cleanupContainer = func() {
if testDB != nil {
testDB.Close()
}
if testContainer != nil {
_ = testContainer.Terminate(ctx)
}
}
// 執行所有測試
code := m.Run()
// 清理資源
cleanupContainer()
// 退出
os.Exit(code)
}
// clearMessages 清空 messages_by_room 表的所有數據
func clearMessages(t *testing.T) {
ctx := context.Background()
// 使用完整的 keyspace.table 名稱
truncateStmt := "TRUNCATE test_keyspace.messages_by_room"
if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil {
t.Fatalf("Failed to truncate messages_by_room table: %v", err)
}
// 等待數據清空
time.Sleep(50 * time.Millisecond)
}
func TestNewMessageRepository(t *testing.T) {
clearMessages(t)
assert.NotNil(t, testRepo)
}
func TestMessageRepository_Insert(t *testing.T) {
clearMessages(t)
ctx := context.Background()
tests := []struct {
name string
message *entity.Message
wantErr bool
}{
{
name: "successful insert",
message: &entity.Message{
RoomID: gocql.TimeUUID(),
BucketDay: "20250117",
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Hello, world!",
},
wantErr: false,
},
{
name: "insert with auto timestamp",
message: &entity.Message{
RoomID: gocql.TimeUUID(),
BucketDay: "20250117",
UID: "user-2",
Content: "Test message",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := testRepo.Insert(ctx, tt.message)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
// 驗證 TS 和 BucketDay 是否被自動設置
if tt.message.TS == 0 {
assert.NotZero(t, tt.message.TS)
}
if tt.message.BucketDay == "" {
assert.NotEmpty(t, tt.message.BucketDay)
}
}
})
}
}
func TestMessageRepository_ListMessages(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入測試數據(使用相同的 roomID
messages := []*entity.Message{
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().UnixNano(), UID: "user-1", Content: "Message 1"},
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Second).UnixNano(), UID: "user-2", Content: "Message 2"},
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(2 * time.Second).UnixNano(), UID: "user-3", Content: "Message 3"},
}
for _, msg := range messages {
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
param repository.ListMessagesReq
wantLen int
wantErr bool
}{
{
name: "list all messages",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 10,
LastTS: 0,
},
wantLen: 3,
wantErr: false,
},
{
name: "list with page size limit",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 2,
LastTS: 0,
},
wantLen: 2,
wantErr: false,
},
{
name: "list with cursor pagination",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 10,
LastTS: messages[1].TS, // 使用第二條訊息的 TS 作為 cursor
},
wantLen: 1, // 應該只返回比 LastTS 更早的訊息
wantErr: false,
},
{
name: "list with default page size",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 0, // 應該使用預設值 20
LastTS: 0,
},
wantLen: 3,
wantErr: false,
},
{
name: "list empty room",
param: repository.ListMessagesReq{
RoomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
BucketDay: bucketDay,
PageSize: 10,
LastTS: 0,
},
wantLen: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := testRepo.ListMessages(ctx, tt.param)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, result, tt.wantLen)
// 驗證排序(應該按 TS DESC 排序)
if len(result) > 1 {
for i := 0; i < len(result)-1; i++ {
assert.GreaterOrEqual(t, result[i].TS, result[i+1].TS, "Messages should be sorted by TS DESC")
}
}
}
})
}
}
func TestMessageRepository_Count(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入測試數據(使用相同的 roomID
for i := 0; i < 5; i++ {
msg := &entity.Message{
RoomID: roomID,
BucketDay: bucketDay,
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
UID: "user-1",
Content: "Message",
}
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
roomID string
want int64
wantErr bool
}{
{
name: "count existing room",
roomID: roomIDStr,
want: 5,
wantErr: false,
},
{
name: "count non-existent room",
roomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
want: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
count, err := testRepo.Count(ctx, tt.roomID)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, count)
}
})
}
}
func TestMessageRepository_Insert_ListMessages_Integration(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入多條訊息(使用相同的 roomID
messages := make([]*entity.Message, 10)
for i := 0; i < 10; i++ {
msg := &entity.Message{
RoomID: roomID,
BucketDay: bucketDay,
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
UID: "user-1",
Content: "Message",
}
messages[i] = msg
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 驗證可以查詢到所有訊息
result, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 20,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, result, 10)
// 驗證分頁功能
firstPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 5,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, firstPage, 5)
// 使用 cursor 獲取下一頁
if len(firstPage) > 0 {
lastTS := firstPage[len(firstPage)-1].TS
secondPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 5,
LastTS: lastTS,
})
require.NoError(t, err)
assert.Len(t, secondPage, 5)
// 驗證第二頁的訊息都比第一頁的最後一條更早
if len(secondPage) > 0 {
assert.Less(t, secondPage[0].TS, lastTS)
}
}
// 驗證計數
count, err := testRepo.Count(ctx, roomIDStr)
require.NoError(t, err)
assert.Equal(t, int64(10), count)
}
func TestMessageRepository_DifferentBuckets(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
today := time.Now().Format(time.DateOnly)
yesterday := time.Now().AddDate(0, 0, -1).Format(time.DateOnly)
// 插入不同 bucket 的訊息(使用相同的 roomID
todayMsg := &entity.Message{
RoomID: roomID,
BucketDay: today,
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Today's message",
}
yesterdayMsg := &entity.Message{
RoomID: roomID,
BucketDay: yesterday,
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Yesterday's message",
}
require.NoError(t, testRepo.Insert(ctx, todayMsg))
require.NoError(t, testRepo.Insert(ctx, yesterdayMsg))
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 查詢今天的訊息
todayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: today,
PageSize: 10,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, todayMessages, 1)
assert.Equal(t, "Today's message", todayMessages[0].Content)
// 查詢昨天的訊息
yesterdayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: yesterday,
PageSize: 10,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, yesterdayMessages, 1)
assert.Equal(t, "Yesterday's message", yesterdayMessages[0].Content)
}