331 lines
8.1 KiB
Go
331 lines
8.1 KiB
Go
package cassandra
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gocql/gocql"
|
|
|
|
"github.com/scylladb/gocqlx/v3/qb"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestQueryBuilder(t *testing.T) {
|
|
ctx := context.Background()
|
|
db := &CassandraDB{} // 可以用 mock DB
|
|
|
|
type args struct {
|
|
cmp qb.Cmp
|
|
whereArg map[string]any
|
|
selects []string
|
|
orderCol string
|
|
order qb.Order
|
|
limit uint
|
|
setCol string
|
|
setVal any
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
wantPanic bool
|
|
wantColumns []string
|
|
wantOrderCol string
|
|
wantOrder qb.Order
|
|
wantLimit uint
|
|
wantSetCol string
|
|
wantSetVal any
|
|
}{
|
|
{
|
|
name: "where by partition key",
|
|
args: args{
|
|
cmp: qb.Eq("id"),
|
|
whereArg: map[string]any{"id": "abc"},
|
|
selects: []string{"id", "name"},
|
|
orderCol: "id",
|
|
order: qb.ASC,
|
|
limit: 1,
|
|
setCol: "name",
|
|
setVal: "Daniel",
|
|
},
|
|
wantPanic: false,
|
|
wantColumns: []string{"id", "name"},
|
|
wantOrderCol: "id",
|
|
wantOrder: qb.ASC,
|
|
wantLimit: 1,
|
|
wantSetCol: "name",
|
|
wantSetVal: "Daniel",
|
|
},
|
|
{
|
|
name: "where by sai index",
|
|
args: args{
|
|
cmp: qb.Eq("name"),
|
|
whereArg: map[string]any{"name": "daniel"},
|
|
selects: []string{"id", "name"},
|
|
orderCol: "name",
|
|
order: qb.DESC,
|
|
limit: 2,
|
|
setCol: "name",
|
|
setVal: "Jacky",
|
|
},
|
|
wantPanic: false,
|
|
wantColumns: []string{"id", "name"},
|
|
wantOrderCol: "name",
|
|
wantOrder: qb.DESC,
|
|
wantLimit: 2,
|
|
wantSetCol: "name",
|
|
wantSetVal: "Jacky",
|
|
},
|
|
{
|
|
name: "where by non-partition-non-sai",
|
|
args: args{
|
|
cmp: qb.Eq("age"),
|
|
whereArg: map[string]any{"age": 18},
|
|
selects: []string{"id", "name"},
|
|
orderCol: "age",
|
|
order: qb.ASC,
|
|
limit: 3,
|
|
setCol: "age",
|
|
setVal: 20,
|
|
},
|
|
wantPanic: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
q := db.Model(ctx, &MonkeyEntity{}, "my_keyspace").
|
|
Where(tc.args.cmp, tc.args.whereArg).
|
|
Select(tc.args.selects...).
|
|
OrderBy(tc.args.orderCol, tc.args.order).
|
|
Limit(tc.args.limit).
|
|
Set(tc.args.setCol, tc.args.setVal)
|
|
|
|
if tc.wantPanic {
|
|
assert.Error(t, q.Update())
|
|
} else {
|
|
assert.Equal(t, tc.wantColumns, q.columns)
|
|
if len(q.orders) > 0 {
|
|
assert.Equal(t, tc.wantOrderCol, q.orders[0].Column)
|
|
assert.Equal(t, tc.wantOrder, q.orders[0].Order)
|
|
}
|
|
assert.Equal(t, tc.wantLimit, q.limit)
|
|
if len(q.sets) > 0 {
|
|
assert.Equal(t, tc.wantSetCol, q.sets[0].Col)
|
|
assert.Equal(t, tc.wantSetVal, q.sets[0].Val)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestQuery_Select(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
selectCalls [][]string
|
|
wantColumns []string
|
|
}{
|
|
{
|
|
name: "select one col",
|
|
selectCalls: [][]string{{"id"}},
|
|
wantColumns: []string{"id"},
|
|
},
|
|
{
|
|
name: "select multi col in one call",
|
|
selectCalls: [][]string{{"id", "name"}},
|
|
wantColumns: []string{"id", "name"},
|
|
},
|
|
{
|
|
name: "multiple select calls append columns",
|
|
selectCalls: [][]string{{"id"}, {"name"}, {"age"}},
|
|
wantColumns: []string{"id", "name", "age"},
|
|
},
|
|
{
|
|
name: "multiple select calls with overlap",
|
|
selectCalls: [][]string{{"id"}, {"id", "name"}, {"name", "age"}},
|
|
wantColumns: []string{"id", "id", "name", "name", "age"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
q := &Query{columns: make([]string, 0)}
|
|
for _, call := range tc.selectCalls {
|
|
q = q.Select(call...)
|
|
}
|
|
assert.Equal(t, tc.wantColumns, q.columns)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestQuery_Count(t *testing.T) {
|
|
// 準備測試用資料
|
|
container, db := setupForTest(t)
|
|
defer func() {
|
|
_ = container.Container.Terminate(container.Ctx)
|
|
}()
|
|
|
|
db.AutoCreateSAIIndexes(&MonkeyEntity{}, "my_keyspace")
|
|
now := time.Now().UTC()
|
|
// 批量插入資料
|
|
docs := []MonkeyEntity{
|
|
{ID: gocql.TimeUUID(), Name: "Alice", CreateAt: now, UpdateAt: now},
|
|
{ID: gocql.TimeUUID(), Name: "Bob", CreateAt: now, UpdateAt: now},
|
|
{ID: gocql.TimeUUID(), Name: "Alice", CreateAt: now, UpdateAt: now},
|
|
}
|
|
for _, doc := range docs {
|
|
assert.NoError(t, db.Insert(container.Ctx, &doc, "my_keyspace"))
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
filterName string
|
|
wantCount int64
|
|
}{
|
|
{"CountAll", "", 3},
|
|
{"CountAlice", "Alice", 2},
|
|
{"CountBob", "Bob", 1},
|
|
{"CountNobody", "Charlie", 0},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
q := db.Model(container.Ctx, &MonkeyEntity{}, "my_keyspace")
|
|
if tt.filterName != "" {
|
|
q = q.Where(qb.Eq("name"), qb.M{"name": tt.filterName})
|
|
}
|
|
count, err := q.Count()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.wantCount, count)
|
|
})
|
|
}
|
|
}
|
|
|
|
type TestUser struct {
|
|
ID gocql.UUID `db:"id" partition_key:"true"`
|
|
Name string `db:"name" sai:"true"`
|
|
Age int64 `db:"age"`
|
|
}
|
|
|
|
func (TestUser) TableName() string { return "test_user" }
|
|
|
|
func TestQueryBasicFlow(t *testing.T) {
|
|
// 啟動 Cassandra container
|
|
container, db := setupForTest(t)
|
|
defer func() {
|
|
_ = container.Container.Terminate(container.Ctx)
|
|
}()
|
|
|
|
ctx := context.Background()
|
|
keyspace := "my_keyspace"
|
|
err := db.EnsureTable(`
|
|
CREATE TABLE IF NOT EXISTS my_keyspace.test_user (
|
|
id UUID,
|
|
name TEXT,
|
|
age BIGINT,
|
|
PRIMARY KEY (id)
|
|
);`)
|
|
assert.NoError(t, err)
|
|
err = db.AutoCreateSAIIndexes(&TestUser{}, keyspace)
|
|
assert.NoError(t, err)
|
|
// 測試資料
|
|
u1 := TestUser{ID: gocql.TimeUUID(), Name: "Alice", Age: 20}
|
|
u2 := TestUser{ID: gocql.TimeUUID(), Name: "Bob", Age: 22}
|
|
u3 := TestUser{ID: gocql.TimeUUID(), Name: "Carol", Age: 23}
|
|
|
|
// InsertOne/InsertMany
|
|
t.Run("InsertOne", func(t *testing.T) {
|
|
q := db.Model(ctx, TestUser{}, keyspace)
|
|
assert.NoError(t, q.InsertOne(u1))
|
|
})
|
|
|
|
t.Run("InsertMany", func(t *testing.T) {
|
|
q := db.Model(ctx, TestUser{}, keyspace)
|
|
assert.NoError(t, q.InsertMany([]TestUser{u2, u3}))
|
|
})
|
|
|
|
// GetAll
|
|
t.Run("GetAll", func(t *testing.T) {
|
|
var got []TestUser
|
|
q := db.Model(ctx, TestUser{}, keyspace)
|
|
assert.NoError(t, q.GetAll(&got))
|
|
assert.GreaterOrEqual(t, len(got), 3)
|
|
})
|
|
|
|
// Count
|
|
t.Run("Count All", func(t *testing.T) {
|
|
q := db.Model(ctx, TestUser{}, keyspace)
|
|
count, err := q.Count()
|
|
assert.NoError(t, err)
|
|
assert.GreaterOrEqual(t, count, int64(3))
|
|
})
|
|
|
|
// Delete
|
|
t.Run("Delete Carol", func(t *testing.T) {
|
|
q2 := db.Model(ctx, TestUser{}, keyspace)
|
|
q2.Where(qb.Eq("id"), map[string]any{"id": u3.ID})
|
|
assert.NoError(t, q2.Delete())
|
|
// 驗證已刪除
|
|
var user TestUser
|
|
err := db.Model(ctx, TestUser{}, keyspace).
|
|
Where(qb.Eq("id"), map[string]any{"id": u3.ID}).Scan(&user)
|
|
assert.Error(t, err)
|
|
|
|
q3 := db.Model(ctx, TestUser{}, keyspace)
|
|
count, err := q3.Count()
|
|
assert.NoError(t, err)
|
|
assert.GreaterOrEqual(t, count, int64(2))
|
|
})
|
|
|
|
// Scan
|
|
t.Run("Scan Find Alice", func(t *testing.T) {
|
|
var user []TestUser
|
|
err := db.Model(ctx, TestUser{}, keyspace).
|
|
Where(qb.Eq("name"), map[string]any{"name": "Alice"}).Scan(&user)
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, u1.Name, user[0].Name)
|
|
})
|
|
//
|
|
// Take (僅取一筆)
|
|
t.Run("Take Get Bob", func(t *testing.T) {
|
|
var user TestUser
|
|
q2 := db.Model(ctx, TestUser{}, keyspace).
|
|
Where(qb.Eq("name"), map[string]any{"name": "Bob"})
|
|
assert.NoError(t, q2.Take(&user))
|
|
assert.Equal(t, u2.Name, user.Name)
|
|
})
|
|
// Update
|
|
t.Run("Update Age of Alice", func(t *testing.T) {
|
|
q := db.Model(ctx, TestUser{}, keyspace)
|
|
assert.NoError(t, q.InsertMany([]TestUser{u1, u2, u3}))
|
|
|
|
err = db.Model(ctx,
|
|
TestUser{}, keyspace).
|
|
Where(qb.Eq("id"), map[string]any{"id": u1.ID}).
|
|
Set("age", 30).
|
|
Update()
|
|
|
|
assert.NoError(t, err)
|
|
// 驗證
|
|
var user TestUser
|
|
assert.NoError(t, db.Model(ctx, TestUser{}, keyspace).
|
|
Where(qb.Eq("id"), map[string]any{"id": u1.ID}).Take(&user))
|
|
assert.Equal(t, int64(30), user.Age)
|
|
})
|
|
|
|
// In 這個 case 不通過,原因是 sai key 也不一定可以確認 cassandra 分區
|
|
t.Run("In", func(t *testing.T) {
|
|
q := db.Model(ctx, TestUser{}, keyspace)
|
|
assert.NoError(t, q.InsertMany([]TestUser{u1, u2, u3}))
|
|
|
|
var user []TestUser
|
|
err = db.Model(ctx,
|
|
TestUser{}, keyspace).
|
|
Where(qb.In("name"), map[string]any{"name": []string{u1.Name, u2.Name}}).
|
|
Scan(&user)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|