package cassandra import ( "context" "fmt" "github.com/scylladb/gocqlx/v3/qb" "github.com/scylladb/gocqlx/v3/table" "reflect" "time" ) func (db *DB) AutoCreateSAIIndexes(doc any, keyspace string) error { metadata, err := GenerateTableMetadata(doc, keyspace) if err != nil { return err } t := reflect.TypeOf(doc) if t.Kind() == reflect.Ptr { t = t.Elem() } for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Tag.Get("sai") == "true" { col := f.Tag.Get("cql") if col == "" { col = toSnakeCase(f.Name) } stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS ON %s (%s) USING 'sai';", metadata.Name, col) if err := db.GetSession().ExecStmt(stmt); err != nil { return fmt.Errorf("SAI index create fail: %w", err) } } } return nil } type Query struct { db *DB ctx context.Context table string keyspace string columns []string cmps []qb.Cmp bindMap map[string]any orders []orderBy limit uint document any sets []setField // 欲更新欄位及其值 } type orderBy struct { Column string Order qb.Order } type setField struct { Col string Val any } func (db *DB) Model(ctx context.Context, document any, keyspace string) *Query { metadata, _ := GenerateTableMetadata(document, keyspace) return &Query{ db: db, ctx: ctx, table: metadata.Name, keyspace: keyspace, columns: make([]string, 0), cmps: make([]qb.Cmp, 0), bindMap: make(map[string]any), orders: make([]orderBy, 0), limit: 0, document: document, } } // Where 只允許 partition key 或有 sai index 的欄位進行 where 查詢 func (q *Query) Where(cmp qb.Cmp, args map[string]any) *Query { metadata, _ := GenerateTableMetadata(q.document, q.keyspace) for k := range args { // 允許 partition_key 或 sai 欄位 isPartition := contains(metadata.PartKey, k) isSAI := IsSAIField(q.document, k) if !isPartition && !isSAI { panic(fmt.Sprintf("field %s must be partition key or SAI index", k)) } q.bindMap[k] = args[k] } q.cmps = append(q.cmps, cmp) for k, v := range args { q.bindMap[k] = v } return q } func (q *Query) Select(cols ...string) *Query { q.columns = append(q.columns, cols...) return q } func (q *Query) OrderBy(column string, order qb.Order) *Query { q.orders = append(q.orders, orderBy{Column: column, Order: order}) return q } func (q *Query) Limit(limit uint) *Query { q.limit = limit return q } func (q *Query) Set(col string, val any) *Query { q.sets = append(q.sets, setField{Col: col, Val: val}) return q } func (q *Query) Scan(dest any) error { builder := qb.Select(q.table) if len(q.columns) > 0 { builder = builder.Columns(q.columns...) } if len(q.cmps) > 0 { builder = builder.Where(q.cmps...) } if len(q.orders) > 0 { for _, o := range q.orders { builder = builder.OrderBy(o.Column, o.Order) } } if q.limit > 0 { builder = builder.Limit(q.limit) } stmt, names := builder.ToCql() query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx) if q.bindMap == nil { q.bindMap = qb.M{} } query = query.BindMap(q.bindMap) return query.SelectRelease(dest) } func (q *Query) Take(dest any) error { q.limit = 1 return q.Scan(dest) } func (q *Query) Delete() error { // 拿 partition key 清單 metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return err } missingKeys := make([]string, 0) for _, pk := range metadata.PartKey { if _, ok := q.bindMap[pk]; !ok { missingKeys = append(missingKeys, pk) } } if len(missingKeys) > 0 { return fmt.Errorf("delete operation requires all partition keys in WHERE: missing %v", missingKeys) } if len(q.cmps) == 0 { return fmt.Errorf("delete operation requires at least one WHERE condition for safety") } // 組 Delete 語句 builder := qb.Delete(q.table) builder = builder.Where(q.cmps...) stmt, names := builder.ToCql() query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx) if q.bindMap == nil { q.bindMap = qb.M{} } query = query.BindMap(q.bindMap) return query.ExecRelease() } func (q *Query) Update() error { if q.document == nil { return fmt.Errorf("update requires modelType to check partition keys") } metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return fmt.Errorf("update: failed to get table metadata: %w", err) } // 檢查 partition key 是否都在 bindMap missingKeys := make([]string, 0) for _, pk := range metadata.PartKey { if _, ok := q.bindMap[pk]; !ok { missingKeys = append(missingKeys, pk) } } if len(missingKeys) > 0 { return fmt.Errorf("update operation requires all partition keys in WHERE: missing %v", missingKeys) } // 至少要有一個 set 欄位 if len(q.sets) == 0 { return fmt.Errorf("update requires at least one field to set") } // 至少一個 where if len(q.cmps) == 0 { return fmt.Errorf("update operation requires at least one WHERE condition for safety") } // 組合 set 欄位 setCols := make([]string, 0, len(q.sets)) setVals := make([]any, 0, len(q.sets)) for _, s := range q.sets { setCols = append(setCols, s.Col) setVals = append(setVals, s.Val) } // 組合 CQL builder := qb.Update(q.table).Set(setCols...) builder = builder.Where(q.cmps...) stmt, names := builder.ToCql() // setVals 要先,剩下的 where bind 順序依照 names bindVals := append([]any{}, setVals...) for _, name := range names[len(setCols):] { if v, ok := q.bindMap[name]; ok { bindVals = append(bindVals, v) } } query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx) if len(bindVals) > 0 { query = query.Bind(bindVals...) } return query.ExecRelease() } func (q *Query) InsertOne(data any) error { metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return err } tbl := table.New(metadata) qry := q.db.GetSession().Query(tbl.Insert()).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3) switch reflect.TypeOf(data).Kind() { case reflect.Map: qry = qry.BindMap(data.(map[string]any)) default: qry = qry.BindStruct(data) } return qry.ExecRelease() } func (q *Query) InsertMany(documents any) error { v := reflect.ValueOf(documents) if v.Kind() != reflect.Slice { return fmt.Errorf("InsertMany: input must be a slice") } if v.Len() == 0 { return nil } for i := 0; i < v.Len(); i++ { item := v.Index(i).Interface() if err := q.InsertOne(item); err != nil { return fmt.Errorf("InsertMany: failed at idx %d: %w", i, err) } } return nil } func (q *Query) GetAll(dest any) error { metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return err } t := table.New(metadata) stmt, names := qb.Select(t.Name()).Columns(metadata.Columns...).ToCql() exec := q.db.GetSession().Query(stmt, names).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3) return exec.SelectRelease(dest) } func (q *Query) Count() (int64, error) { metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return 0, err } t := table.New(metadata) builder := qb.Select(t.Name()).Columns("COUNT(*)") if len(q.cmps) > 0 { builder = builder.Where(q.cmps...) } stmt, names := builder.ToCql() query := q.db.GetSession().Query(stmt, names).WithContext(q.ctx).WithTimestamp(time.Now().UnixNano() / 1e3) if q.bindMap == nil { q.bindMap = qb.M{} } query = query.BindMap(q.bindMap) var count int64 if err := query.GetRelease(&count); err != nil { return 0, err } return count, nil } func IsSAIField(model any, fieldName string) bool { t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { t = t.Elem() } for i := 0; i < t.NumField(); i++ { f := t.Field(i) tag := f.Tag.Get("sai") col := f.Tag.Get("db") if col == "" { col = toSnakeCase(f.Name) } if (col == fieldName || f.Name == fieldName) && tag == "true" { return true } } return false }