backend/pkg/library/cassandra/table.go

463 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cassandra
import (
"context"
"errors"
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/scylladb/gocqlx/v3/table"
)
func (db *CassandraDB) AutoCreateSAIIndexes(doc any, keyspace string) error {
metadata, err := GenerateTableMetadata(doc, keyspace)
if err != nil {
return err
}
t := reflect.TypeOf(doc)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Tag.Get("sai") == "true" {
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS ON %s (%s) USING 'sai';", metadata.Name, col)
if err := db.GetSession().ExecStmt(stmt); err != nil {
return fmt.Errorf("failed to create SAI index on table %s, column %s: %w", metadata.Name, col, err)
}
}
}
return nil
}
type Query struct {
db *CassandraDB
ctx context.Context
table string
keyspace string
columns []string
cmps []qb.Cmp
bindMap map[string]any
orders []orderBy
limit uint
document any
sets []setField // 欲更新欄位及其值
errs []error
}
type orderBy struct {
Column string
Order qb.Order
}
type setField struct {
Col string
Val any
}
// Model 創建一個新的查詢構建器
// document: 用於推斷表結構的範例物件(必須實現 TableName() 方法)
// keyspace: 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) Model(ctx context.Context, document any, keyspace string) *Query {
keyspace = getKeyspace(db, keyspace)
metadata, err := GenerateTableMetadata(document, keyspace)
if err != nil {
// 如果 metadata 生成失敗,創建一個帶錯誤的 Query
return &Query{
db: db,
ctx: ctx,
keyspace: keyspace,
document: document,
errs: []error{err},
}
}
return &Query{
db: db,
ctx: ctx,
table: metadata.Name,
keyspace: keyspace,
columns: make([]string, 0),
cmps: make([]qb.Cmp, 0),
bindMap: make(map[string]any),
orders: make([]orderBy, 0),
limit: 0,
document: document, // document 用於生成 metadata 和驗證 SAI 欄位
errs: make([]error, 0),
}
}
// Where 添加 WHERE 條件
// 只允許 partition key 或有 sai index 的欄位進行 where 查詢
// cmp: 查詢條件(如 qb.Eq("id")
// args: 參數映射(如 map[string]any{"id": uuid}
func (q *Query) Where(cmp qb.Cmp, args map[string]any) *Query {
// 如果之前有錯誤,直接返回
if len(q.errs) > 0 {
return q
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
q.errs = append(q.errs, err)
return q
}
for k := range args {
// 允許 partition_key 或 sai 欄位
isPartition := contains(metadata.PartKey, k)
isSAI := IsSAIField(q.document, k)
if !isPartition && !isSAI {
q.errs = append(q.errs, NewError(
"INVALID_WHERE_FIELD",
fmt.Sprintf("where condition on field %s requires partition_key or sai index", k),
).WithTable(q.table))
}
}
q.cmps = append(q.cmps, cmp)
for k, v := range args {
q.bindMap[k] = v
}
return q
}
func (q *Query) Select(cols ...string) *Query {
q.columns = append(q.columns, cols...)
return q
}
func (q *Query) OrderBy(column string, order qb.Order) *Query {
q.orders = append(q.orders, orderBy{Column: column, Order: order})
return q
}
func (q *Query) Limit(limit uint) *Query {
q.limit = limit
return q
}
func (q *Query) Set(col string, val any) *Query {
q.sets = append(q.sets, setField{Col: col, Val: val})
return q
}
// Scan 執行查詢並將結果掃描到 dest
// dest 必須是指標類型:*Struct 用於單筆查詢,*[]Struct 用於多筆查詢
func (q *Query) Scan(dest any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
builder := qb.Select(q.table)
if len(q.columns) > 0 {
builder = builder.Columns(q.columns...)
} else {
// 如果沒有指定欄位,使用所有欄位
builder = builder.Columns(metadata.Columns...)
}
if len(q.cmps) > 0 {
builder = builder.Where(q.cmps...)
}
if len(q.orders) > 0 {
for _, o := range q.orders {
builder = builder.OrderBy(o.Column, o.Order)
}
}
if q.limit > 0 {
builder = builder.Limit(q.limit)
}
stmt, names := builder.ToCql()
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
// 型態判斷自動選用單筆/多筆查詢
destType := reflect.TypeOf(dest)
if destType.Kind() != reflect.Ptr {
return ErrInvalidInput.WithTable(q.table).WithError(fmt.Errorf("destination must be a pointer, got %T", dest))
}
elemType := destType.Elem()
switch elemType.Kind() {
case reflect.Slice:
return query.SelectRelease(dest)
case reflect.Struct:
err := query.GetRelease(dest)
if err == gocql.ErrNotFound {
return ErrNotFound.WithTable(q.table)
}
return err
default:
return ErrInvalidInput.WithTable(q.table).WithError(fmt.Errorf("destination must be pointer to struct or slice, got %T", dest))
}
}
func (q *Query) Take(dest any) error {
q.limit = 1
return q.Scan(dest)
}
// Delete 執行刪除操作
// 要求:必須提供所有 partition keys 在 WHERE 條件中
func (q *Query) Delete() error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
// 檢查是否提供所有 partition keys
missingKeys := make([]string, 0)
for _, pk := range metadata.PartKey {
if _, ok := q.bindMap[pk]; !ok {
missingKeys = append(missingKeys, pk)
}
}
if len(missingKeys) > 0 {
return ErrMissingPartitionKey.WithTable(q.table).WithError(
fmt.Errorf("missing partition keys: %v", missingKeys),
)
}
if len(q.cmps) == 0 {
return ErrMissingWhereCondition.WithTable(q.table)
}
// 組 Delete 語句
builder := qb.Delete(q.table)
builder = builder.Where(q.cmps...)
stmt, names := builder.ToCql()
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
return query.ExecRelease()
}
// Update 執行更新操作
// 要求:必須提供至少一個 partition_key 或 sai indexed 欄位在 WHERE 條件中,且至少有一個 Set 欄位
func (q *Query) Update() error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
if q.document == nil {
return ErrInvalidInput.WithTable(q.table).WithError(
fmt.Errorf("update requires document model to check partition keys"),
)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
// 先收集所有可被當作主查詢條件的欄位
allowed := make(map[string]struct{})
// 收集 partition_key
for _, pk := range metadata.PartKey {
allowed[pk] = struct{}{}
}
// 收集所有 sai 欄位
for _, f := range reflect.VisibleFields(reflect.TypeOf(q.document)) {
if f.Tag.Get("sai") == "true" {
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
allowed[col] = struct{}{}
}
}
// 檢查 bindMap 有沒有 hit 到
hasCondition := false
for k := range q.bindMap {
if _, ok := allowed[k]; ok {
hasCondition = true
break
}
}
if !hasCondition {
return ErrMissingPartitionKey.WithTable(q.table).WithError(
fmt.Errorf("requires at least one partition_key or sai indexed field in WHERE clause"),
)
}
// 至少要有一個 set 欄位
if len(q.sets) == 0 {
return ErrNoFieldsToUpdate.WithTable(q.table)
}
// 至少一個 where
if len(q.cmps) == 0 {
return ErrMissingWhereCondition.WithTable(q.table)
}
// 組合 set 欄位
setCols := make([]string, 0, len(q.sets))
setVals := make([]any, 0, len(q.sets))
for _, s := range q.sets {
setCols = append(setCols, s.Col)
setVals = append(setVals, s.Val)
}
// 組合 CQL
builder := qb.Update(q.table).Set(setCols...)
builder = builder.Where(q.cmps...)
stmt, names := builder.ToCql()
// setVals 要先,剩下的 where bind 順序依照 names
bindVals := append([]any{}, setVals...)
for _, name := range names[len(setCols):] {
if v, ok := q.bindMap[name]; ok {
bindVals = append(bindVals, v)
}
}
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
if len(bindVals) > 0 {
query = query.Bind(bindVals...)
}
return query.ExecRelease()
}
// InsertOne 插入單筆資料
func (q *Query) InsertOne(data any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
tbl := table.New(metadata)
qry := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(tbl.Insert()))
switch reflect.TypeOf(data).Kind() {
case reflect.Map:
qry = qry.BindMap(data.(map[string]any))
default:
qry = qry.BindStruct(data)
}
return qry.ExecRelease()
}
func (q *Query) InsertMany(documents any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
v := reflect.ValueOf(documents)
if v.Kind() != reflect.Slice {
return fmt.Errorf("insert many: input must be a slice, got %T", documents)
}
if v.Len() == 0 {
return nil
}
for i := 0; i < v.Len(); i++ {
item := v.Index(i).Interface()
if err := q.InsertOne(item); err != nil {
return fmt.Errorf("insert many: failed at index %d (table: %s): %w", i, q.table, err)
}
}
return nil
}
// GetAll 查詢所有資料(不帶條件)
func (q *Query) GetAll(dest any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
t := table.New(metadata)
stmt, names := qb.Select(t.Name()).Columns(metadata.Columns...).ToCql()
exec := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
return exec.SelectRelease(dest)
}
// Count 計算符合條件的記錄數
func (q *Query) Count() (int64, error) {
if len(q.errs) > 0 {
return 0, errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return 0, err
}
t := table.New(metadata)
builder := qb.Select(t.Name()).Columns("COUNT(*)")
if len(q.cmps) > 0 {
builder = builder.Where(q.cmps...)
}
stmt, names := builder.ToCql()
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
var count int64
if err := query.GetRelease(&count); err != nil {
if err == gocql.ErrNotFound {
return 0, nil // COUNT 查詢不會返回 ErrNotFound但為了安全起見
}
return 0, err
}
return count, nil
}
func IsSAIField(model any, fieldName string) bool {
t := reflect.TypeOf(model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get("sai")
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
if (col == fieldName || f.Name == fieldName) && tag == "true" {
return true
}
}
return false
}