399 lines
9.0 KiB
Go
399 lines
9.0 KiB
Go
package cassandra
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/scylladb/gocqlx/v3/qb"
|
|
"github.com/scylladb/gocqlx/v3/table"
|
|
)
|
|
|
|
func (db *CassandraDB) AutoCreateSAIIndexes(doc any, keyspace string) error {
|
|
metadata, err := GenerateTableMetadata(doc, keyspace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t := reflect.TypeOf(doc)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
if f.Tag.Get("sai") == "true" {
|
|
col := f.Tag.Get("db")
|
|
if col == "" {
|
|
col = toSnakeCase(f.Name)
|
|
}
|
|
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS ON %s (%s) USING 'sai';", metadata.Name, col)
|
|
if err := db.GetSession().ExecStmt(stmt); err != nil {
|
|
return fmt.Errorf("SAI index create fail: %w", err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type Query struct {
|
|
db *CassandraDB
|
|
ctx context.Context
|
|
table string
|
|
keyspace string
|
|
columns []string
|
|
cmps []qb.Cmp
|
|
bindMap map[string]any
|
|
orders []orderBy
|
|
limit uint
|
|
document any
|
|
sets []setField // 欲更新欄位及其值
|
|
errs []error
|
|
}
|
|
|
|
type orderBy struct {
|
|
Column string
|
|
Order qb.Order
|
|
}
|
|
|
|
type setField struct {
|
|
Col string
|
|
Val any
|
|
}
|
|
|
|
func (db *CassandraDB) Model(ctx context.Context, document any, keyspace string) *Query {
|
|
metadata, _ := GenerateTableMetadata(document, keyspace)
|
|
|
|
return &Query{
|
|
db: db,
|
|
ctx: ctx,
|
|
table: metadata.Name,
|
|
keyspace: keyspace,
|
|
columns: make([]string, 0),
|
|
cmps: make([]qb.Cmp, 0),
|
|
bindMap: make(map[string]any),
|
|
orders: make([]orderBy, 0),
|
|
limit: 0,
|
|
document: document,
|
|
}
|
|
}
|
|
|
|
// Where 只允許 partition key 或有 sai index 的欄位進行 where 查詢
|
|
func (q *Query) Where(cmp qb.Cmp, args map[string]any) *Query {
|
|
metadata, _ := GenerateTableMetadata(q.document, q.keyspace)
|
|
for k := range args {
|
|
// 允許 partition_key 或 sai 欄位
|
|
isPartition := contains(metadata.PartKey, k)
|
|
isSAI := IsSAIField(q.document, k)
|
|
if !isPartition && !isSAI {
|
|
q.errs = append(q.errs, fmt.Errorf("field %s must be partition key or SAI index", k))
|
|
}
|
|
q.bindMap[k] = args[k]
|
|
}
|
|
q.cmps = append(q.cmps, cmp)
|
|
for k, v := range args {
|
|
q.bindMap[k] = v
|
|
}
|
|
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Select(cols ...string) *Query {
|
|
q.columns = append(q.columns, cols...)
|
|
|
|
return q
|
|
}
|
|
|
|
func (q *Query) OrderBy(column string, order qb.Order) *Query {
|
|
q.orders = append(q.orders, orderBy{Column: column, Order: order})
|
|
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Limit(limit uint) *Query {
|
|
q.limit = limit
|
|
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Set(col string, val any) *Query {
|
|
q.sets = append(q.sets, setField{Col: col, Val: val})
|
|
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Scan(dest any) error {
|
|
if len(q.errs) > 0 {
|
|
return errors.Join(q.errs...)
|
|
}
|
|
|
|
builder := qb.Select(q.table)
|
|
if len(q.columns) > 0 {
|
|
builder = builder.Columns(q.columns...)
|
|
}
|
|
if len(q.cmps) > 0 {
|
|
builder = builder.Where(q.cmps...)
|
|
}
|
|
if len(q.orders) > 0 {
|
|
for _, o := range q.orders {
|
|
builder = builder.OrderBy(o.Column, o.Order)
|
|
}
|
|
}
|
|
if q.limit > 0 {
|
|
builder = builder.Limit(q.limit)
|
|
}
|
|
|
|
stmt, names := builder.ToCql()
|
|
query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx)
|
|
if q.bindMap == nil {
|
|
q.bindMap = qb.M{}
|
|
}
|
|
query = query.BindMap(q.bindMap)
|
|
|
|
// 型態判斷自動選用單筆/多筆查詢
|
|
destType := reflect.TypeOf(dest)
|
|
if destType.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("dest must be a pointer")
|
|
}
|
|
elemType := destType.Elem()
|
|
if elemType.Kind() == reflect.Slice {
|
|
return query.SelectRelease(dest) // 多筆
|
|
} else if elemType.Kind() == reflect.Struct {
|
|
return query.GetRelease(dest) // 單筆
|
|
} else {
|
|
return fmt.Errorf("dest must be pointer to struct or slice")
|
|
}
|
|
}
|
|
|
|
func (q *Query) Take(dest any) error {
|
|
q.limit = 1
|
|
|
|
return q.Scan(dest)
|
|
}
|
|
|
|
func (q *Query) Delete() error {
|
|
if len(q.errs) > 0 {
|
|
return errors.Join(q.errs...)
|
|
}
|
|
// 拿 partition key 清單
|
|
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
missingKeys := make([]string, 0)
|
|
for _, pk := range metadata.PartKey {
|
|
if _, ok := q.bindMap[pk]; !ok {
|
|
missingKeys = append(missingKeys, pk)
|
|
}
|
|
}
|
|
if len(missingKeys) > 0 {
|
|
return fmt.Errorf("delete operation requires all partition keys in WHERE: missing %v", missingKeys)
|
|
}
|
|
if len(q.cmps) == 0 {
|
|
return fmt.Errorf("delete operation requires at least one WHERE condition for safety")
|
|
}
|
|
|
|
// 組 Delete 語句
|
|
builder := qb.Delete(q.table)
|
|
builder = builder.Where(q.cmps...)
|
|
stmt, names := builder.ToCql()
|
|
query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx)
|
|
if q.bindMap == nil {
|
|
q.bindMap = qb.M{}
|
|
}
|
|
query = query.BindMap(q.bindMap)
|
|
|
|
return query.ExecRelease()
|
|
}
|
|
|
|
func (q *Query) Update() error {
|
|
if len(q.errs) > 0 {
|
|
return errors.Join(q.errs...)
|
|
}
|
|
|
|
if q.document == nil {
|
|
return fmt.Errorf("update requires modelType to check partition keys")
|
|
}
|
|
|
|
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
|
if err != nil {
|
|
return fmt.Errorf("update: failed to get table metadata: %w", err)
|
|
}
|
|
|
|
// 先收集所有可被當作主查詢條件的欄位
|
|
allowed := make(map[string]struct{})
|
|
|
|
// 收集 partition_key
|
|
for _, pk := range metadata.PartKey {
|
|
allowed[pk] = struct{}{}
|
|
}
|
|
|
|
// 收集所有 sai 欄位
|
|
for _, f := range reflect.VisibleFields(reflect.TypeOf(q.document)) {
|
|
if f.Tag.Get("sai") == "true" {
|
|
col := f.Tag.Get("db")
|
|
if col == "" {
|
|
col = toSnakeCase(f.Name)
|
|
}
|
|
allowed[col] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// 檢查 bindMap 有沒有 hit 到
|
|
hasCondition := false
|
|
for k := range q.bindMap {
|
|
if _, ok := allowed[k]; ok {
|
|
hasCondition = true
|
|
break
|
|
}
|
|
}
|
|
if !hasCondition {
|
|
return fmt.Errorf("update/delete requires at least one partition_key or sai indexed field in WHERE")
|
|
}
|
|
|
|
// 至少要有一個 set 欄位
|
|
if len(q.sets) == 0 {
|
|
return fmt.Errorf("update requires at least one field to set")
|
|
}
|
|
// 至少一個 where
|
|
if len(q.cmps) == 0 {
|
|
return fmt.Errorf("update operation requires at least one WHERE condition for safety")
|
|
}
|
|
|
|
// 組合 set 欄位
|
|
setCols := make([]string, 0, len(q.sets))
|
|
setVals := make([]any, 0, len(q.sets))
|
|
for _, s := range q.sets {
|
|
setCols = append(setCols, s.Col)
|
|
setVals = append(setVals, s.Val)
|
|
}
|
|
|
|
// 組合 CQL
|
|
builder := qb.Update(q.table).Set(setCols...)
|
|
builder = builder.Where(q.cmps...)
|
|
stmt, names := builder.ToCql()
|
|
|
|
// setVals 要先,剩下的 where bind 順序依照 names
|
|
bindVals := append([]any{}, setVals...)
|
|
for _, name := range names[len(setCols):] {
|
|
if v, ok := q.bindMap[name]; ok {
|
|
bindVals = append(bindVals, v)
|
|
}
|
|
}
|
|
|
|
query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx)
|
|
if len(bindVals) > 0 {
|
|
query = query.Bind(bindVals...)
|
|
}
|
|
return query.ExecRelease()
|
|
}
|
|
|
|
func (q *Query) InsertOne(data any) error {
|
|
if len(q.errs) > 0 {
|
|
return errors.Join(q.errs...)
|
|
}
|
|
|
|
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tbl := table.New(metadata)
|
|
qry := q.db.GetSession().Query(tbl.Insert()).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
|
|
|
|
switch reflect.TypeOf(data).Kind() {
|
|
case reflect.Map:
|
|
qry = qry.BindMap(data.(map[string]any))
|
|
default:
|
|
qry = qry.BindStruct(data)
|
|
}
|
|
return qry.ExecRelease()
|
|
}
|
|
|
|
func (q *Query) InsertMany(documents any) error {
|
|
if len(q.errs) > 0 {
|
|
return errors.Join(q.errs...)
|
|
}
|
|
|
|
v := reflect.ValueOf(documents)
|
|
if v.Kind() != reflect.Slice {
|
|
return fmt.Errorf("InsertMany: input must be a slice")
|
|
}
|
|
if v.Len() == 0 {
|
|
return nil
|
|
}
|
|
|
|
for i := 0; i < v.Len(); i++ {
|
|
item := v.Index(i).Interface()
|
|
if err := q.InsertOne(item); err != nil {
|
|
return fmt.Errorf("InsertMany: failed at idx %d: %w", i, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (q *Query) GetAll(dest any) error {
|
|
if len(q.errs) > 0 {
|
|
return errors.Join(q.errs...)
|
|
}
|
|
|
|
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t := table.New(metadata)
|
|
|
|
stmt, names := qb.Select(t.Name()).Columns(metadata.Columns...).ToCql()
|
|
exec := q.db.GetSession().Query(stmt, names).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
|
|
|
|
return exec.SelectRelease(dest)
|
|
}
|
|
|
|
func (q *Query) Count() (int64, error) {
|
|
if len(q.errs) > 0 {
|
|
return 0, errors.Join(q.errs...)
|
|
}
|
|
|
|
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
t := table.New(metadata)
|
|
builder := qb.Select(t.Name()).Columns("COUNT(*)")
|
|
if len(q.cmps) > 0 {
|
|
builder = builder.Where(q.cmps...)
|
|
}
|
|
|
|
stmt, names := builder.ToCql()
|
|
query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
|
|
if q.bindMap == nil {
|
|
q.bindMap = qb.M{}
|
|
}
|
|
query = query.BindMap(q.bindMap)
|
|
|
|
var count int64
|
|
if err := query.GetRelease(&count); err != nil {
|
|
return 0, err
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func IsSAIField(model any, fieldName string) bool {
|
|
t := reflect.TypeOf(model)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
tag := f.Tag.Get("sai")
|
|
col := f.Tag.Get("db")
|
|
if col == "" {
|
|
col = toSnakeCase(f.Name)
|
|
}
|
|
if (col == fieldName || f.Name == fieldName) && tag == "true" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|