package cassandra import ( "context" "testing" "time" "github.com/gocql/gocql" "github.com/scylladb/gocqlx/v3/qb" "github.com/stretchr/testify/assert" ) func TestQueryBuilder(t *testing.T) { ctx := context.Background() db := &CassandraDB{} // 可以用 mock DB type args struct { cmp qb.Cmp whereArg map[string]any selects []string orderCol string order qb.Order limit uint setCol string setVal any } tests := []struct { name string args args wantPanic bool wantColumns []string wantOrderCol string wantOrder qb.Order wantLimit uint wantSetCol string wantSetVal any }{ { name: "where by partition key", args: args{ cmp: qb.Eq("id"), whereArg: map[string]any{"id": "abc"}, selects: []string{"id", "name"}, orderCol: "id", order: qb.ASC, limit: 1, setCol: "name", setVal: "Daniel", }, wantPanic: false, wantColumns: []string{"id", "name"}, wantOrderCol: "id", wantOrder: qb.ASC, wantLimit: 1, wantSetCol: "name", wantSetVal: "Daniel", }, { name: "where by sai index", args: args{ cmp: qb.Eq("name"), whereArg: map[string]any{"name": "daniel"}, selects: []string{"id", "name"}, orderCol: "name", order: qb.DESC, limit: 2, setCol: "name", setVal: "Jacky", }, wantPanic: false, wantColumns: []string{"id", "name"}, wantOrderCol: "name", wantOrder: qb.DESC, wantLimit: 2, wantSetCol: "name", wantSetVal: "Jacky", }, { name: "where by non-partition-non-sai", args: args{ cmp: qb.Eq("age"), whereArg: map[string]any{"age": 18}, selects: []string{"id", "name"}, orderCol: "age", order: qb.ASC, limit: 3, setCol: "age", setVal: 20, }, wantPanic: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { q := db.Model(ctx, &MonkeyEntity{}, "my_keyspace"). Where(tc.args.cmp, tc.args.whereArg). Select(tc.args.selects...). OrderBy(tc.args.orderCol, tc.args.order). Limit(tc.args.limit). Set(tc.args.setCol, tc.args.setVal) if tc.wantPanic { assert.Error(t, q.Update()) } else { assert.Equal(t, tc.wantColumns, q.columns) if len(q.orders) > 0 { assert.Equal(t, tc.wantOrderCol, q.orders[0].Column) assert.Equal(t, tc.wantOrder, q.orders[0].Order) } assert.Equal(t, tc.wantLimit, q.limit) if len(q.sets) > 0 { assert.Equal(t, tc.wantSetCol, q.sets[0].Col) assert.Equal(t, tc.wantSetVal, q.sets[0].Val) } } }) } } func TestQuery_Select(t *testing.T) { tests := []struct { name string selectCalls [][]string wantColumns []string }{ { name: "select one col", selectCalls: [][]string{{"id"}}, wantColumns: []string{"id"}, }, { name: "select multi col in one call", selectCalls: [][]string{{"id", "name"}}, wantColumns: []string{"id", "name"}, }, { name: "multiple select calls append columns", selectCalls: [][]string{{"id"}, {"name"}, {"age"}}, wantColumns: []string{"id", "name", "age"}, }, { name: "multiple select calls with overlap", selectCalls: [][]string{{"id"}, {"id", "name"}, {"name", "age"}}, wantColumns: []string{"id", "id", "name", "name", "age"}, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { q := &Query{columns: make([]string, 0)} for _, call := range tc.selectCalls { q = q.Select(call...) } assert.Equal(t, tc.wantColumns, q.columns) }) } } func TestQuery_Count(t *testing.T) { // 準備測試用資料 container, db := setupForTest(t) defer func() { _ = container.Container.Terminate(container.Ctx) }() db.AutoCreateSAIIndexes(&MonkeyEntity{}, "my_keyspace") now := time.Now().UTC() // 批量插入資料 docs := []MonkeyEntity{ {ID: gocql.TimeUUID(), Name: "Alice", CreateAt: now, UpdateAt: now}, {ID: gocql.TimeUUID(), Name: "Bob", CreateAt: now, UpdateAt: now}, {ID: gocql.TimeUUID(), Name: "Alice", CreateAt: now, UpdateAt: now}, } for _, doc := range docs { assert.NoError(t, db.Insert(container.Ctx, &doc, "my_keyspace")) } tests := []struct { name string filterName string wantCount int64 }{ {"CountAll", "", 3}, {"CountAlice", "Alice", 2}, {"CountBob", "Bob", 1}, {"CountNobody", "Charlie", 0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { q := db.Model(container.Ctx, &MonkeyEntity{}, "my_keyspace") if tt.filterName != "" { q = q.Where(qb.Eq("name"), qb.M{"name": tt.filterName}) } count, err := q.Count() assert.NoError(t, err) assert.Equal(t, tt.wantCount, count) }) } } type TestUser struct { ID gocql.UUID `db:"id" partition_key:"true"` Name string `db:"name" sai:"true"` Age int64 `db:"age"` } func (TestUser) TableName() string { return "test_user" } func TestQueryBasicFlow(t *testing.T) { // 啟動 Cassandra container container, db := setupForTest(t) defer func() { _ = container.Container.Terminate(container.Ctx) }() ctx := context.Background() keyspace := "my_keyspace" err := db.EnsureTable(` CREATE TABLE IF NOT EXISTS my_keyspace.test_user ( id UUID, name TEXT, age BIGINT, PRIMARY KEY (id) );`) assert.NoError(t, err) err = db.AutoCreateSAIIndexes(&TestUser{}, keyspace) assert.NoError(t, err) // 測試資料 u1 := TestUser{ID: gocql.TimeUUID(), Name: "Alice", Age: 20} u2 := TestUser{ID: gocql.TimeUUID(), Name: "Bob", Age: 22} u3 := TestUser{ID: gocql.TimeUUID(), Name: "Carol", Age: 23} // InsertOne/InsertMany t.Run("InsertOne", func(t *testing.T) { q := db.Model(ctx, TestUser{}, keyspace) assert.NoError(t, q.InsertOne(u1)) }) t.Run("InsertMany", func(t *testing.T) { q := db.Model(ctx, TestUser{}, keyspace) assert.NoError(t, q.InsertMany([]TestUser{u2, u3})) }) // GetAll t.Run("GetAll", func(t *testing.T) { var got []TestUser q := db.Model(ctx, TestUser{}, keyspace) assert.NoError(t, q.GetAll(&got)) assert.GreaterOrEqual(t, len(got), 3) }) // Count t.Run("Count All", func(t *testing.T) { q := db.Model(ctx, TestUser{}, keyspace) count, err := q.Count() assert.NoError(t, err) assert.GreaterOrEqual(t, count, int64(3)) }) // Delete t.Run("Delete Carol", func(t *testing.T) { q2 := db.Model(ctx, TestUser{}, keyspace) q2.Where(qb.Eq("id"), map[string]any{"id": u3.ID}) assert.NoError(t, q2.Delete()) // 驗證已刪除 var user TestUser err := db.Model(ctx, TestUser{}, keyspace). Where(qb.Eq("id"), map[string]any{"id": u3.ID}).Scan(&user) assert.Error(t, err) q3 := db.Model(ctx, TestUser{}, keyspace) count, err := q3.Count() assert.NoError(t, err) assert.GreaterOrEqual(t, count, int64(2)) }) // Scan t.Run("Scan Find Alice", func(t *testing.T) { var user []TestUser err := db.Model(ctx, TestUser{}, keyspace). Where(qb.Eq("name"), map[string]any{"name": "Alice"}).Scan(&user) assert.NoError(t, err) assert.Equal(t, u1.Name, user[0].Name) }) // // Take (僅取一筆) t.Run("Take Get Bob", func(t *testing.T) { var user TestUser q2 := db.Model(ctx, TestUser{}, keyspace). Where(qb.Eq("name"), map[string]any{"name": "Bob"}) assert.NoError(t, q2.Take(&user)) assert.Equal(t, u2.Name, user.Name) }) // Update t.Run("Update Age of Alice", func(t *testing.T) { q := db.Model(ctx, TestUser{}, keyspace) assert.NoError(t, q.InsertMany([]TestUser{u1, u2, u3})) err = db.Model(ctx, TestUser{}, keyspace). Where(qb.Eq("id"), map[string]any{"id": u1.ID}). Set("age", 30). Update() assert.NoError(t, err) // 驗證 var user TestUser assert.NoError(t, db.Model(ctx, TestUser{}, keyspace). Where(qb.Eq("id"), map[string]any{"id": u1.ID}).Take(&user)) assert.Equal(t, int64(30), user.Age) }) // In 這個 case 不通過,原因是 sai key 也不一定可以確認 cassandra 分區 t.Run("In", func(t *testing.T) { q := db.Model(ctx, TestUser{}, keyspace) assert.NoError(t, q.InsertMany([]TestUser{u1, u2, u3})) var user []TestUser err = db.Model(ctx, TestUser{}, keyspace). Where(qb.In("name"), map[string]any{"name": []string{u1.Name, u2.Name}}). Scan(&user) assert.Error(t, err) }) }