blockchain/internal/lib/cassandra/table.go

399 lines
9.0 KiB
Go

package cassandra
import (
"context"
"errors"
"fmt"
"reflect"
"time"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/scylladb/gocqlx/v3/table"
)
func (db *CassandraDB) AutoCreateSAIIndexes(doc any, keyspace string) error {
metadata, err := GenerateTableMetadata(doc, keyspace)
if err != nil {
return err
}
t := reflect.TypeOf(doc)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Tag.Get("sai") == "true" {
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS ON %s (%s) USING 'sai';", metadata.Name, col)
if err := db.GetSession().ExecStmt(stmt); err != nil {
return fmt.Errorf("SAI index create fail: %w", err)
}
}
}
return nil
}
type Query struct {
db *CassandraDB
ctx context.Context
table string
keyspace string
columns []string
cmps []qb.Cmp
bindMap map[string]any
orders []orderBy
limit uint
document any
sets []setField // 欲更新欄位及其值
errs []error
}
type orderBy struct {
Column string
Order qb.Order
}
type setField struct {
Col string
Val any
}
func (db *CassandraDB) Model(ctx context.Context, document any, keyspace string) *Query {
metadata, _ := GenerateTableMetadata(document, keyspace)
return &Query{
db: db,
ctx: ctx,
table: metadata.Name,
keyspace: keyspace,
columns: make([]string, 0),
cmps: make([]qb.Cmp, 0),
bindMap: make(map[string]any),
orders: make([]orderBy, 0),
limit: 0,
document: document,
}
}
// Where 只允許 partition key 或有 sai index 的欄位進行 where 查詢
func (q *Query) Where(cmp qb.Cmp, args map[string]any) *Query {
metadata, _ := GenerateTableMetadata(q.document, q.keyspace)
for k := range args {
// 允許 partition_key 或 sai 欄位
isPartition := contains(metadata.PartKey, k)
isSAI := IsSAIField(q.document, k)
if !isPartition && !isSAI {
q.errs = append(q.errs, fmt.Errorf("field %s must be partition key or SAI index", k))
}
q.bindMap[k] = args[k]
}
q.cmps = append(q.cmps, cmp)
for k, v := range args {
q.bindMap[k] = v
}
return q
}
func (q *Query) Select(cols ...string) *Query {
q.columns = append(q.columns, cols...)
return q
}
func (q *Query) OrderBy(column string, order qb.Order) *Query {
q.orders = append(q.orders, orderBy{Column: column, Order: order})
return q
}
func (q *Query) Limit(limit uint) *Query {
q.limit = limit
return q
}
func (q *Query) Set(col string, val any) *Query {
q.sets = append(q.sets, setField{Col: col, Val: val})
return q
}
func (q *Query) Scan(dest any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
builder := qb.Select(q.table)
if len(q.columns) > 0 {
builder = builder.Columns(q.columns...)
}
if len(q.cmps) > 0 {
builder = builder.Where(q.cmps...)
}
if len(q.orders) > 0 {
for _, o := range q.orders {
builder = builder.OrderBy(o.Column, o.Order)
}
}
if q.limit > 0 {
builder = builder.Limit(q.limit)
}
stmt, names := builder.ToCql()
query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx)
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
// 型態判斷自動選用單筆/多筆查詢
destType := reflect.TypeOf(dest)
if destType.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
elemType := destType.Elem()
if elemType.Kind() == reflect.Slice {
return query.SelectRelease(dest) // 多筆
} else if elemType.Kind() == reflect.Struct {
return query.GetRelease(dest) // 單筆
} else {
return fmt.Errorf("dest must be pointer to struct or slice")
}
}
func (q *Query) Take(dest any) error {
q.limit = 1
return q.Scan(dest)
}
func (q *Query) Delete() error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
// 拿 partition key 清單
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
missingKeys := make([]string, 0)
for _, pk := range metadata.PartKey {
if _, ok := q.bindMap[pk]; !ok {
missingKeys = append(missingKeys, pk)
}
}
if len(missingKeys) > 0 {
return fmt.Errorf("delete operation requires all partition keys in WHERE: missing %v", missingKeys)
}
if len(q.cmps) == 0 {
return fmt.Errorf("delete operation requires at least one WHERE condition for safety")
}
// 組 Delete 語句
builder := qb.Delete(q.table)
builder = builder.Where(q.cmps...)
stmt, names := builder.ToCql()
query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx)
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
return query.ExecRelease()
}
func (q *Query) Update() error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
if q.document == nil {
return fmt.Errorf("update requires modelType to check partition keys")
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return fmt.Errorf("update: failed to get table metadata: %w", err)
}
// 先收集所有可被當作主查詢條件的欄位
allowed := make(map[string]struct{})
// 收集 partition_key
for _, pk := range metadata.PartKey {
allowed[pk] = struct{}{}
}
// 收集所有 sai 欄位
for _, f := range reflect.VisibleFields(reflect.TypeOf(q.document)) {
if f.Tag.Get("sai") == "true" {
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
allowed[col] = struct{}{}
}
}
// 檢查 bindMap 有沒有 hit 到
hasCondition := false
for k := range q.bindMap {
if _, ok := allowed[k]; ok {
hasCondition = true
break
}
}
if !hasCondition {
return fmt.Errorf("update/delete requires at least one partition_key or sai indexed field in WHERE")
}
// 至少要有一個 set 欄位
if len(q.sets) == 0 {
return fmt.Errorf("update requires at least one field to set")
}
// 至少一個 where
if len(q.cmps) == 0 {
return fmt.Errorf("update operation requires at least one WHERE condition for safety")
}
// 組合 set 欄位
setCols := make([]string, 0, len(q.sets))
setVals := make([]any, 0, len(q.sets))
for _, s := range q.sets {
setCols = append(setCols, s.Col)
setVals = append(setVals, s.Val)
}
// 組合 CQL
builder := qb.Update(q.table).Set(setCols...)
builder = builder.Where(q.cmps...)
stmt, names := builder.ToCql()
// setVals 要先,剩下的 where bind 順序依照 names
bindVals := append([]any{}, setVals...)
for _, name := range names[len(setCols):] {
if v, ok := q.bindMap[name]; ok {
bindVals = append(bindVals, v)
}
}
query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx)
if len(bindVals) > 0 {
query = query.Bind(bindVals...)
}
return query.ExecRelease()
}
func (q *Query) InsertOne(data any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
tbl := table.New(metadata)
qry := q.db.GetSession().Query(tbl.Insert()).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
switch reflect.TypeOf(data).Kind() {
case reflect.Map:
qry = qry.BindMap(data.(map[string]any))
default:
qry = qry.BindStruct(data)
}
return qry.ExecRelease()
}
func (q *Query) InsertMany(documents any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
v := reflect.ValueOf(documents)
if v.Kind() != reflect.Slice {
return fmt.Errorf("InsertMany: input must be a slice")
}
if v.Len() == 0 {
return nil
}
for i := 0; i < v.Len(); i++ {
item := v.Index(i).Interface()
if err := q.InsertOne(item); err != nil {
return fmt.Errorf("InsertMany: failed at idx %d: %w", i, err)
}
}
return nil
}
func (q *Query) GetAll(dest any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
t := table.New(metadata)
stmt, names := qb.Select(t.Name()).Columns(metadata.Columns...).ToCql()
exec := q.db.GetSession().Query(stmt, names).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
return exec.SelectRelease(dest)
}
func (q *Query) Count() (int64, error) {
if len(q.errs) > 0 {
return 0, errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return 0, err
}
t := table.New(metadata)
builder := qb.Select(t.Name()).Columns("COUNT(*)")
if len(q.cmps) > 0 {
builder = builder.Where(q.cmps...)
}
stmt, names := builder.ToCql()
query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
var count int64
if err := query.GetRelease(&count); err != nil {
return 0, err
}
return count, nil
}
func IsSAIField(model any, fieldName string) bool {
t := reflect.TypeOf(model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get("sai")
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
if (col == fieldName || f.Name == fieldName) && tag == "true" {
return true
}
}
return false
}