package repository import ( "app-cloudep-trade-service/internal/domain" "app-cloudep-trade-service/internal/domain/repository" "app-cloudep-trade-service/internal/model" "context" "github.com/shopspring/decimal" "github.com/zeromicro/go-zero/core/stores/sqlx" ) // 用戶某個幣種餘額 type userLocalWallet struct { wm model.WalletModel txConn sqlx.SqlConn uid string currency string // local wallet 相關計算的餘額存在這裡 walletBalance map[domain.WalletType]model.Wallet // local order wallet 相關計算的餘額存在這裡 localOrderBalance map[int64]decimal.Decimal // local wallet 內所有餘額變化紀錄 transactions []model.WalletJournal } func (use *userLocalWallet) GetBalancesByID(ctx context.Context, ids []int64, opts ...repository.WalletOperatorOption) ([]model.Wallet, error) { o := repository.ApplyOptions(opts...) tx := use.txConn if o.Tx != nil { tx = *o.Tx } wallets, err := use.wm.BalancesByIDs(ctx, ids, o.WithLock, tx) if err != nil { return nil, err } for _, wallet := range wallets { use.walletBalance[wallet.WalletType] = wallet } return wallets, nil } // LocalBalance 內存餘額 func (use *userLocalWallet) LocalBalance(kind domain.WalletType) decimal.Decimal { wallet, ok := use.walletBalance[kind] if !ok { return decimal.Zero } return wallet.Balance } func (use *userLocalWallet) Balances(ctx context.Context, kind []domain.WalletType, opts ...repository.WalletOperatorOption) ([]model.Wallet, error) { o := repository.ApplyOptions(opts...) tx := use.txConn if o.Tx != nil { tx = *o.Tx } wallets, err := use.wm.Balances(ctx, model.BalanceReq{ UID: []string{use.uid}, Currency: []string{use.currency}, Kind: kind, }, o.WithLock, tx) if err != nil { return nil, err } for _, wallet := range wallets { use.walletBalance[wallet.WalletType] = wallet } return wallets, nil } func NewUserWalletOperator(uid, currency string, wm model.WalletModel, txConn sqlx.SqlConn) repository.UserWalletOperator { return &userLocalWallet{ wm: wm, txConn: txConn, uid: uid, currency: currency, walletBalance: make(map[domain.WalletType]model.Wallet, len(domain.AllWalletType)), localOrderBalance: make(map[int64]decimal.Decimal, len(domain.AllWalletType)), } }