266 lines
7.1 KiB
Go
266 lines
7.1 KiB
Go
package cassandra
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"reflect"
|
||
|
||
"github.com/gocql/gocql"
|
||
"github.com/scylladb/gocqlx/v2"
|
||
"github.com/scylladb/gocqlx/v2/qb"
|
||
"github.com/scylladb/gocqlx/v2/table"
|
||
)
|
||
|
||
// Repository 定義資料存取介面(小介面,符合 M3)
|
||
type Repository[T Table] interface {
|
||
Insert(ctx context.Context, doc T) error
|
||
Get(ctx context.Context, pk any) (T, error)
|
||
Update(ctx context.Context, doc T) error
|
||
Delete(ctx context.Context, pk any) error
|
||
InsertMany(ctx context.Context, docs []T) error
|
||
Query() QueryBuilder[T]
|
||
TryLock(ctx context.Context, doc T, opts ...LockOption) error
|
||
UnLock(ctx context.Context, doc T) error
|
||
}
|
||
|
||
// repository 是 Repository 的具體實作
|
||
type repository[T Table] struct {
|
||
db *DB
|
||
keyspace string
|
||
table string
|
||
metadata table.Metadata
|
||
}
|
||
|
||
// NewRepository 獲取指定類型的 Repository
|
||
// keyspace 如果為空,使用預設 keyspace
|
||
func NewRepository[T Table](db *DB, keyspace string) (Repository[T], error) {
|
||
if keyspace == "" {
|
||
keyspace = db.defaultKeyspace
|
||
}
|
||
|
||
var zero T
|
||
metadata, err := generateMetadata(zero, keyspace)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to generate metadata: %w", err)
|
||
}
|
||
|
||
return &repository[T]{
|
||
db: db,
|
||
keyspace: keyspace,
|
||
table: metadata.Name,
|
||
metadata: metadata,
|
||
}, nil
|
||
}
|
||
|
||
// Insert 插入單筆資料
|
||
func (r *repository[T]) Insert(ctx context.Context, doc T) error {
|
||
t := table.New(r.metadata)
|
||
q := r.db.withContextAndTimestamp(ctx,
|
||
r.db.session.Query(t.Insert()).BindStruct(doc))
|
||
return q.ExecRelease()
|
||
}
|
||
|
||
// Get 根據主鍵查詢單筆資料
|
||
// 注意:pk 必須是完整的 Primary Key(包含所有 Partition Key 和 Clustering Key)
|
||
// 如果主鍵是多欄位,需要傳入包含所有主鍵欄位的 struct
|
||
// pk 可以是:string, int, int64, gocql.UUID, []byte 或包含主鍵欄位的 struct
|
||
func (r *repository[T]) Get(ctx context.Context, pk any) (T, error) {
|
||
var zero T
|
||
t := table.New(r.metadata)
|
||
|
||
// 使用 table.Get() 方法,它會自動根據 metadata 構建主鍵查詢
|
||
// 如果 pk 是 struct,使用 BindStruct;否則使用 Bind
|
||
var q *gocqlx.Queryx
|
||
if reflect.TypeOf(pk).Kind() == reflect.Struct {
|
||
q = r.db.withContextAndTimestamp(ctx,
|
||
r.db.session.Query(t.Get()).BindStruct(pk))
|
||
} else {
|
||
// 單一主鍵欄位的情況
|
||
// 注意:這只適用於單一 Partition Key 且無 Clustering Key 的情況
|
||
if len(r.metadata.PartKey) != 1 || len(r.metadata.SortKey) > 0 {
|
||
return zero, ErrInvalidInput.WithTable(r.table).WithError(
|
||
fmt.Errorf("single value primary key only supported for single partition key without clustering key"),
|
||
)
|
||
}
|
||
q = r.db.withContextAndTimestamp(ctx,
|
||
r.db.session.Query(t.Get()).Bind(pk))
|
||
}
|
||
|
||
var result T
|
||
err := q.GetRelease(&result)
|
||
if errors.Is(err, gocql.ErrNotFound) {
|
||
return zero, ErrNotFound.WithTable(r.table)
|
||
}
|
||
if err != nil {
|
||
return zero, ErrInvalidInput.WithTable(r.table).WithError(err)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// Update 更新資料(只更新非零值欄位)
|
||
func (r *repository[T]) Update(ctx context.Context, doc T) error {
|
||
return r.updateSelective(ctx, doc, false)
|
||
}
|
||
|
||
// UpdateAll 更新所有欄位(包括零值)
|
||
func (r *repository[T]) UpdateAll(ctx context.Context, doc T) error {
|
||
return r.updateSelective(ctx, doc, true)
|
||
}
|
||
|
||
// updateSelective 選擇性更新
|
||
func (r *repository[T]) updateSelective(ctx context.Context, doc T, includeZero bool) error {
|
||
// 重用現有的 BuildUpdateFields 邏輯
|
||
// 由於在不同套件,我們需要重新實作或導入
|
||
fields, err := r.buildUpdateFields(doc, includeZero)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
stmt, names := r.buildUpdateStatement(fields.setCols, fields.whereCols)
|
||
setVals := append(fields.setVals, fields.whereVals...)
|
||
q := r.db.withContextAndTimestamp(ctx,
|
||
r.db.session.Query(stmt, names).Bind(setVals...))
|
||
|
||
return q.ExecRelease()
|
||
}
|
||
|
||
// Delete 刪除資料
|
||
// pk 可以是:string, int, int64, gocql.UUID, []byte 或包含主鍵欄位的 struct
|
||
func (r *repository[T]) Delete(ctx context.Context, pk any) error {
|
||
t := table.New(r.metadata)
|
||
stmt, names := t.Delete()
|
||
q := r.db.withContextAndTimestamp(ctx,
|
||
r.db.session.Query(stmt, names).Bind(pk))
|
||
return q.ExecRelease()
|
||
}
|
||
|
||
// InsertMany 批次插入資料
|
||
func (r *repository[T]) InsertMany(ctx context.Context, docs []T) error {
|
||
if len(docs) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 使用 Batch 操作
|
||
batch := r.db.session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
|
||
t := table.New(r.metadata)
|
||
stmt, names := t.Insert()
|
||
|
||
for _, doc := range docs {
|
||
// 在 v2 中,需要手動提取值
|
||
v := reflect.ValueOf(doc)
|
||
if v.Kind() == reflect.Ptr {
|
||
v = v.Elem()
|
||
}
|
||
values := make([]interface{}, len(names))
|
||
for i, name := range names {
|
||
// 根據 metadata 找到對應的欄位
|
||
for j, col := range r.metadata.Columns {
|
||
if col == name {
|
||
fieldValue := v.Field(j)
|
||
values[i] = fieldValue.Interface()
|
||
break
|
||
}
|
||
}
|
||
}
|
||
batch.Query(stmt, values...)
|
||
}
|
||
|
||
return r.db.session.ExecuteBatch(batch)
|
||
}
|
||
|
||
// Query 返回查詢構建器
|
||
func (r *repository[T]) Query() QueryBuilder[T] {
|
||
return newQueryBuilder(r)
|
||
}
|
||
|
||
// updateFields 包含更新操作所需的欄位資訊
|
||
type updateFields struct {
|
||
setCols []string
|
||
setVals []any
|
||
whereCols []string
|
||
whereVals []any
|
||
}
|
||
|
||
// buildUpdateFields 從 document 中提取更新所需的欄位資訊
|
||
func (r *repository[T]) buildUpdateFields(doc T, includeZero bool) (*updateFields, error) {
|
||
v := reflect.ValueOf(doc)
|
||
if v.Kind() == reflect.Ptr {
|
||
v = v.Elem()
|
||
}
|
||
typ := v.Type()
|
||
|
||
setCols := make([]string, 0)
|
||
setVals := make([]any, 0)
|
||
whereCols := make([]string, 0)
|
||
whereVals := make([]any, 0)
|
||
|
||
for i := 0; i < typ.NumField(); i++ {
|
||
field := typ.Field(i)
|
||
tag := field.Tag.Get(DBFiledName)
|
||
if tag == "" || tag == "-" {
|
||
continue
|
||
}
|
||
|
||
val := v.Field(i)
|
||
if !val.IsValid() {
|
||
continue
|
||
}
|
||
|
||
// 主鍵欄位放入 WHERE 條件
|
||
if contains(r.metadata.PartKey, tag) || contains(r.metadata.SortKey, tag) {
|
||
whereCols = append(whereCols, tag)
|
||
whereVals = append(whereVals, val.Interface())
|
||
continue
|
||
}
|
||
|
||
// 根據 includeZero 決定是否包含零值欄位
|
||
if !includeZero && isZero(val) {
|
||
continue
|
||
}
|
||
|
||
setCols = append(setCols, tag)
|
||
setVals = append(setVals, val.Interface())
|
||
}
|
||
|
||
if len(setCols) == 0 {
|
||
return nil, ErrNoFieldsToUpdate.WithTable(r.table)
|
||
}
|
||
|
||
return &updateFields{
|
||
setCols: setCols,
|
||
setVals: setVals,
|
||
whereCols: whereCols,
|
||
whereVals: whereVals,
|
||
}, nil
|
||
}
|
||
|
||
// buildUpdateStatement 構建 UPDATE CQL 語句
|
||
func (r *repository[T]) buildUpdateStatement(setCols, whereCols []string) (string, []string) {
|
||
builder := qb.Update(r.table).Set(setCols...)
|
||
for _, col := range whereCols {
|
||
builder = builder.Where(qb.Eq(col))
|
||
}
|
||
return builder.ToCql()
|
||
}
|
||
|
||
// contains 判斷字串是否存在於 slice 中
|
||
func contains(list []string, target string) bool {
|
||
for _, item := range list {
|
||
if item == target {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// isZero 判斷欄位是否為零值或 nil
|
||
func isZero(v reflect.Value) bool {
|
||
switch v.Kind() {
|
||
case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice:
|
||
return v.IsNil()
|
||
default:
|
||
return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
|
||
}
|
||
}
|