package repository import ( "backend/pkg/chat/domain/entity" "backend/pkg/chat/domain/repository" "backend/pkg/library/cassandra" "context" "os" "strconv" "testing" "time" "github.com/gocql/gocql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) var ( // 全局變量:所有測試共享的資源 testDB *cassandra.DB testRepo repository.MessageRepository testContainer testcontainers.Container cleanupContainer func() ) // TestMain 在所有測試之前執行一次,設置共享的資料庫 func TestMain(m *testing.M) { // 使用更長的 context 超時時間(Cassandra 啟動需要較長時間) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // 啟動 Cassandra 容器(只執行一次) // 不指定固定 port,讓 testcontainers 自動分配,避免與本地開發環境衝突 req := testcontainers.ContainerRequest{ Image: "cassandra:4.1", ExposedPorts: []string{"9042/tcp"}, // 等待 Cassandra 日誌確認啟動完成(Cassandra 啟動需要較長時間) WaitingFor: wait.ForLog("Startup complete"). WithStartupTimeout(3 * time.Minute). WithOccurrence(1), Env: map[string]string{ "CASSANDRA_CLUSTER_NAME": "test-cluster", "HEAP_NEWSIZE": "128M", "MAX_HEAP_SIZE": "512M", }, } var err error testContainer, err = testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ ContainerRequest: req, Started: true, }) if err != nil { panic("Failed to start Cassandra container: " + err.Error()) } port, err := testContainer.MappedPort(ctx, "9042") if err != nil { testContainer.Terminate(ctx) panic("Failed to get mapped port: " + err.Error()) } host, err := testContainer.Host(ctx) if err != nil { testContainer.Terminate(ctx) panic("Failed to get host: " + err.Error()) } portInt, err := strconv.Atoi(port.Port()) if err != nil { testContainer.Terminate(ctx) panic("Failed to convert port to int: " + err.Error()) } // 先創建 DB 連接(不指定 keyspace,因為 keyspace 還不存在) testDB, err = cassandra.New( cassandra.WithHosts(host), cassandra.WithPort(portInt), // 不指定 keyspace,先連接後再創建 ) if err != nil { testContainer.Terminate(ctx) panic("Failed to create DB: " + err.Error()) } // 創建 keyspace(需要在連接後才能創建) createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" if err := testDB.GetSession().Query(createKeyspaceStmt, nil).Exec(); err != nil { testDB.Close() testContainer.Terminate(ctx) panic("Failed to create keyspace: " + err.Error()) } // 等待 keyspace 創建完成 time.Sleep(500 * time.Millisecond) // 創建 messages_by_room 表(需要指定 keyspace) createTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.messages_by_room ( room_id uuid, bucket_day text, ts bigint, uid text, content text, PRIMARY KEY ((room_id, bucket_day), ts) ) WITH CLUSTERING ORDER BY (ts DESC)` if err := testDB.GetSession().Query(createTableStmt, nil).Exec(); err != nil { testDB.Close() testContainer.Terminate(ctx) panic("Failed to create table: " + err.Error()) } // 創建 messages repository testRepo, err = NewMessageRepository(testDB, "test_keyspace") if err != nil { testDB.Close() testContainer.Terminate(ctx) panic("Failed to create message repository: " + err.Error()) } // 創建 room 相關表(供 room_test.go 使用) createRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room ( room_id uuid PRIMARY KEY, name text, status text, created_at bigint, updated_at bigint )` if err := testDB.GetSession().Query(createRoomTableStmt, nil).Exec(); err != nil { testDB.Close() testContainer.Terminate(ctx) panic("Failed to create room table: " + err.Error()) } createMemberTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.room_member ( room_id uuid, uid text, role text, joined_at bigint, PRIMARY KEY (room_id, uid) )` if err := testDB.GetSession().Query(createMemberTableStmt, nil).Exec(); err != nil { testDB.Close() testContainer.Terminate(ctx) panic("Failed to create room_member table: " + err.Error()) } createUserRoomTableStmt := `CREATE TABLE IF NOT EXISTS test_keyspace.user_room ( uid text, room_id uuid, joined_at bigint, PRIMARY KEY (uid, room_id) )` if err := testDB.GetSession().Query(createUserRoomTableStmt, nil).Exec(); err != nil { testDB.Close() testContainer.Terminate(ctx) panic("Failed to create user_room table: " + err.Error()) } // 設置清理函數 cleanupContainer = func() { if testDB != nil { testDB.Close() } if testContainer != nil { _ = testContainer.Terminate(ctx) } } // 執行所有測試 code := m.Run() // 清理資源 cleanupContainer() // 退出 os.Exit(code) } // clearMessages 清空 messages_by_room 表的所有數據 func clearMessages(t *testing.T) { ctx := context.Background() // 使用完整的 keyspace.table 名稱 truncateStmt := "TRUNCATE test_keyspace.messages_by_room" if err := testDB.GetSession().Query(truncateStmt, nil).WithContext(ctx).Exec(); err != nil { t.Fatalf("Failed to truncate messages_by_room table: %v", err) } // 等待數據清空 time.Sleep(50 * time.Millisecond) } func TestNewMessageRepository(t *testing.T) { clearMessages(t) assert.NotNil(t, testRepo) } func TestMessageRepository_Insert(t *testing.T) { clearMessages(t) ctx := context.Background() tests := []struct { name string message *entity.Message wantErr bool }{ { name: "successful insert", message: &entity.Message{ RoomID: gocql.TimeUUID(), BucketDay: "20250117", TS: time.Now().UnixNano(), UID: "user-1", Content: "Hello, world!", }, wantErr: false, }, { name: "insert with auto timestamp", message: &entity.Message{ RoomID: gocql.TimeUUID(), BucketDay: "20250117", UID: "user-2", Content: "Test message", }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := testRepo.Insert(ctx, tt.message) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) // 驗證 TS 和 BucketDay 是否被自動設置 if tt.message.TS == 0 { assert.NotZero(t, tt.message.TS) } if tt.message.BucketDay == "" { assert.NotEmpty(t, tt.message.BucketDay) } } }) } } func TestMessageRepository_ListMessages(t *testing.T) { clearMessages(t) ctx := context.Background() roomID := gocql.TimeUUID() roomIDStr := roomID.String() bucketDay := time.Now().Format(time.DateOnly) // 插入測試數據(使用相同的 roomID) messages := []*entity.Message{ {RoomID: roomID, BucketDay: bucketDay, TS: time.Now().UnixNano(), UID: "user-1", Content: "Message 1"}, {RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Second).UnixNano(), UID: "user-2", Content: "Message 2"}, {RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(2 * time.Second).UnixNano(), UID: "user-3", Content: "Message 3"}, } for _, msg := range messages { require.NoError(t, testRepo.Insert(ctx, msg)) } // 等待數據寫入 time.Sleep(100 * time.Millisecond) tests := []struct { name string param repository.ListMessagesReq wantLen int wantErr bool }{ { name: "list all messages", param: repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: bucketDay, PageSize: 10, LastTS: 0, }, wantLen: 3, wantErr: false, }, { name: "list with page size limit", param: repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: bucketDay, PageSize: 2, LastTS: 0, }, wantLen: 2, wantErr: false, }, { name: "list with cursor pagination", param: repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: bucketDay, PageSize: 10, LastTS: messages[1].TS, // 使用第二條訊息的 TS 作為 cursor }, wantLen: 1, // 應該只返回比 LastTS 更早的訊息 wantErr: false, }, { name: "list with default page size", param: repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: bucketDay, PageSize: 0, // 應該使用預設值 20 LastTS: 0, }, wantLen: 3, wantErr: false, }, { name: "list empty room", param: repository.ListMessagesReq{ RoomID: gocql.TimeUUID().String(), // 使用不存在的 UUID BucketDay: bucketDay, PageSize: 10, LastTS: 0, }, wantLen: 0, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := testRepo.ListMessages(ctx, tt.param) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) assert.Len(t, result, tt.wantLen) // 驗證排序(應該按 TS DESC 排序) if len(result) > 1 { for i := 0; i < len(result)-1; i++ { assert.GreaterOrEqual(t, result[i].TS, result[i+1].TS, "Messages should be sorted by TS DESC") } } } }) } } func TestMessageRepository_Count(t *testing.T) { clearMessages(t) ctx := context.Background() roomID := gocql.TimeUUID() roomIDStr := roomID.String() bucketDay := time.Now().Format(time.DateOnly) // 插入測試數據(使用相同的 roomID) for i := 0; i < 5; i++ { msg := &entity.Message{ RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(), UID: "user-1", Content: "Message", } require.NoError(t, testRepo.Insert(ctx, msg)) } // 等待數據寫入 time.Sleep(100 * time.Millisecond) tests := []struct { name string roomID string want int64 wantErr bool }{ { name: "count existing room", roomID: roomIDStr, want: 5, wantErr: false, }, { name: "count non-existent room", roomID: gocql.TimeUUID().String(), // 使用不存在的 UUID want: 0, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { count, err := testRepo.Count(ctx, tt.roomID) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.want, count) } }) } } func TestMessageRepository_Insert_ListMessages_Integration(t *testing.T) { clearMessages(t) ctx := context.Background() roomID := gocql.TimeUUID() roomIDStr := roomID.String() bucketDay := time.Now().Format(time.DateOnly) // 插入多條訊息(使用相同的 roomID) messages := make([]*entity.Message, 10) for i := 0; i < 10; i++ { msg := &entity.Message{ RoomID: roomID, BucketDay: bucketDay, TS: time.Now().Add(time.Duration(i) * time.Second).UnixNano(), UID: "user-1", Content: "Message", } messages[i] = msg require.NoError(t, testRepo.Insert(ctx, msg)) } // 等待數據寫入 time.Sleep(100 * time.Millisecond) // 驗證可以查詢到所有訊息 result, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: bucketDay, PageSize: 20, LastTS: 0, }) require.NoError(t, err) assert.Len(t, result, 10) // 驗證分頁功能 firstPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: bucketDay, PageSize: 5, LastTS: 0, }) require.NoError(t, err) assert.Len(t, firstPage, 5) // 使用 cursor 獲取下一頁 if len(firstPage) > 0 { lastTS := firstPage[len(firstPage)-1].TS secondPage, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: bucketDay, PageSize: 5, LastTS: lastTS, }) require.NoError(t, err) assert.Len(t, secondPage, 5) // 驗證第二頁的訊息都比第一頁的最後一條更早 if len(secondPage) > 0 { assert.Less(t, secondPage[0].TS, lastTS) } } // 驗證計數 count, err := testRepo.Count(ctx, roomIDStr) require.NoError(t, err) assert.Equal(t, int64(10), count) } func TestMessageRepository_DifferentBuckets(t *testing.T) { clearMessages(t) ctx := context.Background() roomID := gocql.TimeUUID() roomIDStr := roomID.String() today := time.Now().Format(time.DateOnly) yesterday := time.Now().AddDate(0, 0, -1).Format(time.DateOnly) // 插入不同 bucket 的訊息(使用相同的 roomID) todayMsg := &entity.Message{ RoomID: roomID, BucketDay: today, TS: time.Now().UnixNano(), UID: "user-1", Content: "Today's message", } yesterdayMsg := &entity.Message{ RoomID: roomID, BucketDay: yesterday, TS: time.Now().UnixNano(), UID: "user-1", Content: "Yesterday's message", } require.NoError(t, testRepo.Insert(ctx, todayMsg)) require.NoError(t, testRepo.Insert(ctx, yesterdayMsg)) // 等待數據寫入 time.Sleep(100 * time.Millisecond) // 查詢今天的訊息 todayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: today, PageSize: 10, LastTS: 0, }) require.NoError(t, err) assert.Len(t, todayMessages, 1) assert.Equal(t, "Today's message", todayMessages[0].Content) // 查詢昨天的訊息 yesterdayMessages, err := testRepo.ListMessages(ctx, repository.ListMessagesReq{ RoomID: roomIDStr, BucketDay: yesterday, PageSize: 10, LastTS: 0, }) require.NoError(t, err) assert.Len(t, yesterdayMessages, 1) assert.Equal(t, "Yesterday's message", yesterdayMessages[0].Content) }