527 lines
13 KiB
Go
527 lines
13 KiB
Go
package repository
|
||
|
||
import (
|
||
"backend/pkg/chat/domain/entity"
|
||
"backend/pkg/chat/domain/repository"
|
||
"backend/pkg/library/cassandra"
|
||
"context"
|
||
"os"
|
||
"strconv"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/gocql/gocql"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/testcontainers/testcontainers-go"
|
||
"github.com/testcontainers/testcontainers-go/wait"
|
||
)
|
||
|
||
var (
|
||
// 全局變量:所有測試共享的資源
|
||
testDB *cassandra.DB
|
||
testRepo repository.MessageRepository
|
||
testContainer testcontainers.Container
|
||
cleanupContainer func()
|
||
)
|
||
|
||
// TestMain 在所有測試之前執行一次,設置共享的資料庫
|
||
func TestMain(m *testing.M) {
|
||
// 使用更長的 context 超時時間(Cassandra 啟動需要較長時間)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||
defer cancel()
|
||
|
||
// 啟動 Cassandra 容器(只執行一次)
|
||
// 不指定固定 port,讓 testcontainers 自動分配,避免與本地開發環境衝突
|
||
req := testcontainers.ContainerRequest{
|
||
Image: "cassandra:4.1",
|
||
ExposedPorts: []string{"9042/tcp"},
|
||
// 等待 Cassandra 日誌確認啟動完成(Cassandra 啟動需要較長時間)
|
||
WaitingFor: wait.ForLog("Startup complete").
|
||
WithStartupTimeout(3 * time.Minute).
|
||
WithOccurrence(1),
|
||
Env: map[string]string{
|
||
"CASSANDRA_CLUSTER_NAME": "test-cluster",
|
||
"HEAP_NEWSIZE": "128M",
|
||
"MAX_HEAP_SIZE": "512M",
|
||
},
|
||
}
|
||
|
||
var err error
|
||
testContainer, err = testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||
ContainerRequest: req,
|
||
Started: true,
|
||
})
|
||
if err != nil {
|
||
panic("Failed to start Cassandra container: " + err.Error())
|
||
}
|
||
|
||
port, err := testContainer.MappedPort(ctx, "9042")
|
||
if err != nil {
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to get mapped port: " + err.Error())
|
||
}
|
||
|
||
host, err := testContainer.Host(ctx)
|
||
if err != nil {
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to get host: " + err.Error())
|
||
}
|
||
|
||
portInt, err := strconv.Atoi(port.Port())
|
||
if err != nil {
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to convert port to int: " + err.Error())
|
||
}
|
||
|
||
// 先創建 DB 連接(不指定 keyspace,因為 keyspace 還不存在)
|
||
testDB, err = cassandra.New(
|
||
cassandra.WithHosts(host),
|
||
cassandra.WithPort(portInt),
|
||
// 不指定 keyspace,先連接後再創建
|
||
)
|
||
if err != nil {
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to create DB: " + err.Error())
|
||
}
|
||
|
||
// 創建 keyspace(需要在連接後才能創建)
|
||
createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
|
||
if err := testDB.GetSession().Query(createKeyspaceStmt, nil).Exec(); err != nil {
|
||
testDB.Close()
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to create keyspace: " + err.Error())
|
||
}
|
||
|
||
// 等待 keyspace 創建完成
|
||
time.Sleep(500 * time.Millisecond)
|
||
|
||
// 創建 messages_by_room 表(需要指定 keyspace)
|
||
createTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.messages_by_room (
|
||
room_id uuid,
|
||
bucket_day text,
|
||
ts bigint,
|
||
uid text,
|
||
content text,
|
||
PRIMARY KEY ((room_id, bucket_day), ts)
|
||
) WITH CLUSTERING ORDER BY (ts DESC)`
|
||
if err := testDB.GetSession().Query(createTableStmt, nil).Exec(); err != nil {
|
||
testDB.Close()
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to create table: " + err.Error())
|
||
}
|
||
|
||
// 創建 messages repository
|
||
testRepo, err = NewMessageRepository(testDB, "test_keyspace")
|
||
if err != nil {
|
||
testDB.Close()
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to create message repository: " + err.Error())
|
||
}
|
||
|
||
// 創建 room 相關表(供 room_test.go 使用)
|
||
createRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room (
|
||
room_id uuid PRIMARY KEY,
|
||
name text,
|
||
status text,
|
||
created_at bigint,
|
||
updated_at bigint
|
||
)`
|
||
if err := testDB.GetSession().Query(createRoomTableStmt, nil).Exec(); err != nil {
|
||
testDB.Close()
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to create room table: " + err.Error())
|
||
}
|
||
|
||
createMemberTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room_member (
|
||
room_id uuid,
|
||
uid text,
|
||
role text,
|
||
joined_at bigint,
|
||
PRIMARY KEY (room_id, uid)
|
||
)`
|
||
if err := testDB.GetSession().Query(createMemberTableStmt, nil).Exec(); err != nil {
|
||
testDB.Close()
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to create room_member table: " + err.Error())
|
||
}
|
||
|
||
createUserRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.user_room (
|
||
uid text,
|
||
room_id uuid,
|
||
joined_at bigint,
|
||
PRIMARY KEY (uid, room_id)
|
||
)`
|
||
if err := testDB.GetSession().Query(createUserRoomTableStmt, nil).Exec(); err != nil {
|
||
testDB.Close()
|
||
testContainer.Terminate(ctx)
|
||
panic("Failed to create user_room table: " + err.Error())
|
||
}
|
||
|
||
// 設置清理函數
|
||
cleanupContainer = func() {
|
||
if testDB != nil {
|
||
testDB.Close()
|
||
}
|
||
if testContainer != nil {
|
||
_ = testContainer.Terminate(ctx)
|
||
}
|
||
}
|
||
|
||
// 執行所有測試
|
||
code := m.Run()
|
||
|
||
// 清理資源
|
||
cleanupContainer()
|
||
|
||
// 退出
|
||
os.Exit(code)
|
||
}
|
||
|
||
// clearMessages 清空 messages_by_room 表的所有數據
|
||
func clearMessages(t *testing.T) {
|
||
ctx := context.Background()
|
||
// 使用完整的 keyspace.table 名稱
|
||
truncateStmt := "TRUNCATE test_keyspace.messages_by_room"
|
||
if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil {
|
||
t.Fatalf("Failed to truncate messages_by_room table: %v", err)
|
||
}
|
||
// 等待數據清空
|
||
time.Sleep(50 * time.Millisecond)
|
||
}
|
||
|
||
func TestNewMessageRepository(t *testing.T) {
|
||
clearMessages(t)
|
||
assert.NotNil(t, testRepo)
|
||
}
|
||
|
||
func TestMessageRepository_Insert(t *testing.T) {
|
||
clearMessages(t)
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
message *entity.Message
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "successful insert",
|
||
message: &entity.Message{
|
||
RoomID: gocql.TimeUUID(),
|
||
BucketDay: "20250117",
|
||
TS: time.Now().UnixNano(),
|
||
UID: "user-1",
|
||
Content: "Hello, world!",
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "insert with auto timestamp",
|
||
message: &entity.Message{
|
||
RoomID: gocql.TimeUUID(),
|
||
BucketDay: "20250117",
|
||
UID: "user-2",
|
||
Content: "Test message",
|
||
},
|
||
wantErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := testRepo.Insert(ctx, tt.message)
|
||
if tt.wantErr {
|
||
assert.Error(t, err)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
// 驗證 TS 和 BucketDay 是否被自動設置
|
||
if tt.message.TS == 0 {
|
||
assert.NotZero(t, tt.message.TS)
|
||
}
|
||
if tt.message.BucketDay == "" {
|
||
assert.NotEmpty(t, tt.message.BucketDay)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestMessageRepository_ListMessages(t *testing.T) {
|
||
clearMessages(t)
|
||
ctx := context.Background()
|
||
roomID := gocql.TimeUUID()
|
||
roomIDStr := roomID.String()
|
||
bucketDay := time.Now().Format(time.DateOnly)
|
||
|
||
// 插入測試數據(使用相同的 roomID)
|
||
messages := []*entity.Message{
|
||
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().UnixNano(), UID: "user-1", Content: "Message 1"},
|
||
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Second).UnixNano(), UID: "user-2", Content: "Message 2"},
|
||
{RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(2 * time.Second).UnixNano(), UID: "user-3", Content: "Message 3"},
|
||
}
|
||
|
||
for _, msg := range messages {
|
||
require.NoError(t, testRepo.Insert(ctx, msg))
|
||
}
|
||
|
||
// 等待數據寫入
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
tests := []struct {
|
||
name string
|
||
param repository.ListMessagesReq
|
||
wantLen int
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "list all messages",
|
||
param: repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: bucketDay,
|
||
PageSize: 10,
|
||
LastTS: 0,
|
||
},
|
||
wantLen: 3,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "list with page size limit",
|
||
param: repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: bucketDay,
|
||
PageSize: 2,
|
||
LastTS: 0,
|
||
},
|
||
wantLen: 2,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "list with cursor pagination",
|
||
param: repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: bucketDay,
|
||
PageSize: 10,
|
||
LastTS: messages[1].TS, // 使用第二條訊息的 TS 作為 cursor
|
||
},
|
||
wantLen: 1, // 應該只返回比 LastTS 更早的訊息
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "list with default page size",
|
||
param: repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: bucketDay,
|
||
PageSize: 0, // 應該使用預設值 20
|
||
LastTS: 0,
|
||
},
|
||
wantLen: 3,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "list empty room",
|
||
param: repository.ListMessagesReq{
|
||
RoomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
|
||
BucketDay: bucketDay,
|
||
PageSize: 10,
|
||
LastTS: 0,
|
||
},
|
||
wantLen: 0,
|
||
wantErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
result, err := testRepo.ListMessages(ctx, tt.param)
|
||
if tt.wantErr {
|
||
assert.Error(t, err)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
assert.Len(t, result, tt.wantLen)
|
||
// 驗證排序(應該按 TS DESC 排序)
|
||
if len(result) > 1 {
|
||
for i := 0; i < len(result)-1; i++ {
|
||
assert.GreaterOrEqual(t, result[i].TS, result[i+1].TS, "Messages should be sorted by TS DESC")
|
||
}
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestMessageRepository_Count(t *testing.T) {
|
||
clearMessages(t)
|
||
ctx := context.Background()
|
||
roomID := gocql.TimeUUID()
|
||
roomIDStr := roomID.String()
|
||
bucketDay := time.Now().Format(time.DateOnly)
|
||
|
||
// 插入測試數據(使用相同的 roomID)
|
||
for i := 0; i < 5; i++ {
|
||
msg := &entity.Message{
|
||
RoomID: roomID,
|
||
BucketDay: bucketDay,
|
||
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
|
||
UID: "user-1",
|
||
Content: "Message",
|
||
}
|
||
require.NoError(t, testRepo.Insert(ctx, msg))
|
||
}
|
||
|
||
// 等待數據寫入
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
tests := []struct {
|
||
name string
|
||
roomID string
|
||
want int64
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "count existing room",
|
||
roomID: roomIDStr,
|
||
want: 5,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "count non-existent room",
|
||
roomID: gocql.TimeUUID().String(), // 使用不存在的 UUID
|
||
want: 0,
|
||
wantErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
count, err := testRepo.Count(ctx, tt.roomID)
|
||
if tt.wantErr {
|
||
assert.Error(t, err)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, tt.want, count)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestMessageRepository_Insert_ListMessages_Integration(t *testing.T) {
|
||
clearMessages(t)
|
||
ctx := context.Background()
|
||
roomID := gocql.TimeUUID()
|
||
roomIDStr := roomID.String()
|
||
bucketDay := time.Now().Format(time.DateOnly)
|
||
|
||
// 插入多條訊息(使用相同的 roomID)
|
||
messages := make([]*entity.Message, 10)
|
||
for i := 0; i < 10; i++ {
|
||
msg := &entity.Message{
|
||
RoomID: roomID,
|
||
BucketDay: bucketDay,
|
||
TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(),
|
||
UID: "user-1",
|
||
Content: "Message",
|
||
}
|
||
messages[i] = msg
|
||
require.NoError(t, testRepo.Insert(ctx, msg))
|
||
}
|
||
|
||
// 等待數據寫入
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
// 驗證可以查詢到所有訊息
|
||
result, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: bucketDay,
|
||
PageSize: 20,
|
||
LastTS: 0,
|
||
})
|
||
require.NoError(t, err)
|
||
assert.Len(t, result, 10)
|
||
|
||
// 驗證分頁功能
|
||
firstPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: bucketDay,
|
||
PageSize: 5,
|
||
LastTS: 0,
|
||
})
|
||
require.NoError(t, err)
|
||
assert.Len(t, firstPage, 5)
|
||
|
||
// 使用 cursor 獲取下一頁
|
||
if len(firstPage) > 0 {
|
||
lastTS := firstPage[len(firstPage)-1].TS
|
||
secondPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: bucketDay,
|
||
PageSize: 5,
|
||
LastTS: lastTS,
|
||
})
|
||
require.NoError(t, err)
|
||
assert.Len(t, secondPage, 5)
|
||
// 驗證第二頁的訊息都比第一頁的最後一條更早
|
||
if len(secondPage) > 0 {
|
||
assert.Less(t, secondPage[0].TS, lastTS)
|
||
}
|
||
}
|
||
|
||
// 驗證計數
|
||
count, err := testRepo.Count(ctx, roomIDStr)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, int64(10), count)
|
||
}
|
||
|
||
func TestMessageRepository_DifferentBuckets(t *testing.T) {
|
||
clearMessages(t)
|
||
ctx := context.Background()
|
||
roomID := gocql.TimeUUID()
|
||
roomIDStr := roomID.String()
|
||
today := time.Now().Format(time.DateOnly)
|
||
yesterday := time.Now().AddDate(0, 0, -1).Format(time.DateOnly)
|
||
|
||
// 插入不同 bucket 的訊息(使用相同的 roomID)
|
||
todayMsg := &entity.Message{
|
||
RoomID: roomID,
|
||
BucketDay: today,
|
||
TS: time.Now().UnixNano(),
|
||
UID: "user-1",
|
||
Content: "Today's message",
|
||
}
|
||
yesterdayMsg := &entity.Message{
|
||
RoomID: roomID,
|
||
BucketDay: yesterday,
|
||
TS: time.Now().UnixNano(),
|
||
UID: "user-1",
|
||
Content: "Yesterday's message",
|
||
}
|
||
|
||
require.NoError(t, testRepo.Insert(ctx, todayMsg))
|
||
require.NoError(t, testRepo.Insert(ctx, yesterdayMsg))
|
||
|
||
// 等待數據寫入
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
// 查詢今天的訊息
|
||
todayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: today,
|
||
PageSize: 10,
|
||
LastTS: 0,
|
||
})
|
||
require.NoError(t, err)
|
||
assert.Len(t, todayMessages, 1)
|
||
assert.Equal(t, "Today's message", todayMessages[0].Content)
|
||
|
||
// 查詢昨天的訊息
|
||
yesterdayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{
|
||
RoomID: roomIDStr,
|
||
BucketDay: yesterday,
|
||
PageSize: 10,
|
||
LastTS: 0,
|
||
})
|
||
require.NoError(t, err)
|
||
assert.Len(t, yesterdayMessages, 1)
|
||
assert.Equal(t, "Yesterday's message", yesterdayMessages[0].Content)
|
||
}
|