app-cloudep-wallet-service/pkg/repository/user_wallet.go

220 lines
6.8 KiB
Go
Raw Normal View History

2025-04-10 23:23:42 +00:00
package repository
import (
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
2025-04-16 09:24:54 +00:00
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
"context"
"database/sql"
"errors"
"fmt"
2025-04-10 23:23:42 +00:00
"github.com/shopspring/decimal"
"gorm.io/gorm"
2025-04-16 09:24:54 +00:00
"gorm.io/gorm/clause"
"time"
2025-04-10 23:23:42 +00:00
)
2025-04-18 09:10:40 +00:00
// WalletService 代表一個使用者在某資產上的錢包服務,
// 負責讀取/寫入資料庫並在記憶體暫存變動
type WalletService struct {
db *gorm.DB
uid string // 使用者識別碼
asset string // 資產代號 (如 BTC、ETH、TWD)
localBalances map[wallet.Types]entity.Wallet // 暫存各類型錢包當前餘額
localOrderBalances map[int64]decimal.Decimal // 暫存各訂單變動後的餘額
transactions []entity.WalletTransaction // 暫存所有尚未落庫的錢包交易紀錄
}
// NewWalletService 建立一個 WalletService 實例
func NewWalletService(db *gorm.DB, uid, asset string) repository.UserWalletService {
return &WalletService{
2025-04-16 09:24:54 +00:00
db: db,
uid: uid,
asset: asset,
2025-04-18 09:10:40 +00:00
localBalances: make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)),
localOrderBalances: make(map[int64]decimal.Decimal, len(wallet.AllTypes)),
2025-04-16 09:24:54 +00:00
}
}
2025-04-18 09:10:40 +00:00
// InitializeWallets 啟動時為新使用者初始化所有類型錢包,並寫入資料庫
func (s *WalletService) InitializeWallets(ctx context.Context, brand string) ([]entity.Wallet, error) {
var wallets []entity.Wallet
2025-04-16 09:24:54 +00:00
for _, t := range wallet.AllTypes {
wallets = append(wallets, entity.Wallet{
Brand: brand,
2025-04-18 09:10:40 +00:00
UID: s.uid,
Asset: s.asset,
Balance: decimal.Zero,
2025-04-16 09:24:54 +00:00
Type: t,
})
}
2025-04-18 09:10:40 +00:00
if err := s.db.WithContext(ctx).Create(&wallets).Error; err != nil {
2025-04-16 09:24:54 +00:00
return nil, err
}
2025-04-18 09:10:40 +00:00
// 將初始化後的錢包資料寫入本地緩存
for _, w := range wallets {
s.localBalances[w.Type] = w
2025-04-16 09:24:54 +00:00
}
return wallets, nil
}
2025-04-18 09:10:40 +00:00
// GetAllBalances 查詢該使用者某資產所有錢包類型當前餘額
func (s *WalletService) GetAllBalances(ctx context.Context) ([]entity.Wallet, error) {
2025-04-16 09:24:54 +00:00
var result []entity.Wallet
2025-04-18 09:10:40 +00:00
err := s.db.WithContext(ctx).
Where("uid = ? AND asset = ?", s.uid, s.asset).
Select("id, asset, balance, type").
2025-04-16 09:24:54 +00:00
Find(&result).Error
if err != nil {
2025-04-18 09:10:40 +00:00
return nil, err
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
for _, w := range result {
s.localBalances[w.Type] = w
2025-04-16 09:24:54 +00:00
}
return result, nil
}
2025-04-18 09:10:40 +00:00
// GetBalancesForTypes 查詢指定類型的錢包餘額,不上鎖
func (s *WalletService) GetBalancesForTypes(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) {
var result []entity.Wallet
err := s.db.WithContext(ctx).
Where("uid = ? AND asset = ?", s.uid, s.asset).
2025-04-16 09:24:54 +00:00
Where("type IN ?", kinds).
2025-04-18 09:10:40 +00:00
Select("id, asset, balance, type").
Find(&result).Error
2025-04-16 09:24:54 +00:00
if err != nil {
2025-04-18 09:10:40 +00:00
return nil, translateNotFound(err)
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
for _, w := range result {
s.localBalances[w.Type] = w
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
return result, nil
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
// GetBalancesForUpdate 查詢並鎖定指定類型的錢包 (FOR UPDATE)
func (s *WalletService) GetBalancesForUpdate(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) {
var result []entity.Wallet
err := s.db.WithContext(ctx).
Where("uid = ? AND asset = ?", s.uid, s.asset).
2025-04-16 09:24:54 +00:00
Where("type IN ?", kinds).
Clauses(clause.Locking{Strength: "UPDATE"}).
2025-04-18 09:10:40 +00:00
Select("id, asset, balance, type").
Find(&result).Error
2025-04-16 09:24:54 +00:00
if err != nil {
2025-04-18 09:10:40 +00:00
return nil, translateNotFound(err)
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
for _, w := range result {
s.localBalances[w.Type] = w
2025-04-10 23:23:42 +00:00
}
2025-04-18 09:10:40 +00:00
return result, nil
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
// CurrentBalance 從緩存中取得某種類型錢包的當前餘額
func (s *WalletService) CurrentBalance(kind wallet.Types) decimal.Decimal {
if w, ok := s.localBalances[kind]; ok {
return w.Balance
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
return decimal.Zero
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
// IncreaseBalance 在本地緩存新增餘額,並記錄一筆 WalletTransaction
func (s *WalletService) IncreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error {
w, ok := s.localBalances[kind]
2025-04-16 09:24:54 +00:00
if !ok {
return repository.ErrRecordNotFound
}
w.Balance = w.Balance.Add(amount)
if w.Balance.LessThan(decimal.Zero) {
return repository.ErrBalanceInsufficient
}
2025-04-18 09:10:40 +00:00
s.transactions = append(s.transactions, entity.WalletTransaction{
2025-04-16 09:24:54 +00:00
OrderID: orderID,
2025-04-18 09:10:40 +00:00
UID: s.uid,
2025-04-16 09:24:54 +00:00
WalletType: kind,
2025-04-18 09:10:40 +00:00
Asset: s.asset,
2025-04-16 09:24:54 +00:00
Amount: amount,
Balance: w.Balance,
})
2025-04-18 09:10:40 +00:00
s.localBalances[kind] = w
2025-04-16 09:24:54 +00:00
return nil
}
2025-04-18 09:10:40 +00:00
// DecreaseBalance 本質上是 IncreaseBalance 的負數版本
func (s *WalletService) DecreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error {
return s.IncreaseBalance(kind, orderID, amount.Neg())
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
// PrepareTransactions 為每筆暫存的 WalletTransaction 填入共用欄位 (txID, brand, businessType)
// 並回傳完整可落庫的切片
func (s *WalletService) PrepareTransactions(
2025-04-17 09:00:42 +00:00
txID int64,
2025-04-18 09:10:40 +00:00
orderID, brand string,
2025-04-17 09:00:42 +00:00
businessType wallet.BusinessName,
) []entity.WalletTransaction {
2025-04-18 09:10:40 +00:00
for i := range s.transactions {
s.transactions[i].TransactionID = txID
s.transactions[i].OrderID = orderID
s.transactions[i].Brand = brand
s.transactions[i].BusinessType = businessType.ToInt8()
2025-04-17 09:00:42 +00:00
}
2025-04-18 09:10:40 +00:00
return s.transactions
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
// PersistBalances 寫入本地緩存中所有錢包的最終餘額到資料庫
func (s *WalletService) PersistBalances(ctx context.Context) error {
return s.db.Transaction(func(tx *gorm.DB) error {
for _, w := range s.localBalances {
if err := tx.WithContext(ctx).
2025-04-16 09:24:54 +00:00
Model(&entity.Wallet{}).
Where("id = ?", w.ID).
2025-04-18 09:10:40 +00:00
UpdateColumns(map[string]interface{}{
"balance": w.Balance,
"update_at": time.Now().Unix(),
}).Error; err != nil {
return fmt.Errorf("更新錢包餘額失敗 (id=%d): %w", w.ID, err)
2025-04-16 09:24:54 +00:00
}
}
2025-04-18 09:10:40 +00:00
return nil
}, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
// PersistOrderBalances 寫入所有訂單錢包的最終餘額到 transaction 表
func (s *WalletService) PersistOrderBalances(ctx context.Context) error {
return s.db.Transaction(func(tx *gorm.DB) error {
for id, bal := range s.localOrderBalances {
if err := tx.WithContext(ctx).
2025-04-16 09:24:54 +00:00
Model(&entity.Transaction{}).
Where("id = ?", id).
2025-04-18 09:10:40 +00:00
Update("post_transfer_balance", bal).Error; err != nil {
return fmt.Errorf("更新訂單錢包餘額失敗 (id=%d): %w", id, err)
2025-04-16 09:24:54 +00:00
}
}
2025-04-18 09:10:40 +00:00
return nil
}, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
}
2025-04-16 09:24:54 +00:00
2025-04-18 09:10:40 +00:00
func (s *WalletService) HasAvailableBalance(ctx context.Context) (bool, error) {
var exists bool
err := s.db.WithContext(ctx).
Model(&entity.Wallet{}).
Select("1").
Where("uid = ? AND asset = ?", s.uid, s.asset).
Where("type = ?", wallet.TypeAvailable).
Limit(1).
Scan(&exists).Error
2025-04-16 09:24:54 +00:00
if err != nil {
2025-04-18 09:10:40 +00:00
return false, err
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
return exists, nil
2025-04-16 09:24:54 +00:00
}
2025-04-18 09:10:40 +00:00
// translateNotFound 將 GORM 的 RecordNotFound 轉為自訂錯誤
func translateNotFound(err error) error {
2025-04-16 09:24:54 +00:00
if errors.Is(err, gorm.ErrRecordNotFound) {
return repository.ErrRecordNotFound
}
return err
2025-04-10 23:23:42 +00:00
}