app-cloudep-wallet-service/pkg/repository/user_wallet.go

220 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
"context"
"database/sql"
"errors"
"fmt"
"github.com/shopspring/decimal"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"time"
)
// WalletService 代表一個使用者在某資產上的錢包服務,
// 負責讀取/寫入資料庫並在記憶體暫存變動
type WalletService struct {
db *gorm.DB
uid string // 使用者識別碼
asset string // 資產代號 (如 BTC、ETH、TWD)
localBalances map[wallet.Types]entity.Wallet // 暫存各類型錢包當前餘額
localOrderBalances map[int64]decimal.Decimal // 暫存各訂單變動後的餘額
transactions []entity.WalletTransaction // 暫存所有尚未落庫的錢包交易紀錄
}
// NewWalletService 建立一個 WalletService 實例
func NewWalletService(db *gorm.DB, uid, asset string) repository.UserWalletService {
return &WalletService{
db: db,
uid: uid,
asset: asset,
localBalances: make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)),
localOrderBalances: make(map[int64]decimal.Decimal, len(wallet.AllTypes)),
}
}
// InitializeWallets 啟動時為新使用者初始化所有類型錢包,並寫入資料庫
func (s *WalletService) InitializeWallets(ctx context.Context, brand string) ([]entity.Wallet, error) {
var wallets []entity.Wallet
for _, t := range wallet.AllTypes {
wallets = append(wallets, entity.Wallet{
Brand: brand,
UID: s.uid,
Asset: s.asset,
Balance: decimal.Zero,
Type: t,
})
}
if err := s.db.WithContext(ctx).Create(&wallets).Error; err != nil {
return nil, err
}
// 將初始化後的錢包資料寫入本地緩存
for _, w := range wallets {
s.localBalances[w.Type] = w
}
return wallets, nil
}
// GetAllBalances 查詢該使用者某資產所有錢包類型當前餘額
func (s *WalletService) GetAllBalances(ctx context.Context) ([]entity.Wallet, error) {
var result []entity.Wallet
err := s.db.WithContext(ctx).
Where("uid = ? AND asset = ?", s.uid, s.asset).
Select("id, asset, balance, type").
Find(&result).Error
if err != nil {
return nil, err
}
for _, w := range result {
s.localBalances[w.Type] = w
}
return result, nil
}
// GetBalancesForTypes 查詢指定類型的錢包餘額,不上鎖
func (s *WalletService) GetBalancesForTypes(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) {
var result []entity.Wallet
err := s.db.WithContext(ctx).
Where("uid = ? AND asset = ?", s.uid, s.asset).
Where("type IN ?", kinds).
Select("id, asset, balance, type").
Find(&result).Error
if err != nil {
return nil, translateNotFound(err)
}
for _, w := range result {
s.localBalances[w.Type] = w
}
return result, nil
}
// GetBalancesForUpdate 查詢並鎖定指定類型的錢包 (FOR UPDATE)
func (s *WalletService) GetBalancesForUpdate(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) {
var result []entity.Wallet
err := s.db.WithContext(ctx).
Where("uid = ? AND asset = ?", s.uid, s.asset).
Where("type IN ?", kinds).
Clauses(clause.Locking{Strength: "UPDATE"}).
Select("id, asset, balance, type").
Find(&result).Error
if err != nil {
return nil, translateNotFound(err)
}
for _, w := range result {
s.localBalances[w.Type] = w
}
return result, nil
}
// CurrentBalance 從緩存中取得某種類型錢包的當前餘額
func (s *WalletService) CurrentBalance(kind wallet.Types) decimal.Decimal {
if w, ok := s.localBalances[kind]; ok {
return w.Balance
}
return decimal.Zero
}
// IncreaseBalance 在本地緩存新增餘額,並記錄一筆 WalletTransaction
func (s *WalletService) IncreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error {
w, ok := s.localBalances[kind]
if !ok {
return repository.ErrRecordNotFound
}
w.Balance = w.Balance.Add(amount)
if w.Balance.LessThan(decimal.Zero) {
return repository.ErrBalanceInsufficient
}
s.transactions = append(s.transactions, entity.WalletTransaction{
OrderID: orderID,
UID: s.uid,
WalletType: kind,
Asset: s.asset,
Amount: amount,
Balance: w.Balance,
})
s.localBalances[kind] = w
return nil
}
// DecreaseBalance 本質上是 IncreaseBalance 的負數版本
func (s *WalletService) DecreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error {
return s.IncreaseBalance(kind, orderID, amount.Neg())
}
// PrepareTransactions 為每筆暫存的 WalletTransaction 填入共用欄位 (txID, brand, businessType)
// 並回傳完整可落庫的切片
func (s *WalletService) PrepareTransactions(
txID int64,
orderID, brand string,
businessType wallet.BusinessName,
) []entity.WalletTransaction {
for i := range s.transactions {
s.transactions[i].TransactionID = txID
s.transactions[i].OrderID = orderID
s.transactions[i].Brand = brand
s.transactions[i].BusinessType = businessType.ToInt8()
}
return s.transactions
}
// PersistBalances 寫入本地緩存中所有錢包的最終餘額到資料庫
func (s *WalletService) PersistBalances(ctx context.Context) error {
return s.db.Transaction(func(tx *gorm.DB) error {
for _, w := range s.localBalances {
if err := tx.WithContext(ctx).
Model(&entity.Wallet{}).
Where("id = ?", w.ID).
UpdateColumns(map[string]interface{}{
"balance": w.Balance,
"update_at": time.Now().Unix(),
}).Error; err != nil {
return fmt.Errorf("更新錢包餘額失敗 (id=%d): %w", w.ID, err)
}
}
return nil
}, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
}
// PersistOrderBalances 寫入所有訂單錢包的最終餘額到 transaction 表
func (s *WalletService) PersistOrderBalances(ctx context.Context) error {
return s.db.Transaction(func(tx *gorm.DB) error {
for id, bal := range s.localOrderBalances {
if err := tx.WithContext(ctx).
Model(&entity.Transaction{}).
Where("id = ?", id).
Update("post_transfer_balance", bal).Error; err != nil {
return fmt.Errorf("更新訂單錢包餘額失敗 (id=%d): %w", id, err)
}
}
return nil
}, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
}
func (s *WalletService) HasAvailableBalance(ctx context.Context) (bool, error) {
var exists bool
err := s.db.WithContext(ctx).
Model(&entity.Wallet{}).
Select("1").
Where("uid = ? AND asset = ?", s.uid, s.asset).
Where("type = ?", wallet.TypeAvailable).
Limit(1).
Scan(&exists).Error
if err != nil {
return false, err
}
return exists, nil
}
// translateNotFound 將 GORM 的 RecordNotFound 轉為自訂錯誤
func translateNotFound(err error) error {
if errors.Is(err, gorm.ErrRecordNotFound) {
return repository.ErrRecordNotFound
}
return err
}