package repository import ( "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet" "context" "database/sql" "errors" "fmt" "github.com/shopspring/decimal" "gorm.io/gorm" "gorm.io/gorm/clause" "time" ) // 用戶某個幣種餘額 type userWallet struct { db *gorm.DB uid string asset string // local wallet 相關計算的餘額存在這裡 localWalletBalance map[wallet.Types]entity.Wallet // local order wallet 相關計算的餘額存在這裡 localOrderBalance map[int64]decimal.Decimal // local wallet 內所有餘額變化紀錄 transactions []entity.WalletTransaction } func NewUserWallet(db *gorm.DB, uid, asset string) repository.UserWalletService { return &userWallet{ db: db, uid: uid, asset: asset, localWalletBalance: make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)), localOrderBalance: make(map[int64]decimal.Decimal, len(wallet.AllTypes)), } } func (repo *userWallet) Init(ctx context.Context, uid, asset, brand string) ([]entity.Wallet, error) { wallets := make([]entity.Wallet, 0, len(wallet.AllTypes)) for _, t := range wallet.AllTypes { balance := decimal.Zero wallets = append(wallets, entity.Wallet{ Brand: brand, UID: uid, Asset: asset, Balance: balance, Type: t, }) } if err := repo.db.WithContext(ctx).Create(&wallets).Error; err != nil { return nil, err } for _, v := range wallets { repo.localWalletBalance[v.Type] = v } return wallets, nil } func (repo *userWallet) All(ctx context.Context) ([]entity.Wallet, error) { var result []entity.Wallet err := repo.buildCommonWhereSQL(repo.uid, repo.asset). WithContext(ctx). Select("id, crypto, balance, type"). Find(&result).Error if err != nil { return []entity.Wallet{}, err } for _, v := range result { repo.localWalletBalance[v.Type] = v } return result, nil } func (repo *userWallet) Get(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) { var wallets []entity.Wallet err := repo.buildCommonWhereSQL(repo.uid, repo.asset). WithContext(ctx). Model(&entity.Wallet{}). Select("id, crypto, balance, type"). Where("type IN ?", kinds). Find(&wallets).Error if err != nil { return []entity.Wallet{}, notFoundError(err) } for _, w := range wallets { repo.localWalletBalance[w.Type] = w } return wallets, nil } func (repo *userWallet) GetWithLock(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) { var wallets []entity.Wallet err := repo.buildCommonWhereSQL(repo.uid, repo.asset). WithContext(ctx). Model(&entity.Wallet{}). Select("id, crypto, balance, type"). Where("type IN ?", kinds). Clauses(clause.Locking{Strength: "UPDATE"}). Find(&wallets).Error if err != nil { return []entity.Wallet{}, notFoundError(err) } for _, w := range wallets { repo.localWalletBalance[w.Type] = w } return wallets, nil } func (repo *userWallet) LocalBalance(kind wallet.Types) decimal.Decimal { w, ok := repo.localWalletBalance[kind] if !ok { return decimal.Zero } return w.Balance } func (repo *userWallet) LockByIDs(ctx context.Context, ids []int64) ([]entity.Wallet, error) { var wallets []entity.Wallet err := repo.db.WithContext(ctx). Model(&entity.Wallet{}). Select("id, crypto, balance, type"). Where("id IN ?", ids). Clauses(clause.Locking{Strength: "UPDATE"}). Find(&wallets).Error if err != nil { return []entity.Wallet{}, notFoundError(err) } for _, w := range wallets { repo.localWalletBalance[w.Type] = w } return wallets, nil } func (repo *userWallet) CheckReady(ctx context.Context) (bool, error) { var exists bool err := repo.buildCommonWhereSQL(repo.uid, repo.asset).WithContext(ctx). Model(&entity.Wallet{}). Select("1"). Where("type = ?", wallet.TypeAvailable). Limit(1). Scan(&exists).Error if err != nil { return false, err } return exists, nil } // Add 新增某種餘額餘額 // 使用前 localWalletBalance 必須有資料,所以必須執行過 GetWithLock / All 才會有資料 func (repo *userWallet) Add(kind wallet.Types, orderID string, amount decimal.Decimal) error { w, ok := repo.localWalletBalance[kind] if !ok { return repository.ErrRecordNotFound } w.Balance = w.Balance.Add(amount) if w.Balance.LessThan(decimal.Zero) { return repository.ErrBalanceInsufficient } repo.transactions = append(repo.transactions, entity.WalletTransaction{ OrderID: orderID, UID: repo.uid, WalletType: kind, Asset: repo.asset, Amount: amount, Balance: w.Balance, }) repo.localWalletBalance[kind] = w return nil } func (repo *userWallet) Sub(kind wallet.Types, orderID string, amount decimal.Decimal) error { return repo.Add(kind, orderID, decimal.Zero.Sub(amount)) } func (repo *userWallet) AddTransaction(txID int64, orderID string, brand string, business wallet.BusinessName, kind wallet.Types, amount decimal.Decimal) { balance := repo.LocalBalance(kind).Add(amount) repo.transactions = append(repo.transactions, entity.WalletTransaction{ TransactionID: txID, OrderID: orderID, Brand: brand, UID: repo.uid, WalletType: kind, BusinessType: business.ToInt8(), Asset: repo.asset, Amount: amount, Balance: balance, CreateAt: time.Now().UTC().UnixNano(), }) } func (repo *userWallet) Commit(ctx context.Context) error { // 事務隔離等級設定 rc := &sql.TxOptions{ Isolation: sql.LevelReadCommitted, ReadOnly: false, } err := repo.db.Transaction(func(tx *gorm.DB) error { for _, w := range repo.localWalletBalance { err := tx.WithContext(ctx). Model(&entity.Wallet{}). Where("id = ?", w.ID). UpdateColumns(map[string]any{ "balance": w.Balance, "update_time": time.Now().UTC().Unix(), }).Error if err != nil { return fmt.Errorf("failed to update wallet id %d: %w", w.ID, err) } } return nil // 所有更新成功才 return nil }, rc) if err != nil { return fmt.Errorf("update uid: %s asset: %s error: %w", repo.uid, repo.asset, err) } return nil } func (repo *userWallet) GetTransactions() []entity.WalletTransaction { return repo.transactions } func (repo *userWallet) CommitOrder(ctx context.Context) error { rc := &sql.TxOptions{ Isolation: sql.LevelReadCommitted, ReadOnly: false, } err := repo.db.Transaction(func(tx *gorm.DB) error { for id, balance := range repo.localOrderBalance { err := tx.WithContext(ctx). Model(&entity.Transaction{}). Where("id = ?", id). Update("balance", balance).Error if err != nil { return fmt.Errorf("failed to update order balance, id=%d, err=%w", id, err) } } return nil // 所有更新成功才 return nil }, rc) if err != nil { return fmt.Errorf("update uid: %s asset: %s error: %w", repo.uid, repo.asset, err) } return nil } // ============================================================================= func (repo *userWallet) buildCommonWhereSQL(uid, asset string) *gorm.DB { return repo.db.Where("uid = ?", uid). Where("asset = ?", asset) } func notFoundError(err error) error { if errors.Is(err, gorm.ErrRecordNotFound) { return repository.ErrRecordNotFound } return err }