package cassandra import ( "context" "errors" "fmt" "reflect" "github.com/gocql/gocql" "github.com/scylladb/gocqlx/v3/qb" "github.com/scylladb/gocqlx/v3/table" ) func (db *CassandraDB) AutoCreateSAIIndexes(doc any, keyspace string) error { metadata, err := GenerateTableMetadata(doc, keyspace) if err != nil { return err } t := reflect.TypeOf(doc) if t.Kind() == reflect.Ptr { t = t.Elem() } for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Tag.Get("sai") == "true" { col := f.Tag.Get("db") if col == "" { col = toSnakeCase(f.Name) } stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS ON %s (%s) USING 'sai';", metadata.Name, col) if err := db.GetSession().ExecStmt(stmt); err != nil { return fmt.Errorf("failed to create SAI index on table %s, column %s: %w", metadata.Name, col, err) } } } return nil } type Query struct { db *CassandraDB ctx context.Context table string keyspace string columns []string cmps []qb.Cmp bindMap map[string]any orders []orderBy limit uint document any sets []setField // 欲更新欄位及其值 errs []error } type orderBy struct { Column string Order qb.Order } type setField struct { Col string Val any } // Model 創建一個新的查詢構建器 // document: 用於推斷表結構的範例物件(必須實現 TableName() 方法) // keyspace: 如果為空,則使用初始化時設定的預設 keyspace func (db *CassandraDB) Model(ctx context.Context, document any, keyspace string) *Query { keyspace = getKeyspace(db, keyspace) metadata, err := GenerateTableMetadata(document, keyspace) if err != nil { // 如果 metadata 生成失敗,創建一個帶錯誤的 Query return &Query{ db: db, ctx: ctx, keyspace: keyspace, document: document, errs: []error{err}, } } return &Query{ db: db, ctx: ctx, table: metadata.Name, keyspace: keyspace, columns: make([]string, 0), cmps: make([]qb.Cmp, 0), bindMap: make(map[string]any), orders: make([]orderBy, 0), limit: 0, document: document, // document 用於生成 metadata 和驗證 SAI 欄位 errs: make([]error, 0), } } // Where 添加 WHERE 條件 // 只允許 partition key 或有 sai index 的欄位進行 where 查詢 // cmp: 查詢條件(如 qb.Eq("id")) // args: 參數映射(如 map[string]any{"id": uuid}) func (q *Query) Where(cmp qb.Cmp, args map[string]any) *Query { // 如果之前有錯誤,直接返回 if len(q.errs) > 0 { return q } metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { q.errs = append(q.errs, err) return q } for k := range args { // 允許 partition_key 或 sai 欄位 isPartition := contains(metadata.PartKey, k) isSAI := IsSAIField(q.document, k) if !isPartition && !isSAI { q.errs = append(q.errs, NewError( "INVALID_WHERE_FIELD", fmt.Sprintf("where condition on field %s requires partition_key or sai index", k), ).WithTable(q.table)) } } q.cmps = append(q.cmps, cmp) for k, v := range args { q.bindMap[k] = v } return q } func (q *Query) Select(cols ...string) *Query { q.columns = append(q.columns, cols...) return q } func (q *Query) OrderBy(column string, order qb.Order) *Query { q.orders = append(q.orders, orderBy{Column: column, Order: order}) return q } func (q *Query) Limit(limit uint) *Query { q.limit = limit return q } func (q *Query) Set(col string, val any) *Query { q.sets = append(q.sets, setField{Col: col, Val: val}) return q } // Scan 執行查詢並將結果掃描到 dest // dest 必須是指標類型:*Struct 用於單筆查詢,*[]Struct 用於多筆查詢 func (q *Query) Scan(dest any) error { if len(q.errs) > 0 { return errors.Join(q.errs...) } metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return err } builder := qb.Select(q.table) if len(q.columns) > 0 { builder = builder.Columns(q.columns...) } else { // 如果沒有指定欄位,使用所有欄位 builder = builder.Columns(metadata.Columns...) } if len(q.cmps) > 0 { builder = builder.Where(q.cmps...) } if len(q.orders) > 0 { for _, o := range q.orders { builder = builder.OrderBy(o.Column, o.Order) } } if q.limit > 0 { builder = builder.Limit(q.limit) } stmt, names := builder.ToCql() query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names)) if q.bindMap == nil { q.bindMap = qb.M{} } query = query.BindMap(q.bindMap) // 型態判斷自動選用單筆/多筆查詢 destType := reflect.TypeOf(dest) if destType.Kind() != reflect.Ptr { return ErrInvalidInput.WithTable(q.table).WithError(fmt.Errorf("destination must be a pointer, got %T", dest)) } elemType := destType.Elem() switch elemType.Kind() { case reflect.Slice: return query.SelectRelease(dest) case reflect.Struct: err := query.GetRelease(dest) if err == gocql.ErrNotFound { return ErrNotFound.WithTable(q.table) } return err default: return ErrInvalidInput.WithTable(q.table).WithError(fmt.Errorf("destination must be pointer to struct or slice, got %T", dest)) } } func (q *Query) Take(dest any) error { q.limit = 1 return q.Scan(dest) } // Delete 執行刪除操作 // 要求:必須提供所有 partition keys 在 WHERE 條件中 func (q *Query) Delete() error { if len(q.errs) > 0 { return errors.Join(q.errs...) } metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return err } // 檢查是否提供所有 partition keys missingKeys := make([]string, 0) for _, pk := range metadata.PartKey { if _, ok := q.bindMap[pk]; !ok { missingKeys = append(missingKeys, pk) } } if len(missingKeys) > 0 { return ErrMissingPartitionKey.WithTable(q.table).WithError( fmt.Errorf("missing partition keys: %v", missingKeys), ) } if len(q.cmps) == 0 { return ErrMissingWhereCondition.WithTable(q.table) } // 組 Delete 語句 builder := qb.Delete(q.table) builder = builder.Where(q.cmps...) stmt, names := builder.ToCql() query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names)) if q.bindMap == nil { q.bindMap = qb.M{} } query = query.BindMap(q.bindMap) return query.ExecRelease() } // Update 執行更新操作 // 要求:必須提供至少一個 partition_key 或 sai indexed 欄位在 WHERE 條件中,且至少有一個 Set 欄位 func (q *Query) Update() error { if len(q.errs) > 0 { return errors.Join(q.errs...) } if q.document == nil { return ErrInvalidInput.WithTable(q.table).WithError( fmt.Errorf("update requires document model to check partition keys"), ) } metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return err } // 先收集所有可被當作主查詢條件的欄位 allowed := make(map[string]struct{}) // 收集 partition_key for _, pk := range metadata.PartKey { allowed[pk] = struct{}{} } // 收集所有 sai 欄位 for _, f := range reflect.VisibleFields(reflect.TypeOf(q.document)) { if f.Tag.Get("sai") == "true" { col := f.Tag.Get("db") if col == "" { col = toSnakeCase(f.Name) } allowed[col] = struct{}{} } } // 檢查 bindMap 有沒有 hit 到 hasCondition := false for k := range q.bindMap { if _, ok := allowed[k]; ok { hasCondition = true break } } if !hasCondition { return ErrMissingPartitionKey.WithTable(q.table).WithError( fmt.Errorf("requires at least one partition_key or sai indexed field in WHERE clause"), ) } // 至少要有一個 set 欄位 if len(q.sets) == 0 { return ErrNoFieldsToUpdate.WithTable(q.table) } // 至少一個 where if len(q.cmps) == 0 { return ErrMissingWhereCondition.WithTable(q.table) } // 組合 set 欄位 setCols := make([]string, 0, len(q.sets)) setVals := make([]any, 0, len(q.sets)) for _, s := range q.sets { setCols = append(setCols, s.Col) setVals = append(setVals, s.Val) } // 組合 CQL builder := qb.Update(q.table).Set(setCols...) builder = builder.Where(q.cmps...) stmt, names := builder.ToCql() // setVals 要先,剩下的 where bind 順序依照 names bindVals := append([]any{}, setVals...) for _, name := range names[len(setCols):] { if v, ok := q.bindMap[name]; ok { bindVals = append(bindVals, v) } } query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names)) if len(bindVals) > 0 { query = query.Bind(bindVals...) } return query.ExecRelease() } // InsertOne 插入單筆資料 func (q *Query) InsertOne(data any) error { if len(q.errs) > 0 { return errors.Join(q.errs...) } metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return err } tbl := table.New(metadata) qry := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(tbl.Insert())) switch reflect.TypeOf(data).Kind() { case reflect.Map: qry = qry.BindMap(data.(map[string]any)) default: qry = qry.BindStruct(data) } return qry.ExecRelease() } func (q *Query) InsertMany(documents any) error { if len(q.errs) > 0 { return errors.Join(q.errs...) } v := reflect.ValueOf(documents) if v.Kind() != reflect.Slice { return fmt.Errorf("insert many: input must be a slice, got %T", documents) } if v.Len() == 0 { return nil } for i := 0; i < v.Len(); i++ { item := v.Index(i).Interface() if err := q.InsertOne(item); err != nil { return fmt.Errorf("insert many: failed at index %d (table: %s): %w", i, q.table, err) } } return nil } // GetAll 查詢所有資料(不帶條件) func (q *Query) GetAll(dest any) error { if len(q.errs) > 0 { return errors.Join(q.errs...) } metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return err } t := table.New(metadata) stmt, names := qb.Select(t.Name()).Columns(metadata.Columns...).ToCql() exec := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names)) return exec.SelectRelease(dest) } // Count 計算符合條件的記錄數 func (q *Query) Count() (int64, error) { if len(q.errs) > 0 { return 0, errors.Join(q.errs...) } metadata, err := GenerateTableMetadata(q.document, q.keyspace) if err != nil { return 0, err } t := table.New(metadata) builder := qb.Select(t.Name()).Columns("COUNT(*)") if len(q.cmps) > 0 { builder = builder.Where(q.cmps...) } stmt, names := builder.ToCql() query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names)) if q.bindMap == nil { q.bindMap = qb.M{} } query = query.BindMap(q.bindMap) var count int64 if err := query.GetRelease(&count); err != nil { if err == gocql.ErrNotFound { return 0, nil // COUNT 查詢不會返回 ErrNotFound,但為了安全起見 } return 0, err } return count, nil } func IsSAIField(model any, fieldName string) bool { t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { t = t.Elem() } for i := 0; i < t.NumField(); i++ { f := t.Field(i) tag := f.Tag.Get("sai") col := f.Tag.Get("db") if col == "" { col = toSnakeCase(f.Name) } if (col == fieldName || f.Name == fieldName) && tag == "true" { return true } } return false }