package cassandra import ( "context" "fmt" "reflect" "github.com/gocql/gocql" "github.com/scylladb/gocqlx/v3" "github.com/scylladb/gocqlx/v3/qb" "github.com/scylladb/gocqlx/v3/table" ) // TODO: 只保證同一個 PK 下有一致性,中間有失敗的話可能只有失敗不會寫入,其他成功的還是會成功。 // 之後會朝兩個方向走 // 1. 最終一致性:目前的設計是直接寫入副表,然後透過 background worker 讀取 sync_task 表,補寫副表資料。 // 2. 研究 自己做 TX_ID 以及 STATUS 的方案 // 這個是已知問題,一定要解決 func (db *DB) NewBatch(ctx context.Context, keyspace string) *Batch { session := db.GetSession() return &Batch{ ctx: ctx, keyspace: keyspace, db: db, batch: gocqlx.Batch{ Batch: session.NewBatch(gocql.LoggedBatch).WithContext(ctx), }, } } type Batch struct { ctx context.Context keyspace string db *DB batch gocqlx.Batch } func (tx *Batch) Insert(doc any) error { metadata, err := GenerateTableMetadata(doc, tx.keyspace) if err != nil { return err } tbl := table.New(metadata) stmt, names := tbl.Insert() return tx.batch.BindStruct(tx.db.GetSession().Query(stmt, names), doc) } func (tx *Batch) Delete(doc any) error { metadata, err := GenerateTableMetadata(doc, tx.keyspace) if err != nil { return err } tbl := table.New(metadata) stmt, names := tbl.Delete() return tx.batch.BindStruct(tx.db.GetSession().Query(stmt, names), doc) } func (tx *Batch) Update(doc any) error { metadata, err := GenerateTableMetadata(doc, tx.keyspace) if err != nil { return err } v := reflect.ValueOf(doc) if v.Kind() == reflect.Ptr { v = v.Elem() } typ := v.Type() setCols := make([]string, 0) setVals := make([]any, 0) whereCols := make([]string, 0) whereVals := make([]any, 0) for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) tag := field.Tag.Get("db") if tag == "" || tag == "-" { continue } val := v.Field(i) if !val.IsValid() { continue } if contains(metadata.PartKey, tag) || contains(metadata.SortKey, tag) { whereCols = append(whereCols, tag) whereVals = append(whereVals, val.Interface()) } else if !isZero(val) { setCols = append(setCols, tag) setVals = append(setVals, val.Interface()) } } if len(setCols) == 0 { return fmt.Errorf("update: no non-zero fields in %+v", doc) } builder := qb.Update(metadata.Name).Set(setCols...) for _, col := range whereCols { builder = builder.Where(qb.Eq(col)) } stmt, names := builder.ToCql() args := append(setVals, whereVals...) return tx.batch.Bind(tx.db.GetSession().Query(stmt, names), args...) } func (tx *Batch) Commit() error { session := tx.db.GetSession() return session.ExecuteBatch(&tx.batch) }