backend/pkg/chat/repository/message_test.go

527 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"backend/pkg/chat/domain/entity"
"backend/pkg/chat/domain/repository"
"backend/pkg/library/cassandra"
"context"
"os"
"strconv"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
var (
// 全局變量:所有測試共享的資源
testDB *cassandra.DB
testRepo repository.MessageRepository
testContainer testcontainers.Container
cleanupContainer func()
)
// TestMain 在所有測試之前執行一次,設置共享的資料庫
func TestMain(m *testing.M) {
// 使用更長的 context 超時時間Cassandra 啟動需要較長時間)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// 啟動 Cassandra 容器(只執行一次)
// 不指定固定 port讓 testcontainers 自動分配,避免與本地開發環境衝突
req := testcontainers.ContainerRequest{
Image: "cassandra:4.1",
ExposedPorts: []string{"9042/tcp"},
// 等待 Cassandra 日誌確認啟動完成Cassandra 啟動需要較長時間)
WaitingFor: wait.ForLog("Startup complete").
WithStartupTimeout(3 * time.Minute).
WithOccurrence(1),
Env: map[string]string{
"CASSANDRA_CLUSTER_NAME": "test-cluster",
"HEAP_NEWSIZE": "128M",
"MAX_HEAP_SIZE": "512M",
},
}
var err error
testContainer, err = testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
panic("Failed to start Cassandra container: " + err.Error())
}
port, err := testContainer.MappedPort(ctx, "9042")
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to get mapped port: " + err.Error())
}
host, err := testContainer.Host(ctx)
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to get host: " + err.Error())
}
portInt, err := strconv.Atoi(port.Port())
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to convert port to int: " + err.Error())
}
// 先創建 DB 連接(不指定 keyspace因為 keyspace 還不存在)
testDB, err = cassandra.New(
cassandra.WithHosts(host),
cassandra.WithPort(portInt),
// 不指定 keyspace先連接後再創建
)
if err != nil {
testContainer.Terminate(ctx)
panic("Failed to create DB: " + err.Error())
}
// 創建 keyspace需要在連接後才能創建
createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
if err := testDB.GetSession().Query(createKeyspaceStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create keyspace: " + err.Error())
}
// 等待 keyspace 創建完成
time.Sleep(500 * time.Millisecond)
// 創建 messages_by_room 表(需要指定 keyspace
createTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.messages_by_room (
room_id uuid,
bucket_day text,
ts bigint,
uid text,
content text,
PRIMARY KEY ((room_id, bucket_day), ts)
) WITH CLUSTERING ORDER BY (ts DESC)`
if err := testDB.GetSession().Query(createTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create table: " + err.Error())
}
// 創建 messages repository
testRepo, err = NewMessageRepository(testDB, "test_keyspace")
if err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create message repository: " + err.Error())
}
// 創建 room 相關表(供 room_test.go 使用)
createRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room (
room_id uuid PRIMARY KEY,
name text,
status text,
created_at bigint,
updated_at bigint
)`
if err := testDB.GetSession().Query(createRoomTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create room table: " + err.Error())
}
createMemberTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room_member (
room_id uuid,
uid text,
role text,
joined_at bigint,
PRIMARY KEY (room_id, uid)
)`
if err := testDB.GetSession().Query(createMemberTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create room_member table: " + err.Error())
}
createUserRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.user_room (
uid text,
room_id uuid,
joined_at bigint,
PRIMARY KEY (uid, room_id)
)`
if err := testDB.GetSession().Query(createUserRoomTableStmt, nil).Exec(); err != nil {
testDB.Close()
testContainer.Terminate(ctx)
panic("Failed to create user_room table: " + err.Error())
}
// 設置清理函數
cleanupContainer = func() {
if testDB != nil {
testDB.Close()
}
if testContainer != nil {
_ = testContainer.Terminate(ctx)
}
}
// 執行所有測試
code := m.Run()
// 清理資源
cleanupContainer()
// 退出
os.Exit(code)
}
// clearMessages 清空 messages_by_room 表的所有數據
func clearMessages(t *testing.T) {
ctx := context.Background()
// 使用完整的 keyspace.table 名稱
truncateStmt := "TRUNCATE test_keyspace.messages_by_room"
if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil {
t.Fatalf("Failed to truncate messages_by_room table: %v", err)
}
// 等待數據清空
time.Sleep(50 * time.Millisecond)
}
func TestNewMessageRepository(t *testing.T) {
clearMessages(t)
assert.NotNil(t, testRepo)
}
func TestMessageRepository_Insert(t *testing.T) {
clearMessages(t)
ctx := context.Background()
tests := []struct {
name string
message *entity.Message
wantErr bool
}{
{
name: "successful insert",
message: &entity.Message{
RoomID: gocql.TimeUUID(),
BucketDay: "20250117",
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Hello, world!",
},
wantErr: false,
},
{
name: "insert with auto timestamp",
message: &entity.Message{
RoomID: gocql.TimeUUID(),
BucketDay: "20250117",
UID: "user-2",
Content: "Test message",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := testRepo.Insert(ctx, tt.message)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
// 驗證 TS 和 BucketDay 是否被自動設置
if tt.message.TS == 0 {
assert.NotZero(t, tt.message.TS)
}
if tt.message.BucketDay == "" {
assert.NotEmpty(t, tt.message.BucketDay)
}
}
})
}
}
func TestMessageRepository_ListMessages(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入測試數據(使用相同的 roomID
messages := []*entity.Message{
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().UnixNano(), UID: "user-1", Content: "Message 1"},
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Second).UnixNano(), UID: "user-2", Content: "Message 2"},
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(2 * time.Second).UnixNano(), UID: "user-3", Content: "Message 3"},
}
for _, msg := range messages {
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
param repository.ListMessagesReq
wantLen int
wantErr bool
}{
{
name: "list all messages",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 10,
LastTS: 0,
},
wantLen: 3,
wantErr: false,
},
{
name: "list with page size limit",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 2,
LastTS: 0,
},
wantLen: 2,
wantErr: false,
},
{
name: "list with cursor pagination",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 10,
LastTS: messages[1].TS, // 使用第二條訊息的 TS 作為 cursor
},
wantLen: 1, // 應該只返回比 LastTS 更早的訊息
wantErr: false,
},
{
name: "list with default page size",
param: repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 0, // 應該使用預設值 20
LastTS: 0,
},
wantLen: 3,
wantErr: false,
},
{
name: "list empty room",
param: repository.ListMessagesReq{
RoomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
BucketDay: bucketDay,
PageSize: 10,
LastTS: 0,
},
wantLen: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := testRepo.ListMessages(ctx, tt.param)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, result, tt.wantLen)
// 驗證排序(應該按 TS DESC 排序)
if len(result) > 1 {
for i := 0; i < len(result)-1; i++ {
assert.GreaterOrEqual(t, result[i].TS, result[i+1].TS, "Messages should be sorted by TS DESC")
}
}
}
})
}
}
func TestMessageRepository_Count(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入測試數據(使用相同的 roomID
for i := 0; i < 5; i++ {
msg := &entity.Message{
RoomID: roomID,
BucketDay: bucketDay,
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
UID: "user-1",
Content: "Message",
}
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
roomID string
want int64
wantErr bool
}{
{
name: "count existing room",
roomID: roomIDStr,
want: 5,
wantErr: false,
},
{
name: "count non-existent room",
roomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
want: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
count, err := testRepo.Count(ctx, tt.roomID)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, count)
}
})
}
}
func TestMessageRepository_Insert_ListMessages_Integration(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
bucketDay := time.Now().Format(time.DateOnly)
// 插入多條訊息(使用相同的 roomID
messages := make([]*entity.Message, 10)
for i := 0; i < 10; i++ {
msg := &entity.Message{
RoomID: roomID,
BucketDay: bucketDay,
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
UID: "user-1",
Content: "Message",
}
messages[i] = msg
require.NoError(t, testRepo.Insert(ctx, msg))
}
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 驗證可以查詢到所有訊息
result, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 20,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, result, 10)
// 驗證分頁功能
firstPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 5,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, firstPage, 5)
// 使用 cursor 獲取下一頁
if len(firstPage) > 0 {
lastTS := firstPage[len(firstPage)-1].TS
secondPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: bucketDay,
PageSize: 5,
LastTS: lastTS,
})
require.NoError(t, err)
assert.Len(t, secondPage, 5)
// 驗證第二頁的訊息都比第一頁的最後一條更早
if len(secondPage) > 0 {
assert.Less(t, secondPage[0].TS, lastTS)
}
}
// 驗證計數
count, err := testRepo.Count(ctx, roomIDStr)
require.NoError(t, err)
assert.Equal(t, int64(10), count)
}
func TestMessageRepository_DifferentBuckets(t *testing.T) {
clearMessages(t)
ctx := context.Background()
roomID := gocql.TimeUUID()
roomIDStr := roomID.String()
today := time.Now().Format(time.DateOnly)
yesterday := time.Now().AddDate(0, 0, -1).Format(time.DateOnly)
// 插入不同 bucket 的訊息(使用相同的 roomID
todayMsg := &entity.Message{
RoomID: roomID,
BucketDay: today,
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Today's message",
}
yesterdayMsg := &entity.Message{
RoomID: roomID,
BucketDay: yesterday,
TS: time.Now().UnixNano(),
UID: "user-1",
Content: "Yesterday's message",
}
require.NoError(t, testRepo.Insert(ctx, todayMsg))
require.NoError(t, testRepo.Insert(ctx, yesterdayMsg))
// 等待數據寫入
time.Sleep(100 * time.Millisecond)
// 查詢今天的訊息
todayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: today,
PageSize: 10,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, todayMessages, 1)
assert.Equal(t, "Today's message", todayMessages[0].Content)
// 查詢昨天的訊息
yesterdayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
RoomID: roomIDStr,
BucketDay: yesterday,
PageSize: 10,
LastTS: 0,
})
require.NoError(t, err)
assert.Len(t, yesterdayMessages, 1)
assert.Equal(t, "Yesterday's message", yesterdayMessages[0].Content)
}