package repository import ( "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet" "context" "database/sql" "errors" "fmt" "github.com/shopspring/decimal" "gorm.io/gorm" "gorm.io/gorm/clause" "time" ) // WalletService 代表一個使用者在某資產上的錢包服務, // 負責讀取/寫入資料庫並在記憶體暫存變動 type WalletService struct { db *gorm.DB uid string // 使用者識別碼 asset string // 資產代號 (如 BTC、ETH、TWD) localBalances map[wallet.Types]entity.Wallet // 暫存各類型錢包當前餘額 localOrderBalances map[int64]decimal.Decimal // 暫存各訂單變動後的餘額 transactions []entity.WalletTransaction // 暫存所有尚未落庫的錢包交易紀錄 } // NewWalletService 建立一個 WalletService 實例 func NewWalletService(db *gorm.DB, uid, asset string) repository.UserWalletService { return &WalletService{ db: db, uid: uid, asset: asset, localBalances: make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)), localOrderBalances: make(map[int64]decimal.Decimal, len(wallet.AllTypes)), } } // InitializeWallets 啟動時為新使用者初始化所有類型錢包,並寫入資料庫 func (s *WalletService) InitializeWallets(ctx context.Context, brand string) ([]entity.Wallet, error) { var wallets []entity.Wallet for _, t := range wallet.AllTypes { wallets = append(wallets, entity.Wallet{ Brand: brand, UID: s.uid, Asset: s.asset, Balance: decimal.Zero, Type: t, }) } if err := s.db.WithContext(ctx).Create(&wallets).Error; err != nil { return nil, err } // 將初始化後的錢包資料寫入本地緩存 for _, w := range wallets { s.localBalances[w.Type] = w } return wallets, nil } // GetAllBalances 查詢該使用者某資產所有錢包類型當前餘額 func (s *WalletService) GetAllBalances(ctx context.Context) ([]entity.Wallet, error) { var result []entity.Wallet err := s.db.WithContext(ctx). Where("uid = ? AND asset = ?", s.uid, s.asset). Select("id, asset, balance, type"). Find(&result).Error if err != nil { return nil, err } for _, w := range result { s.localBalances[w.Type] = w } return result, nil } // GetBalancesForTypes 查詢指定類型的錢包餘額,不上鎖 func (s *WalletService) GetBalancesForTypes(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) { var result []entity.Wallet err := s.db.WithContext(ctx). Where("uid = ? AND asset = ?", s.uid, s.asset). Where("type IN ?", kinds). Select("id, asset, balance, type"). Find(&result).Error if err != nil { return nil, translateNotFound(err) } for _, w := range result { s.localBalances[w.Type] = w } return result, nil } // GetBalancesForUpdate 查詢並鎖定指定類型的錢包 (FOR UPDATE) func (s *WalletService) GetBalancesForUpdate(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) { var result []entity.Wallet err := s.db.WithContext(ctx). Where("uid = ? AND asset = ?", s.uid, s.asset). Where("type IN ?", kinds). Clauses(clause.Locking{Strength: "UPDATE"}). Select("id, asset, balance, type"). Find(&result).Error if err != nil { return nil, translateNotFound(err) } for _, w := range result { s.localBalances[w.Type] = w } return result, nil } // CurrentBalance 從緩存中取得某種類型錢包的當前餘額 func (s *WalletService) CurrentBalance(kind wallet.Types) decimal.Decimal { if w, ok := s.localBalances[kind]; ok { return w.Balance } return decimal.Zero } // IncreaseBalance 在本地緩存新增餘額,並記錄一筆 WalletTransaction func (s *WalletService) IncreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error { w, ok := s.localBalances[kind] if !ok { return repository.ErrRecordNotFound } w.Balance = w.Balance.Add(amount) if w.Balance.LessThan(decimal.Zero) { return repository.ErrBalanceInsufficient } s.transactions = append(s.transactions, entity.WalletTransaction{ OrderID: orderID, UID: s.uid, WalletType: kind, Asset: s.asset, Amount: amount, Balance: w.Balance, }) s.localBalances[kind] = w return nil } // DecreaseBalance 本質上是 IncreaseBalance 的負數版本 func (s *WalletService) DecreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error { return s.IncreaseBalance(kind, orderID, amount.Neg()) } // PrepareTransactions 為每筆暫存的 WalletTransaction 填入共用欄位 (txID, brand, businessType), // 並回傳完整可落庫的切片 func (s *WalletService) PrepareTransactions( txID int64, orderID, brand string, businessType wallet.BusinessName, ) []entity.WalletTransaction { for i := range s.transactions { s.transactions[i].TransactionID = txID s.transactions[i].OrderID = orderID s.transactions[i].Brand = brand s.transactions[i].BusinessType = businessType.ToInt8() } return s.transactions } // PersistBalances 寫入本地緩存中所有錢包的最終餘額到資料庫 func (s *WalletService) PersistBalances(ctx context.Context) error { return s.db.Transaction(func(tx *gorm.DB) error { for _, w := range s.localBalances { if err := tx.WithContext(ctx). Model(&entity.Wallet{}). Where("id = ?", w.ID). UpdateColumns(map[string]interface{}{ "balance": w.Balance, "update_at": time.Now().Unix(), }).Error; err != nil { return fmt.Errorf("更新錢包餘額失敗 (id=%d): %w", w.ID, err) } } return nil }, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) } // PersistOrderBalances 寫入所有訂單錢包的最終餘額到 transaction 表 func (s *WalletService) PersistOrderBalances(ctx context.Context) error { return s.db.Transaction(func(tx *gorm.DB) error { for id, bal := range s.localOrderBalances { if err := tx.WithContext(ctx). Model(&entity.Transaction{}). Where("id = ?", id). Update("post_transfer_balance", bal).Error; err != nil { return fmt.Errorf("更新訂單錢包餘額失敗 (id=%d): %w", id, err) } } return nil }, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) } func (s *WalletService) HasAvailableBalance(ctx context.Context) (bool, error) { var exists bool err := s.db.WithContext(ctx). Model(&entity.Wallet{}). Select("1"). Where("uid = ? AND asset = ?", s.uid, s.asset). Where("type = ?", wallet.TypeAvailable). Limit(1). Scan(&exists).Error if err != nil { return false, err } return exists, nil } func (s *WalletService) GetOrderBalance(ctx context.Context, orderID string) (entity.Transaction, error) { var t entity.Transaction err := s.db.WithContext(ctx). Where("order_id = ?", orderID). Take(&t).Error if err != nil { return entity.Transaction{}, translateNotFound(err) } s.localOrderBalances[t.ID] = t.PostTransferBalance return t, nil } func (s *WalletService) GetOrderBalanceForUpdate(ctx context.Context, orderID string) (entity.Transaction, error) { var t entity.Transaction err := s.db.WithContext(ctx). Where("order_id = ?", orderID). Clauses(clause.Locking{Strength: "UPDATE"}). Take(&t).Error if err != nil { return entity.Transaction{}, translateNotFound(err) } s.localOrderBalances[t.ID] = t.PostTransferBalance return t, nil } func (s *WalletService) ClearCache() { s.localBalances = make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)) s.localOrderBalances = make(map[int64]decimal.Decimal, len(wallet.AllTypes)) s.transactions = nil } // translateNotFound 將 GORM 的 RecordNotFound 轉為自訂錯誤 func translateNotFound(err error) error { if errors.Is(err, gorm.ErrRecordNotFound) { return repository.ErrRecordNotFound } return err }