app-cloudep-wallet-service/pkg/repository/user_wallet.go

305 lines
7.4 KiB
Go
Raw Normal View History

2025-04-10 23:23:42 +00:00
package repository
import (
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
2025-04-16 09:24:54 +00:00
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
"context"
"database/sql"
"errors"
"fmt"
2025-04-10 23:23:42 +00:00
"github.com/shopspring/decimal"
"gorm.io/gorm"
2025-04-16 09:24:54 +00:00
"gorm.io/gorm/clause"
"time"
2025-04-10 23:23:42 +00:00
)
2025-04-16 09:24:54 +00:00
// 用戶某個幣種餘額
type userWallet struct {
db *gorm.DB
uid string
asset string
// local wallet 相關計算的餘額存在這裡
localWalletBalance map[wallet.Types]entity.Wallet
// local order wallet 相關計算的餘額存在這裡
localOrderBalance map[int64]decimal.Decimal
// local wallet 內所有餘額變化紀錄
transactions []entity.WalletTransaction
}
func NewUserWallet(db *gorm.DB, uid, asset string) repository.UserWalletService {
2025-04-10 23:23:42 +00:00
return &userWallet{
2025-04-16 09:24:54 +00:00
db: db,
uid: uid,
asset: asset,
localWalletBalance: make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)),
localOrderBalance: make(map[int64]decimal.Decimal, len(wallet.AllTypes)),
}
}
func (repo *userWallet) Init(ctx context.Context, uid, asset, brand string) ([]entity.Wallet, error) {
wallets := make([]entity.Wallet, 0, len(wallet.AllTypes))
for _, t := range wallet.AllTypes {
balance := decimal.Zero
wallets = append(wallets, entity.Wallet{
Brand: brand,
UID: uid,
Asset: asset,
Balance: balance,
Type: t,
})
}
if err := repo.db.WithContext(ctx).Create(&wallets).Error; err != nil {
return nil, err
}
for _, v := range wallets {
repo.localWalletBalance[v.Type] = v
}
return wallets, nil
}
func (repo *userWallet) All(ctx context.Context) ([]entity.Wallet, error) {
var result []entity.Wallet
err := repo.buildCommonWhereSQL(repo.uid, repo.asset).
WithContext(ctx).
Select("id, crypto, balance, type").
Find(&result).Error
if err != nil {
return []entity.Wallet{}, err
}
for _, v := range result {
repo.localWalletBalance[v.Type] = v
}
return result, nil
}
func (repo *userWallet) Get(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) {
var wallets []entity.Wallet
err := repo.buildCommonWhereSQL(repo.uid, repo.asset).
WithContext(ctx).
Model(&entity.Wallet{}).
Select("id, crypto, balance, type").
Where("type IN ?", kinds).
Find(&wallets).Error
if err != nil {
return []entity.Wallet{}, notFoundError(err)
}
for _, w := range wallets {
repo.localWalletBalance[w.Type] = w
}
return wallets, nil
}
func (repo *userWallet) GetWithLock(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) {
var wallets []entity.Wallet
err := repo.buildCommonWhereSQL(repo.uid, repo.asset).
WithContext(ctx).
Model(&entity.Wallet{}).
Select("id, crypto, balance, type").
Where("type IN ?", kinds).
Clauses(clause.Locking{Strength: "UPDATE"}).
Find(&wallets).Error
if err != nil {
return []entity.Wallet{}, notFoundError(err)
}
for _, w := range wallets {
repo.localWalletBalance[w.Type] = w
}
return wallets, nil
}
func (repo *userWallet) LocalBalance(kind wallet.Types) decimal.Decimal {
w, ok := repo.localWalletBalance[kind]
if !ok {
return decimal.Zero
}
return w.Balance
}
func (repo *userWallet) LockByIDs(ctx context.Context, ids []int64) ([]entity.Wallet, error) {
var wallets []entity.Wallet
err := repo.db.WithContext(ctx).
Model(&entity.Wallet{}).
Select("id, crypto, balance, type").
Where("id IN ?", ids).
Clauses(clause.Locking{Strength: "UPDATE"}).
Find(&wallets).Error
if err != nil {
return []entity.Wallet{}, notFoundError(err)
}
2025-04-10 23:23:42 +00:00
2025-04-16 09:24:54 +00:00
for _, w := range wallets {
repo.localWalletBalance[w.Type] = w
2025-04-10 23:23:42 +00:00
}
2025-04-16 09:24:54 +00:00
return wallets, nil
}
func (repo *userWallet) CheckReady(ctx context.Context) (bool, error) {
var exists bool
err := repo.buildCommonWhereSQL(repo.uid, repo.asset).WithContext(ctx).
Model(&entity.Wallet{}).
Select("1").
Where("type = ?", wallet.TypeAvailable).
Limit(1).
Scan(&exists).Error
if err != nil {
return false, err
}
return exists, nil
}
// Add 新增某種餘額餘額
// 使用前 localWalletBalance 必須有資料,所以必須執行過 GetWithLock / All 才會有資料
func (repo *userWallet) Add(kind wallet.Types, orderID string, amount decimal.Decimal) error {
w, ok := repo.localWalletBalance[kind]
if !ok {
return repository.ErrRecordNotFound
}
w.Balance = w.Balance.Add(amount)
if w.Balance.LessThan(decimal.Zero) {
return repository.ErrBalanceInsufficient
}
repo.transactions = append(repo.transactions, entity.WalletTransaction{
OrderID: orderID,
UID: repo.uid,
WalletType: kind,
Asset: repo.asset,
Amount: amount,
Balance: w.Balance,
})
repo.localWalletBalance[kind] = w
return nil
}
func (repo *userWallet) Sub(kind wallet.Types, orderID string, amount decimal.Decimal) error {
return repo.Add(kind, orderID, decimal.Zero.Sub(amount))
}
2025-04-17 09:00:42 +00:00
// Transactions 為本次整筆交易 (txID) 給所有暫存的 WalletTransaction 設置共用欄位,
// 並回傳整批交易紀錄以便後續寫入資料庫。
func (repo *userWallet) Transactions(
txID int64,
orderID string,
brand string,
businessType wallet.BusinessName,
) []entity.WalletTransaction {
for i := range repo.transactions {
repo.transactions[i].TransactionID = txID
repo.transactions[i].OrderID = orderID
repo.transactions[i].Brand = brand
repo.transactions[i].BusinessType = businessType.ToInt8()
}
return repo.transactions
2025-04-16 09:24:54 +00:00
}
func (repo *userWallet) Commit(ctx context.Context) error {
// 事務隔離等級設定
rc := &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
ReadOnly: false,
}
err := repo.db.Transaction(func(tx *gorm.DB) error {
for _, w := range repo.localWalletBalance {
err := tx.WithContext(ctx).
Model(&entity.Wallet{}).
Where("id = ?", w.ID).
UpdateColumns(map[string]any{
"balance": w.Balance,
"update_time": time.Now().UTC().Unix(),
}).Error
if err != nil {
return fmt.Errorf("failed to update wallet id %d: %w", w.ID, err)
}
}
return nil // 所有更新成功才 return nil
}, rc)
if err != nil {
return fmt.Errorf("update uid: %s asset: %s error: %w", repo.uid, repo.asset, err)
}
return nil
}
func (repo *userWallet) GetTransactions() []entity.WalletTransaction {
return repo.transactions
}
func (repo *userWallet) CommitOrder(ctx context.Context) error {
rc := &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
ReadOnly: false,
}
err := repo.db.Transaction(func(tx *gorm.DB) error {
for id, balance := range repo.localOrderBalance {
err := tx.WithContext(ctx).
Model(&entity.Transaction{}).
Where("id = ?", id).
Update("balance", balance).Error
if err != nil {
return fmt.Errorf("failed to update order balance, id=%d, err=%w", id, err)
}
}
return nil // 所有更新成功才 return nil
}, rc)
if err != nil {
return fmt.Errorf("update uid: %s asset: %s error: %w", repo.uid, repo.asset, err)
}
return nil
}
2025-04-17 09:00:42 +00:00
func (repo *userWallet) AddTransaction(txID int64, orderID string, brand string, business wallet.BusinessName, kind wallet.Types, amount decimal.Decimal) {
//TODO implement me
panic("implement me")
}
2025-04-16 09:24:54 +00:00
// =============================================================================
func (repo *userWallet) buildCommonWhereSQL(uid, asset string) *gorm.DB {
return repo.db.Where("uid = ?", uid).
Where("asset = ?", asset)
}
func notFoundError(err error) error {
if errors.Is(err, gorm.ErrRecordNotFound) {
return repository.ErrRecordNotFound
}
return err
2025-04-10 23:23:42 +00:00
}