package cassandra import ( "context" "errors" "fmt" "time" "github.com/gocql/gocql" "github.com/scylladb/gocqlx/v2/qb" ) const ( defaultLockTTLSec = 30 defaultLockRetry = 3 lockBaseDelay = 100 * time.Millisecond ) // LockOption 用來設定 TryLock 的 TTL 行為 type LockOption func(*lockOptions) type lockOptions struct { ttlSeconds int // TTL,單位秒;<=0 代表不 expire } // WithLockTTL 設定鎖的 TTL func WithLockTTL(d time.Duration) LockOption { return func(o *lockOptions) { o.ttlSeconds = int(d.Seconds()) } } // WithNoLockExpire 永不自動解鎖 func WithNoLockExpire() LockOption { return func(o *lockOptions) { o.ttlSeconds = 0 } } // TryLock 嘗試在表上插入一筆唯一鍵(IF NOT EXISTS)作為鎖 // 預設 30 秒 TTL,可透過 option 調整或取消 TTL func (r *repository[T]) TryLock(ctx context.Context, doc T, opts ...LockOption) error { // 組合 option options := &lockOptions{ttlSeconds: defaultLockTTLSec} for _, opt := range opts { opt(options) } // 建 TTL 子句 builder := qb.Insert(r.table). Unique(). // IF NOT EXISTS Columns(r.metadata.Columns...) if options.ttlSeconds > 0 { ttl := time.Duration(options.ttlSeconds) * time.Second builder = builder.TTL(ttl) } stmt, names := builder.ToCql() // 執行 CAS q := r.db.session.Query(stmt, names).BindStruct(doc). WithContext(ctx). WithTimestamp(time.Now().UnixNano() / 1e3). SerialConsistency(gocql.Serial) applied, err := q.ExecCASRelease() if err != nil { return ErrInvalidInput.WithTable(r.table).WithError(err) } if !applied { return NewError(ErrCodeConflict, "acquire lock failed").WithTable(r.table) } return nil } // UnLock 釋放鎖,其實就是 Delete func (r *repository[T]) UnLock(ctx context.Context, doc T) error { var lastErr error for i := 0; i < defaultLockRetry; i++ { builder := qb.Delete(r.table).Existing() // 動態添加 WHERE 條件(使用 Partition Key) for _, key := range r.metadata.PartKey { builder = builder.Where(qb.Eq(key)) } stmt, names := builder.ToCql() q := r.db.session.Query(stmt, names).BindStruct(doc). WithContext(ctx). WithTimestamp(time.Now().UnixNano() / 1e3). SerialConsistency(gocql.Serial) applied, err := q.ExecCASRelease() if err == nil && applied { return nil } if err != nil { lastErr = fmt.Errorf("unlock error: %w", err) } else if !applied { lastErr = fmt.Errorf("unlock not applied: row not found or not visible yet") } time.Sleep(lockBaseDelay * time.Duration(1<