112 lines
2.8 KiB
Go
112 lines
2.8 KiB
Go
package cassandra
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/gocql/gocql"
|
|
"github.com/scylladb/gocqlx/v3"
|
|
"github.com/scylladb/gocqlx/v3/qb"
|
|
"github.com/scylladb/gocqlx/v3/table"
|
|
)
|
|
|
|
// TODO: 只保證同一個 PK 下有一致性,中間有失敗的話可能只有失敗不會寫入,其他成功的還是會成功。
|
|
// 之後會朝兩個方向走
|
|
// 1. 最終一致性:目前的設計是直接寫入副表,然後透過 background worker 讀取 sync_task 表,補寫副表資料。
|
|
// 2. 研究 自己做 TX_ID 以及 STATUS 的方案
|
|
// 這個是已知問題,一定要解決
|
|
|
|
func (db *CassandraDB) NewBatch(ctx context.Context, keyspace string) *Batch {
|
|
session := db.GetSession()
|
|
return &Batch{
|
|
ctx: ctx,
|
|
keyspace: keyspace,
|
|
db: db,
|
|
batch: gocqlx.Batch{
|
|
Batch: session.NewBatch(gocql.LoggedBatch).WithContext(ctx),
|
|
},
|
|
}
|
|
}
|
|
|
|
type Batch struct {
|
|
ctx context.Context
|
|
keyspace string
|
|
db *CassandraDB
|
|
batch gocqlx.Batch
|
|
}
|
|
|
|
func (tx *Batch) Insert(doc any) error {
|
|
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tbl := table.New(metadata)
|
|
stmt, names := tbl.Insert()
|
|
return tx.batch.BindStruct(tx.db.GetSession().Query(stmt, names), doc)
|
|
}
|
|
|
|
func (tx *Batch) Delete(doc any) error {
|
|
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tbl := table.New(metadata)
|
|
stmt, names := tbl.Delete()
|
|
return tx.batch.BindStruct(tx.db.GetSession().Query(stmt, names), doc)
|
|
}
|
|
|
|
func (tx *Batch) Update(doc any) error {
|
|
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
v := reflect.ValueOf(doc)
|
|
if v.Kind() == reflect.Ptr {
|
|
v = v.Elem()
|
|
}
|
|
typ := v.Type()
|
|
|
|
setCols := make([]string, 0)
|
|
setVals := make([]any, 0)
|
|
whereCols := make([]string, 0)
|
|
whereVals := make([]any, 0)
|
|
|
|
for i := 0; i < typ.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
tag := field.Tag.Get("db")
|
|
if tag == "" || tag == "-" {
|
|
continue
|
|
}
|
|
val := v.Field(i)
|
|
if !val.IsValid() {
|
|
continue
|
|
}
|
|
if contains(metadata.PartKey, tag) || contains(metadata.SortKey, tag) {
|
|
whereCols = append(whereCols, tag)
|
|
whereVals = append(whereVals, val.Interface())
|
|
} else if !isZero(val) {
|
|
setCols = append(setCols, tag)
|
|
setVals = append(setVals, val.Interface())
|
|
}
|
|
}
|
|
|
|
if len(setCols) == 0 {
|
|
return fmt.Errorf("update: no non-zero fields in %+v", doc)
|
|
}
|
|
|
|
builder := qb.Update(metadata.Name).Set(setCols...)
|
|
for _, col := range whereCols {
|
|
builder = builder.Where(qb.Eq(col))
|
|
}
|
|
stmt, names := builder.ToCql()
|
|
args := append(setVals, whereVals...)
|
|
return tx.batch.Bind(tx.db.GetSession().Query(stmt, names), args...)
|
|
}
|
|
|
|
func (tx *Batch) Commit() error {
|
|
session := tx.db.GetSession()
|
|
|
|
return session.ExecuteBatch(&tx.batch)
|
|
}
|