121 lines
2.9 KiB
Go
121 lines
2.9 KiB
Go
package cassandra
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/gocql/gocql"
|
||
"github.com/scylladb/gocqlx/v3/qb"
|
||
)
|
||
|
||
const (
|
||
defaultLockTTLSec = 30
|
||
defaultLockRetry = 3
|
||
lockBaseDelay = 100 * time.Millisecond
|
||
)
|
||
|
||
// LockOption 用來設定 TryLock 的 TTL 行為
|
||
type LockOption func(*lockOptions)
|
||
|
||
type lockOptions struct {
|
||
ttlSeconds int // TTL,單位秒;<=0 代表不 expire
|
||
}
|
||
|
||
// WithLockTTL 設定鎖的 TTL
|
||
func WithLockTTL(d time.Duration) LockOption {
|
||
return func(o *lockOptions) {
|
||
o.ttlSeconds = int(d.Seconds())
|
||
}
|
||
}
|
||
|
||
// WithNoLockExpire 永不自動解鎖
|
||
func WithNoLockExpire() LockOption {
|
||
return func(o *lockOptions) {
|
||
o.ttlSeconds = 0
|
||
}
|
||
}
|
||
|
||
// TryLock 嘗試在表上插入一筆唯一鍵(IF NOT EXISTS)作為鎖
|
||
// 預設 30 秒 TTL,可透過 option 調整或取消 TTL
|
||
func (r *repository[T]) TryLock(ctx context.Context, doc T, opts ...LockOption) error {
|
||
// 組合 option
|
||
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
|
||
for _, opt := range opts {
|
||
opt(options)
|
||
}
|
||
|
||
// 建 TTL 子句
|
||
builder := qb.Insert(r.table).
|
||
Unique(). // IF NOT EXISTS
|
||
Columns(r.metadata.Columns...)
|
||
|
||
if options.ttlSeconds > 0 {
|
||
ttl := time.Duration(options.ttlSeconds) * time.Second
|
||
builder = builder.TTL(ttl)
|
||
}
|
||
stmt, names := builder.ToCql()
|
||
|
||
// 執行 CAS
|
||
q := r.db.session.Query(stmt, names).BindStruct(doc).
|
||
WithContext(ctx).
|
||
WithTimestamp(time.Now().UnixNano() / 1e3).
|
||
SerialConsistency(gocql.Serial)
|
||
|
||
applied, err := q.ExecCASRelease()
|
||
if err != nil {
|
||
return ErrInvalidInput.WithTable(r.table).WithError(err)
|
||
}
|
||
|
||
if !applied {
|
||
return NewError(ErrCodeConflict, "acquire lock failed").WithTable(r.table)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// UnLock 釋放鎖,其實就是 Delete
|
||
func (r *repository[T]) UnLock(ctx context.Context, doc T) error {
|
||
var lastErr error
|
||
|
||
for i := 0; i < defaultLockRetry; i++ {
|
||
builder := qb.Delete(r.table).Existing()
|
||
|
||
// 動態添加 WHERE 條件(使用 Partition Key)
|
||
for _, key := range r.metadata.PartKey {
|
||
builder = builder.Where(qb.Eq(key))
|
||
}
|
||
stmt, names := builder.ToCql()
|
||
q := r.db.session.Query(stmt, names).BindStruct(doc).
|
||
WithContext(ctx).
|
||
WithTimestamp(time.Now().UnixNano() / 1e3).
|
||
SerialConsistency(gocql.Serial)
|
||
|
||
applied, err := q.ExecCASRelease()
|
||
if err == nil && applied {
|
||
return nil
|
||
}
|
||
|
||
if err != nil {
|
||
lastErr = fmt.Errorf("unlock error: %w", err)
|
||
} else if !applied {
|
||
lastErr = fmt.Errorf("unlock not applied: row not found or not visible yet")
|
||
}
|
||
|
||
time.Sleep(lockBaseDelay * time.Duration(1<<i)) // 100ms → 200ms → 400ms
|
||
}
|
||
|
||
return ErrInvalidInput.WithTable(r.table).WithError(
|
||
fmt.Errorf("unlock failed after %d retries: %w", defaultLockRetry, lastErr),
|
||
)
|
||
}
|
||
|
||
// IsLockFailed 檢查錯誤是否為獲取鎖失敗
|
||
func IsLockFailed(err error) bool {
|
||
var e *Error
|
||
if errors.As(err, &e) {
|
||
return e.Code == ErrCodeConflict && e.Message == "acquire lock failed"
|
||
}
|
||
return false
|
||
}
|