blockchain/internal/lib/cassandra/batch.go

112 lines
2.8 KiB
Go

package cassandra
import (
"context"
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/scylladb/gocqlx/v3/table"
)
// TODO: 只保證同一個 PK 下有一致性,中間有失敗的話可能只有失敗不會寫入,其他成功的還是會成功。
// 之後會朝兩個方向走
// 1. 最終一致性:目前的設計是直接寫入副表,然後透過 background worker 讀取 sync_task 表,補寫副表資料。
// 2. 研究 自己做 TX_ID 以及 STATUS 的方案
// 這個是已知問題,一定要解決
func (db *CassandraDB) NewBatch(ctx context.Context, keyspace string) *Batch {
session := db.GetSession()
return &Batch{
ctx: ctx,
keyspace: keyspace,
db: db,
batch: gocqlx.Batch{
Batch: session.NewBatch(gocql.LoggedBatch).WithContext(ctx),
},
}
}
type Batch struct {
ctx context.Context
keyspace string
db *CassandraDB
batch gocqlx.Batch
}
func (tx *Batch) Insert(doc any) error {
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
if err != nil {
return err
}
tbl := table.New(metadata)
stmt, names := tbl.Insert()
return tx.batch.BindStruct(tx.db.GetSession().Query(stmt, names), doc)
}
func (tx *Batch) Delete(doc any) error {
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
if err != nil {
return err
}
tbl := table.New(metadata)
stmt, names := tbl.Delete()
return tx.batch.BindStruct(tx.db.GetSession().Query(stmt, names), doc)
}
func (tx *Batch) Update(doc any) error {
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
if err != nil {
return err
}
v := reflect.ValueOf(doc)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
typ := v.Type()
setCols := make([]string, 0)
setVals := make([]any, 0)
whereCols := make([]string, 0)
whereVals := make([]any, 0)
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
val := v.Field(i)
if !val.IsValid() {
continue
}
if contains(metadata.PartKey, tag) || contains(metadata.SortKey, tag) {
whereCols = append(whereCols, tag)
whereVals = append(whereVals, val.Interface())
} else if !isZero(val) {
setCols = append(setCols, tag)
setVals = append(setVals, val.Interface())
}
}
if len(setCols) == 0 {
return fmt.Errorf("update: no non-zero fields in %+v", doc)
}
builder := qb.Update(metadata.Name).Set(setCols...)
for _, col := range whereCols {
builder = builder.Where(qb.Eq(col))
}
stmt, names := builder.ToCql()
args := append(setVals, whereVals...)
return tx.batch.Bind(tx.db.GetSession().Query(stmt, names), args...)
}
func (tx *Batch) Commit() error {
session := tx.db.GetSession()
return session.ExecuteBatch(&tx.batch)
}