backend/pkg/library/cassandra/lock.go

121 lines
2.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cassandra
import (
"context"
"errors"
"fmt"
"time"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v2/qb"
)
const (
defaultLockTTLSec = 30
defaultLockRetry = 3
lockBaseDelay = 100 * time.Millisecond
)
// LockOption 用來設定 TryLock 的 TTL 行為
type LockOption func(*lockOptions)
type lockOptions struct {
ttlSeconds int // TTL單位秒<=0 代表不 expire
}
// WithLockTTL 設定鎖的 TTL
func WithLockTTL(d time.Duration) LockOption {
return func(o *lockOptions) {
o.ttlSeconds = int(d.Seconds())
}
}
// WithNoLockExpire 永不自動解鎖
func WithNoLockExpire() LockOption {
return func(o *lockOptions) {
o.ttlSeconds = 0
}
}
// TryLock 嘗試在表上插入一筆唯一鍵IF NOT EXISTS作為鎖
// 預設 30 秒 TTL可透過 option 調整或取消 TTL
func (r *repository[T]) TryLock(ctx context.Context, doc T, opts ...LockOption) error {
// 組合 option
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
for _, opt := range opts {
opt(options)
}
// 建 TTL 子句
builder := qb.Insert(r.table).
Unique(). // IF NOT EXISTS
Columns(r.metadata.Columns...)
if options.ttlSeconds > 0 {
ttl := time.Duration(options.ttlSeconds) * time.Second
builder = builder.TTL(ttl)
}
stmt, names := builder.ToCql()
// 執行 CAS
q := r.db.session.Query(stmt, names).BindStruct(doc).
WithContext(ctx).
WithTimestamp(time.Now().UnixNano() / 1e3).
SerialConsistency(gocql.Serial)
applied, err := q.ExecCASRelease()
if err != nil {
return ErrInvalidInput.WithTable(r.table).WithError(err)
}
if !applied {
return NewError(ErrCodeConflict, "acquire lock failed").WithTable(r.table)
}
return nil
}
// UnLock 釋放鎖,其實就是 Delete
func (r *repository[T]) UnLock(ctx context.Context, doc T) error {
var lastErr error
for i := 0; i < defaultLockRetry; i++ {
builder := qb.Delete(r.table).Existing()
// 動態添加 WHERE 條件(使用 Partition Key
for _, key := range r.metadata.PartKey {
builder = builder.Where(qb.Eq(key))
}
stmt, names := builder.ToCql()
q := r.db.session.Query(stmt, names).BindStruct(doc).
WithContext(ctx).
WithTimestamp(time.Now().UnixNano() / 1e3).
SerialConsistency(gocql.Serial)
applied, err := q.ExecCASRelease()
if err == nil && applied {
return nil
}
if err != nil {
lastErr = fmt.Errorf("unlock error: %w", err)
} else if !applied {
lastErr = fmt.Errorf("unlock not applied: row not found or not visible yet")
}
time.Sleep(lockBaseDelay * time.Duration(1<<i)) // 100ms → 200ms → 400ms
}
return ErrInvalidInput.WithTable(r.table).WithError(
fmt.Errorf("unlock failed after %d retries: %w", defaultLockRetry, lastErr),
)
}
// IsLockFailed 檢查錯誤是否為獲取鎖失敗
func IsLockFailed(err error) bool {
var e *Error
if errors.As(err, &e) {
return e.Code == ErrCodeConflict && e.Message == "acquire lock failed"
}
return false
}