463 lines
11 KiB
Go
463 lines
11 KiB
Go
package cassandra
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"reflect"
|
||
|
||
"github.com/gocql/gocql"
|
||
"github.com/scylladb/gocqlx/v3/qb"
|
||
"github.com/scylladb/gocqlx/v3/table"
|
||
)
|
||
|
||
func (db *CassandraDB) AutoCreateSAIIndexes(doc any, keyspace string) error {
|
||
metadata, err := GenerateTableMetadata(doc, keyspace)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
t := reflect.TypeOf(doc)
|
||
if t.Kind() == reflect.Ptr {
|
||
t = t.Elem()
|
||
}
|
||
for i := 0; i < t.NumField(); i++ {
|
||
f := t.Field(i)
|
||
if f.Tag.Get("sai") == "true" {
|
||
col := f.Tag.Get("db")
|
||
if col == "" {
|
||
col = toSnakeCase(f.Name)
|
||
}
|
||
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS ON %s (%s) USING 'sai';", metadata.Name, col)
|
||
if err := db.GetSession().ExecStmt(stmt); err != nil {
|
||
return fmt.Errorf("failed to create SAI index on table %s, column %s: %w", metadata.Name, col, err)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type Query struct {
|
||
db *CassandraDB
|
||
ctx context.Context
|
||
table string
|
||
keyspace string
|
||
columns []string
|
||
cmps []qb.Cmp
|
||
bindMap map[string]any
|
||
orders []orderBy
|
||
limit uint
|
||
document any
|
||
sets []setField // 欲更新欄位及其值
|
||
errs []error
|
||
}
|
||
|
||
type orderBy struct {
|
||
Column string
|
||
Order qb.Order
|
||
}
|
||
|
||
type setField struct {
|
||
Col string
|
||
Val any
|
||
}
|
||
|
||
// Model 創建一個新的查詢構建器
|
||
// document: 用於推斷表結構的範例物件(必須實現 TableName() 方法)
|
||
// keyspace: 如果為空,則使用初始化時設定的預設 keyspace
|
||
func (db *CassandraDB) Model(ctx context.Context, document any, keyspace string) *Query {
|
||
keyspace = getKeyspace(db, keyspace)
|
||
metadata, err := GenerateTableMetadata(document, keyspace)
|
||
if err != nil {
|
||
// 如果 metadata 生成失敗,創建一個帶錯誤的 Query
|
||
return &Query{
|
||
db: db,
|
||
ctx: ctx,
|
||
keyspace: keyspace,
|
||
document: document,
|
||
errs: []error{err},
|
||
}
|
||
}
|
||
|
||
return &Query{
|
||
db: db,
|
||
ctx: ctx,
|
||
table: metadata.Name,
|
||
keyspace: keyspace,
|
||
columns: make([]string, 0),
|
||
cmps: make([]qb.Cmp, 0),
|
||
bindMap: make(map[string]any),
|
||
orders: make([]orderBy, 0),
|
||
limit: 0,
|
||
document: document, // document 用於生成 metadata 和驗證 SAI 欄位
|
||
errs: make([]error, 0),
|
||
}
|
||
}
|
||
|
||
// Where 添加 WHERE 條件
|
||
// 只允許 partition key 或有 sai index 的欄位進行 where 查詢
|
||
// cmp: 查詢條件(如 qb.Eq("id"))
|
||
// args: 參數映射(如 map[string]any{"id": uuid})
|
||
func (q *Query) Where(cmp qb.Cmp, args map[string]any) *Query {
|
||
// 如果之前有錯誤,直接返回
|
||
if len(q.errs) > 0 {
|
||
return q
|
||
}
|
||
|
||
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
||
if err != nil {
|
||
q.errs = append(q.errs, err)
|
||
return q
|
||
}
|
||
|
||
for k := range args {
|
||
// 允許 partition_key 或 sai 欄位
|
||
isPartition := contains(metadata.PartKey, k)
|
||
isSAI := IsSAIField(q.document, k)
|
||
if !isPartition && !isSAI {
|
||
q.errs = append(q.errs, NewError(
|
||
"INVALID_WHERE_FIELD",
|
||
fmt.Sprintf("where condition on field %s requires partition_key or sai index", k),
|
||
).WithTable(q.table))
|
||
}
|
||
}
|
||
|
||
q.cmps = append(q.cmps, cmp)
|
||
for k, v := range args {
|
||
q.bindMap[k] = v
|
||
}
|
||
|
||
return q
|
||
}
|
||
|
||
func (q *Query) Select(cols ...string) *Query {
|
||
q.columns = append(q.columns, cols...)
|
||
|
||
return q
|
||
}
|
||
|
||
func (q *Query) OrderBy(column string, order qb.Order) *Query {
|
||
q.orders = append(q.orders, orderBy{Column: column, Order: order})
|
||
|
||
return q
|
||
}
|
||
|
||
func (q *Query) Limit(limit uint) *Query {
|
||
q.limit = limit
|
||
|
||
return q
|
||
}
|
||
|
||
func (q *Query) Set(col string, val any) *Query {
|
||
q.sets = append(q.sets, setField{Col: col, Val: val})
|
||
|
||
return q
|
||
}
|
||
|
||
// Scan 執行查詢並將結果掃描到 dest
|
||
// dest 必須是指標類型:*Struct 用於單筆查詢,*[]Struct 用於多筆查詢
|
||
func (q *Query) Scan(dest any) error {
|
||
if len(q.errs) > 0 {
|
||
return errors.Join(q.errs...)
|
||
}
|
||
|
||
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
builder := qb.Select(q.table)
|
||
if len(q.columns) > 0 {
|
||
builder = builder.Columns(q.columns...)
|
||
} else {
|
||
// 如果沒有指定欄位,使用所有欄位
|
||
builder = builder.Columns(metadata.Columns...)
|
||
}
|
||
if len(q.cmps) > 0 {
|
||
builder = builder.Where(q.cmps...)
|
||
}
|
||
if len(q.orders) > 0 {
|
||
for _, o := range q.orders {
|
||
builder = builder.OrderBy(o.Column, o.Order)
|
||
}
|
||
}
|
||
if q.limit > 0 {
|
||
builder = builder.Limit(q.limit)
|
||
}
|
||
|
||
stmt, names := builder.ToCql()
|
||
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
|
||
if q.bindMap == nil {
|
||
q.bindMap = qb.M{}
|
||
}
|
||
query = query.BindMap(q.bindMap)
|
||
|
||
// 型態判斷自動選用單筆/多筆查詢
|
||
destType := reflect.TypeOf(dest)
|
||
if destType.Kind() != reflect.Ptr {
|
||
return ErrInvalidInput.WithTable(q.table).WithError(fmt.Errorf("destination must be a pointer, got %T", dest))
|
||
}
|
||
elemType := destType.Elem()
|
||
switch elemType.Kind() {
|
||
case reflect.Slice:
|
||
return query.SelectRelease(dest)
|
||
case reflect.Struct:
|
||
err := query.GetRelease(dest)
|
||
if err == gocql.ErrNotFound {
|
||
return ErrNotFound.WithTable(q.table)
|
||
}
|
||
return err
|
||
default:
|
||
return ErrInvalidInput.WithTable(q.table).WithError(fmt.Errorf("destination must be pointer to struct or slice, got %T", dest))
|
||
}
|
||
}
|
||
|
||
func (q *Query) Take(dest any) error {
|
||
q.limit = 1
|
||
|
||
return q.Scan(dest)
|
||
}
|
||
|
||
// Delete 執行刪除操作
|
||
// 要求:必須提供所有 partition keys 在 WHERE 條件中
|
||
func (q *Query) Delete() error {
|
||
if len(q.errs) > 0 {
|
||
return errors.Join(q.errs...)
|
||
}
|
||
|
||
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 檢查是否提供所有 partition keys
|
||
missingKeys := make([]string, 0)
|
||
for _, pk := range metadata.PartKey {
|
||
if _, ok := q.bindMap[pk]; !ok {
|
||
missingKeys = append(missingKeys, pk)
|
||
}
|
||
}
|
||
if len(missingKeys) > 0 {
|
||
return ErrMissingPartitionKey.WithTable(q.table).WithError(
|
||
fmt.Errorf("missing partition keys: %v", missingKeys),
|
||
)
|
||
}
|
||
if len(q.cmps) == 0 {
|
||
return ErrMissingWhereCondition.WithTable(q.table)
|
||
}
|
||
|
||
// 組 Delete 語句
|
||
builder := qb.Delete(q.table)
|
||
builder = builder.Where(q.cmps...)
|
||
stmt, names := builder.ToCql()
|
||
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
|
||
if q.bindMap == nil {
|
||
q.bindMap = qb.M{}
|
||
}
|
||
query = query.BindMap(q.bindMap)
|
||
|
||
return query.ExecRelease()
|
||
}
|
||
|
||
// Update 執行更新操作
|
||
// 要求:必須提供至少一個 partition_key 或 sai indexed 欄位在 WHERE 條件中,且至少有一個 Set 欄位
|
||
func (q *Query) Update() error {
|
||
if len(q.errs) > 0 {
|
||
return errors.Join(q.errs...)
|
||
}
|
||
|
||
if q.document == nil {
|
||
return ErrInvalidInput.WithTable(q.table).WithError(
|
||
fmt.Errorf("update requires document model to check partition keys"),
|
||
)
|
||
}
|
||
|
||
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 先收集所有可被當作主查詢條件的欄位
|
||
allowed := make(map[string]struct{})
|
||
|
||
// 收集 partition_key
|
||
for _, pk := range metadata.PartKey {
|
||
allowed[pk] = struct{}{}
|
||
}
|
||
|
||
// 收集所有 sai 欄位
|
||
for _, f := range reflect.VisibleFields(reflect.TypeOf(q.document)) {
|
||
if f.Tag.Get("sai") == "true" {
|
||
col := f.Tag.Get("db")
|
||
if col == "" {
|
||
col = toSnakeCase(f.Name)
|
||
}
|
||
allowed[col] = struct{}{}
|
||
}
|
||
}
|
||
|
||
// 檢查 bindMap 有沒有 hit 到
|
||
hasCondition := false
|
||
for k := range q.bindMap {
|
||
if _, ok := allowed[k]; ok {
|
||
hasCondition = true
|
||
break
|
||
}
|
||
}
|
||
if !hasCondition {
|
||
return ErrMissingPartitionKey.WithTable(q.table).WithError(
|
||
fmt.Errorf("requires at least one partition_key or sai indexed field in WHERE clause"),
|
||
)
|
||
}
|
||
|
||
// 至少要有一個 set 欄位
|
||
if len(q.sets) == 0 {
|
||
return ErrNoFieldsToUpdate.WithTable(q.table)
|
||
}
|
||
// 至少一個 where
|
||
if len(q.cmps) == 0 {
|
||
return ErrMissingWhereCondition.WithTable(q.table)
|
||
}
|
||
|
||
// 組合 set 欄位
|
||
setCols := make([]string, 0, len(q.sets))
|
||
setVals := make([]any, 0, len(q.sets))
|
||
for _, s := range q.sets {
|
||
setCols = append(setCols, s.Col)
|
||
setVals = append(setVals, s.Val)
|
||
}
|
||
|
||
// 組合 CQL
|
||
builder := qb.Update(q.table).Set(setCols...)
|
||
builder = builder.Where(q.cmps...)
|
||
stmt, names := builder.ToCql()
|
||
|
||
// setVals 要先,剩下的 where bind 順序依照 names
|
||
bindVals := append([]any{}, setVals...)
|
||
for _, name := range names[len(setCols):] {
|
||
if v, ok := q.bindMap[name]; ok {
|
||
bindVals = append(bindVals, v)
|
||
}
|
||
}
|
||
|
||
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
|
||
if len(bindVals) > 0 {
|
||
query = query.Bind(bindVals...)
|
||
}
|
||
return query.ExecRelease()
|
||
}
|
||
|
||
// InsertOne 插入單筆資料
|
||
func (q *Query) InsertOne(data any) error {
|
||
if len(q.errs) > 0 {
|
||
return errors.Join(q.errs...)
|
||
}
|
||
|
||
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
tbl := table.New(metadata)
|
||
qry := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(tbl.Insert()))
|
||
|
||
switch reflect.TypeOf(data).Kind() {
|
||
case reflect.Map:
|
||
qry = qry.BindMap(data.(map[string]any))
|
||
default:
|
||
qry = qry.BindStruct(data)
|
||
}
|
||
return qry.ExecRelease()
|
||
}
|
||
|
||
func (q *Query) InsertMany(documents any) error {
|
||
if len(q.errs) > 0 {
|
||
return errors.Join(q.errs...)
|
||
}
|
||
|
||
v := reflect.ValueOf(documents)
|
||
if v.Kind() != reflect.Slice {
|
||
return fmt.Errorf("insert many: input must be a slice, got %T", documents)
|
||
}
|
||
if v.Len() == 0 {
|
||
return nil
|
||
}
|
||
|
||
for i := 0; i < v.Len(); i++ {
|
||
item := v.Index(i).Interface()
|
||
if err := q.InsertOne(item); err != nil {
|
||
return fmt.Errorf("insert many: failed at index %d (table: %s): %w", i, q.table, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetAll 查詢所有資料(不帶條件)
|
||
func (q *Query) GetAll(dest any) error {
|
||
if len(q.errs) > 0 {
|
||
return errors.Join(q.errs...)
|
||
}
|
||
|
||
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
t := table.New(metadata)
|
||
|
||
stmt, names := qb.Select(t.Name()).Columns(metadata.Columns...).ToCql()
|
||
exec := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
|
||
|
||
return exec.SelectRelease(dest)
|
||
}
|
||
|
||
// Count 計算符合條件的記錄數
|
||
func (q *Query) Count() (int64, error) {
|
||
if len(q.errs) > 0 {
|
||
return 0, errors.Join(q.errs...)
|
||
}
|
||
|
||
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
t := table.New(metadata)
|
||
builder := qb.Select(t.Name()).Columns("COUNT(*)")
|
||
if len(q.cmps) > 0 {
|
||
builder = builder.Where(q.cmps...)
|
||
}
|
||
|
||
stmt, names := builder.ToCql()
|
||
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
|
||
if q.bindMap == nil {
|
||
q.bindMap = qb.M{}
|
||
}
|
||
query = query.BindMap(q.bindMap)
|
||
|
||
var count int64
|
||
if err := query.GetRelease(&count); err != nil {
|
||
if err == gocql.ErrNotFound {
|
||
return 0, nil // COUNT 查詢不會返回 ErrNotFound,但為了安全起見
|
||
}
|
||
return 0, err
|
||
}
|
||
return count, nil
|
||
}
|
||
|
||
func IsSAIField(model any, fieldName string) bool {
|
||
t := reflect.TypeOf(model)
|
||
if t.Kind() == reflect.Ptr {
|
||
t = t.Elem()
|
||
}
|
||
for i := 0; i < t.NumField(); i++ {
|
||
f := t.Field(i)
|
||
tag := f.Tag.Get("sai")
|
||
col := f.Tag.Get("db")
|
||
if col == "" {
|
||
col = toSnakeCase(f.Name)
|
||
}
|
||
if (col == fieldName || f.Name == fieldName) && tag == "true" {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|