2025-04-10 23:23:42 +00:00
|
|
|
package repository
|
|
|
|
|
|
|
|
import (
|
|
|
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
|
2025-04-16 09:24:54 +00:00
|
|
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
|
|
|
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
|
|
|
|
"context"
|
|
|
|
"database/sql"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
2025-04-10 23:23:42 +00:00
|
|
|
"github.com/shopspring/decimal"
|
|
|
|
"gorm.io/gorm"
|
2025-04-16 09:24:54 +00:00
|
|
|
"gorm.io/gorm/clause"
|
|
|
|
"time"
|
2025-04-10 23:23:42 +00:00
|
|
|
)
|
|
|
|
|
2025-04-16 09:24:54 +00:00
|
|
|
// 用戶某個幣種餘額
|
|
|
|
type userWallet struct {
|
|
|
|
db *gorm.DB
|
|
|
|
uid string
|
|
|
|
asset string
|
|
|
|
|
|
|
|
// local wallet 相關計算的餘額存在這裡
|
|
|
|
localWalletBalance map[wallet.Types]entity.Wallet
|
|
|
|
// local order wallet 相關計算的餘額存在這裡
|
|
|
|
localOrderBalance map[int64]decimal.Decimal
|
|
|
|
// local wallet 內所有餘額變化紀錄
|
|
|
|
transactions []entity.WalletTransaction
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewUserWallet(db *gorm.DB, uid, asset string) repository.UserWalletService {
|
2025-04-10 23:23:42 +00:00
|
|
|
return &userWallet{
|
2025-04-16 09:24:54 +00:00
|
|
|
db: db,
|
|
|
|
uid: uid,
|
|
|
|
asset: asset,
|
|
|
|
localWalletBalance: make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)),
|
|
|
|
localOrderBalance: make(map[int64]decimal.Decimal, len(wallet.AllTypes)),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) Init(ctx context.Context, uid, asset, brand string) ([]entity.Wallet, error) {
|
|
|
|
wallets := make([]entity.Wallet, 0, len(wallet.AllTypes))
|
|
|
|
for _, t := range wallet.AllTypes {
|
|
|
|
balance := decimal.Zero
|
|
|
|
|
|
|
|
wallets = append(wallets, entity.Wallet{
|
|
|
|
Brand: brand,
|
|
|
|
UID: uid,
|
|
|
|
Asset: asset,
|
|
|
|
Balance: balance,
|
|
|
|
Type: t,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := repo.db.WithContext(ctx).Create(&wallets).Error; err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, v := range wallets {
|
|
|
|
repo.localWalletBalance[v.Type] = v
|
|
|
|
}
|
|
|
|
|
|
|
|
return wallets, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) All(ctx context.Context) ([]entity.Wallet, error) {
|
|
|
|
var result []entity.Wallet
|
|
|
|
|
|
|
|
err := repo.buildCommonWhereSQL(repo.uid, repo.asset).
|
|
|
|
WithContext(ctx).
|
|
|
|
Select("id, crypto, balance, type").
|
|
|
|
Find(&result).Error
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return []entity.Wallet{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, v := range result {
|
|
|
|
repo.localWalletBalance[v.Type] = v
|
|
|
|
}
|
|
|
|
|
|
|
|
return result, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) Get(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) {
|
|
|
|
var wallets []entity.Wallet
|
|
|
|
|
|
|
|
err := repo.buildCommonWhereSQL(repo.uid, repo.asset).
|
|
|
|
WithContext(ctx).
|
|
|
|
Model(&entity.Wallet{}).
|
|
|
|
Select("id, crypto, balance, type").
|
|
|
|
Where("type IN ?", kinds).
|
|
|
|
Find(&wallets).Error
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return []entity.Wallet{}, notFoundError(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, w := range wallets {
|
|
|
|
repo.localWalletBalance[w.Type] = w
|
|
|
|
}
|
|
|
|
|
|
|
|
return wallets, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) GetWithLock(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) {
|
|
|
|
var wallets []entity.Wallet
|
|
|
|
|
|
|
|
err := repo.buildCommonWhereSQL(repo.uid, repo.asset).
|
|
|
|
WithContext(ctx).
|
|
|
|
Model(&entity.Wallet{}).
|
|
|
|
Select("id, crypto, balance, type").
|
|
|
|
Where("type IN ?", kinds).
|
|
|
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
|
|
Find(&wallets).Error
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return []entity.Wallet{}, notFoundError(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, w := range wallets {
|
|
|
|
repo.localWalletBalance[w.Type] = w
|
|
|
|
}
|
|
|
|
|
|
|
|
return wallets, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) LocalBalance(kind wallet.Types) decimal.Decimal {
|
|
|
|
w, ok := repo.localWalletBalance[kind]
|
|
|
|
if !ok {
|
|
|
|
return decimal.Zero
|
|
|
|
}
|
|
|
|
|
|
|
|
return w.Balance
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) LockByIDs(ctx context.Context, ids []int64) ([]entity.Wallet, error) {
|
|
|
|
var wallets []entity.Wallet
|
|
|
|
|
|
|
|
err := repo.db.WithContext(ctx).
|
|
|
|
Model(&entity.Wallet{}).
|
|
|
|
Select("id, crypto, balance, type").
|
|
|
|
Where("id IN ?", ids).
|
|
|
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
|
|
Find(&wallets).Error
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return []entity.Wallet{}, notFoundError(err)
|
|
|
|
}
|
2025-04-10 23:23:42 +00:00
|
|
|
|
2025-04-16 09:24:54 +00:00
|
|
|
for _, w := range wallets {
|
|
|
|
repo.localWalletBalance[w.Type] = w
|
2025-04-10 23:23:42 +00:00
|
|
|
}
|
2025-04-16 09:24:54 +00:00
|
|
|
|
|
|
|
return wallets, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) CheckReady(ctx context.Context) (bool, error) {
|
|
|
|
var exists bool
|
|
|
|
err := repo.buildCommonWhereSQL(repo.uid, repo.asset).WithContext(ctx).
|
|
|
|
Model(&entity.Wallet{}).
|
|
|
|
Select("1").
|
|
|
|
Where("type = ?", wallet.TypeAvailable).
|
|
|
|
Limit(1).
|
|
|
|
Scan(&exists).Error
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return false, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return exists, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add 新增某種餘額餘額
|
|
|
|
// 使用前 localWalletBalance 必須有資料,所以必須執行過 GetWithLock / All 才會有資料
|
|
|
|
func (repo *userWallet) Add(kind wallet.Types, orderID string, amount decimal.Decimal) error {
|
|
|
|
w, ok := repo.localWalletBalance[kind]
|
|
|
|
if !ok {
|
|
|
|
return repository.ErrRecordNotFound
|
|
|
|
}
|
|
|
|
|
|
|
|
w.Balance = w.Balance.Add(amount)
|
|
|
|
if w.Balance.LessThan(decimal.Zero) {
|
|
|
|
return repository.ErrBalanceInsufficient
|
|
|
|
}
|
|
|
|
|
|
|
|
repo.transactions = append(repo.transactions, entity.WalletTransaction{
|
|
|
|
OrderID: orderID,
|
|
|
|
UID: repo.uid,
|
|
|
|
WalletType: kind,
|
|
|
|
Asset: repo.asset,
|
|
|
|
Amount: amount,
|
|
|
|
Balance: w.Balance,
|
|
|
|
})
|
|
|
|
|
|
|
|
repo.localWalletBalance[kind] = w
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) Sub(kind wallet.Types, orderID string, amount decimal.Decimal) error {
|
|
|
|
return repo.Add(kind, orderID, decimal.Zero.Sub(amount))
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) AddTransaction(txID int64, orderID string, brand string, business wallet.BusinessName, kind wallet.Types, amount decimal.Decimal) {
|
|
|
|
balance := repo.LocalBalance(kind).Add(amount)
|
|
|
|
repo.transactions = append(repo.transactions, entity.WalletTransaction{
|
|
|
|
TransactionID: txID,
|
|
|
|
OrderID: orderID,
|
|
|
|
Brand: brand,
|
|
|
|
UID: repo.uid,
|
|
|
|
WalletType: kind,
|
|
|
|
BusinessType: business.ToInt8(),
|
|
|
|
Asset: repo.asset,
|
|
|
|
Amount: amount,
|
|
|
|
Balance: balance,
|
|
|
|
CreateAt: time.Now().UTC().UnixNano(),
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) Commit(ctx context.Context) error {
|
|
|
|
// 事務隔離等級設定
|
|
|
|
rc := &sql.TxOptions{
|
|
|
|
Isolation: sql.LevelReadCommitted,
|
|
|
|
ReadOnly: false,
|
|
|
|
}
|
|
|
|
|
|
|
|
err := repo.db.Transaction(func(tx *gorm.DB) error {
|
|
|
|
for _, w := range repo.localWalletBalance {
|
|
|
|
err := tx.WithContext(ctx).
|
|
|
|
Model(&entity.Wallet{}).
|
|
|
|
Where("id = ?", w.ID).
|
|
|
|
UpdateColumns(map[string]any{
|
|
|
|
"balance": w.Balance,
|
|
|
|
"update_time": time.Now().UTC().Unix(),
|
|
|
|
}).Error
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to update wallet id %d: %w", w.ID, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil // 所有更新成功才 return nil
|
|
|
|
}, rc)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("update uid: %s asset: %s error: %w", repo.uid, repo.asset, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) GetTransactions() []entity.WalletTransaction {
|
|
|
|
return repo.transactions
|
|
|
|
}
|
|
|
|
|
|
|
|
func (repo *userWallet) CommitOrder(ctx context.Context) error {
|
|
|
|
rc := &sql.TxOptions{
|
|
|
|
Isolation: sql.LevelReadCommitted,
|
|
|
|
ReadOnly: false,
|
|
|
|
}
|
|
|
|
err := repo.db.Transaction(func(tx *gorm.DB) error {
|
|
|
|
|
|
|
|
for id, balance := range repo.localOrderBalance {
|
|
|
|
err := tx.WithContext(ctx).
|
|
|
|
Model(&entity.Transaction{}).
|
|
|
|
Where("id = ?", id).
|
|
|
|
Update("balance", balance).Error
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to update order balance, id=%d, err=%w", id, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil // 所有更新成功才 return nil
|
|
|
|
}, rc)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("update uid: %s asset: %s error: %w", repo.uid, repo.asset, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
func (repo *userWallet) buildCommonWhereSQL(uid, asset string) *gorm.DB {
|
|
|
|
return repo.db.Where("uid = ?", uid).
|
|
|
|
Where("asset = ?", asset)
|
|
|
|
}
|
|
|
|
|
|
|
|
func notFoundError(err error) error {
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
return repository.ErrRecordNotFound
|
|
|
|
}
|
|
|
|
|
|
|
|
return err
|
2025-04-10 23:23:42 +00:00
|
|
|
}
|