From 47fb3139c2c3a656fdd7eb5ba9ba3c9567422a89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=A7=E9=A9=8A?= Date: Fri, 18 Apr 2025 17:10:40 +0800 Subject: [PATCH] feat: add wallet translation --- pkg/domain/repository/wallet.go | 52 +- pkg/repository/user_wallet.go | 359 ++++------ pkg/repository/user_wallet_test.go | 1025 ++++++++++++++++++++++++++++ pkg/repository/wallet.go | 4 +- pkg/usecase/wallet_tx_option.go | 8 +- pkg/usecase/wallet_tx_processer.go | 8 +- 6 files changed, 1196 insertions(+), 260 deletions(-) create mode 100644 pkg/repository/user_wallet_test.go diff --git a/pkg/domain/repository/wallet.go b/pkg/domain/repository/wallet.go index e2136c6..86e8de6 100644 --- a/pkg/domain/repository/wallet.go +++ b/pkg/domain/repository/wallet.go @@ -58,36 +58,32 @@ type BalanceQuery struct { Kinds []wallet.Types // 錢包類型(如可用、凍結等) } -// UserWalletService 專注於某位使用者在單一資產下的錢包操作邏輯 +// UserWalletService 定義了一個「單一使用者、單一資產」的錢包操作合約 type UserWalletService interface { - // Init 初始化錢包(如建立可用、凍結、未確認等錢包) - Init(ctx context.Context, uid, asset, brand string) ([]entity.Wallet, error) - // All 查詢所有錢包餘額 - All(ctx context.Context) ([]entity.Wallet, error) - // Get 查詢單一或多種類型的餘額 - Get(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) - // GetWithLock 查詢鎖定後的錢包(交易使用) - GetWithLock(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) - // LocalBalance 查詢記憶中的快取值(非查資料庫) - LocalBalance(kind wallet.Types) decimal.Decimal - // LockByIDs 根據錢包 ID 鎖定(資料一致性用) - LockByIDs(ctx context.Context, ids []int64) ([]entity.Wallet, error) - // CheckReady 檢查錢包是否已經存在並準備好(可用餘額的錢包) - CheckReady(ctx context.Context) (bool, error) - // Add 加值與扣款邏輯(含業務類別) - Add(kind wallet.Types, orderID string, amount decimal.Decimal) error - Sub(kind wallet.Types, orderID string, amount decimal.Decimal) error - // AddTransaction 新增一筆交易紀錄(建立資料) - AddTransaction(txID int64, orderID string, brand string, business wallet.BusinessName, kind wallet.Types, amount decimal.Decimal) - Transactions( + // InitializeWallets 為新使用者初始化所有錢包類型並寫入資料庫 + InitializeWallets(ctx context.Context, brand string) ([]entity.Wallet, error) + // GetAllBalances 查詢此使用者此資產下所有錢包類型的當前餘額 + GetAllBalances(ctx context.Context) ([]entity.Wallet, error) + // GetBalancesForTypes 查詢指定錢包類型的一組餘額(不加鎖) + GetBalancesForTypes(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) + // GetBalancesForUpdate 查詢並鎖定指定錢包類型(FOR UPDATE) + GetBalancesForUpdate(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) + // CurrentBalance 從本地緩存取得某種錢包類型的餘額 + CurrentBalance(kind wallet.Types) decimal.Decimal + // IncreaseBalance 增加指定錢包類型的餘額,並累積一筆交易紀錄 + IncreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error + // DecreaseBalance 減少指定錢包類型的餘額(等同於 IncreaseBalance 的負數版本) + DecreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error + // PrepareTransactions 為所有暫存交易紀錄填上 TXID/OrderID/Brand/BusinessType,並回傳可落庫的切片 + PrepareTransactions( txID int64, - orderID string, - brand string, + orderID, brand string, businessType wallet.BusinessName, ) []entity.WalletTransaction - - // Commit 提交所有操作(更新錢包與新增交易紀錄) - Commit(ctx context.Context) error - // CommitOrder 提交所有訂單 - CommitOrder(ctx context.Context) error + // PersistBalances 將本地緩存中所有錢包最終餘額批次寫入資料庫 + PersistBalances(ctx context.Context) error + // PersistOrderBalances 將本地緩存中所有訂單相關餘額批次寫入 transaction 表 + PersistOrderBalances(ctx context.Context) error + // HasAvailableBalance 確認此使用者此資產是否已有可用餘額錢包 + HasAvailableBalance(ctx context.Context) (bool, error) } diff --git a/pkg/repository/user_wallet.go b/pkg/repository/user_wallet.go index e4180a0..27002df 100644 --- a/pkg/repository/user_wallet.go +++ b/pkg/repository/user_wallet.go @@ -14,152 +14,191 @@ import ( "time" ) -// 用戶某個幣種餘額 -type userWallet struct { - db *gorm.DB - uid string - asset string - - // local wallet 相關計算的餘額存在這裡 - localWalletBalance map[wallet.Types]entity.Wallet - // local order wallet 相關計算的餘額存在這裡 - localOrderBalance map[int64]decimal.Decimal - // local wallet 內所有餘額變化紀錄 - transactions []entity.WalletTransaction +// WalletService 代表一個使用者在某資產上的錢包服務, +// 負責讀取/寫入資料庫並在記憶體暫存變動 +type WalletService struct { + db *gorm.DB + uid string // 使用者識別碼 + asset string // 資產代號 (如 BTC、ETH、TWD) + localBalances map[wallet.Types]entity.Wallet // 暫存各類型錢包當前餘額 + localOrderBalances map[int64]decimal.Decimal // 暫存各訂單變動後的餘額 + transactions []entity.WalletTransaction // 暫存所有尚未落庫的錢包交易紀錄 } -func NewUserWallet(db *gorm.DB, uid, asset string) repository.UserWalletService { - return &userWallet{ +// NewWalletService 建立一個 WalletService 實例 +func NewWalletService(db *gorm.DB, uid, asset string) repository.UserWalletService { + return &WalletService{ db: db, uid: uid, asset: asset, - localWalletBalance: make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)), - localOrderBalance: make(map[int64]decimal.Decimal, len(wallet.AllTypes)), + localBalances: make(map[wallet.Types]entity.Wallet, len(wallet.AllTypes)), + localOrderBalances: make(map[int64]decimal.Decimal, len(wallet.AllTypes)), } } -func (repo *userWallet) Init(ctx context.Context, uid, asset, brand string) ([]entity.Wallet, error) { - wallets := make([]entity.Wallet, 0, len(wallet.AllTypes)) +// InitializeWallets 啟動時為新使用者初始化所有類型錢包,並寫入資料庫 +func (s *WalletService) InitializeWallets(ctx context.Context, brand string) ([]entity.Wallet, error) { + var wallets []entity.Wallet for _, t := range wallet.AllTypes { - balance := decimal.Zero - wallets = append(wallets, entity.Wallet{ Brand: brand, - UID: uid, - Asset: asset, - Balance: balance, + UID: s.uid, + Asset: s.asset, + Balance: decimal.Zero, Type: t, }) } - - if err := repo.db.WithContext(ctx).Create(&wallets).Error; err != nil { + if err := s.db.WithContext(ctx).Create(&wallets).Error; err != nil { return nil, err } - - for _, v := range wallets { - repo.localWalletBalance[v.Type] = v + // 將初始化後的錢包資料寫入本地緩存 + for _, w := range wallets { + s.localBalances[w.Type] = w } - return wallets, nil } -func (repo *userWallet) All(ctx context.Context) ([]entity.Wallet, error) { +// GetAllBalances 查詢該使用者某資產所有錢包類型當前餘額 +func (s *WalletService) GetAllBalances(ctx context.Context) ([]entity.Wallet, error) { var result []entity.Wallet - - err := repo.buildCommonWhereSQL(repo.uid, repo.asset). - WithContext(ctx). - Select("id, crypto, balance, type"). + err := s.db.WithContext(ctx). + Where("uid = ? AND asset = ?", s.uid, s.asset). + Select("id, asset, balance, type"). Find(&result).Error - if err != nil { - return []entity.Wallet{}, err + return nil, err } - - for _, v := range result { - repo.localWalletBalance[v.Type] = v + for _, w := range result { + s.localBalances[w.Type] = w } - return result, nil } -func (repo *userWallet) Get(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) { - var wallets []entity.Wallet - - err := repo.buildCommonWhereSQL(repo.uid, repo.asset). - WithContext(ctx). - Model(&entity.Wallet{}). - Select("id, crypto, balance, type"). +// GetBalancesForTypes 查詢指定類型的錢包餘額,不上鎖 +func (s *WalletService) GetBalancesForTypes(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) { + var result []entity.Wallet + err := s.db.WithContext(ctx). + Where("uid = ? AND asset = ?", s.uid, s.asset). Where("type IN ?", kinds). - Find(&wallets).Error - + Select("id, asset, balance, type"). + Find(&result).Error if err != nil { - return []entity.Wallet{}, notFoundError(err) + return nil, translateNotFound(err) } - - for _, w := range wallets { - repo.localWalletBalance[w.Type] = w + for _, w := range result { + s.localBalances[w.Type] = w } - - return wallets, nil + return result, nil } -func (repo *userWallet) GetWithLock(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) { - var wallets []entity.Wallet - - err := repo.buildCommonWhereSQL(repo.uid, repo.asset). - WithContext(ctx). - Model(&entity.Wallet{}). - Select("id, crypto, balance, type"). +// GetBalancesForUpdate 查詢並鎖定指定類型的錢包 (FOR UPDATE) +func (s *WalletService) GetBalancesForUpdate(ctx context.Context, kinds []wallet.Types) ([]entity.Wallet, error) { + var result []entity.Wallet + err := s.db.WithContext(ctx). + Where("uid = ? AND asset = ?", s.uid, s.asset). Where("type IN ?", kinds). Clauses(clause.Locking{Strength: "UPDATE"}). - Find(&wallets).Error - + Select("id, asset, balance, type"). + Find(&result).Error if err != nil { - return []entity.Wallet{}, notFoundError(err) + return nil, translateNotFound(err) } - - for _, w := range wallets { - repo.localWalletBalance[w.Type] = w + for _, w := range result { + s.localBalances[w.Type] = w } - - return wallets, nil + return result, nil } -func (repo *userWallet) LocalBalance(kind wallet.Types) decimal.Decimal { - w, ok := repo.localWalletBalance[kind] +// CurrentBalance 從緩存中取得某種類型錢包的當前餘額 +func (s *WalletService) CurrentBalance(kind wallet.Types) decimal.Decimal { + if w, ok := s.localBalances[kind]; ok { + return w.Balance + } + return decimal.Zero +} + +// IncreaseBalance 在本地緩存新增餘額,並記錄一筆 WalletTransaction +func (s *WalletService) IncreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error { + w, ok := s.localBalances[kind] if !ok { - return decimal.Zero + return repository.ErrRecordNotFound } - - return w.Balance + w.Balance = w.Balance.Add(amount) + if w.Balance.LessThan(decimal.Zero) { + return repository.ErrBalanceInsufficient + } + s.transactions = append(s.transactions, entity.WalletTransaction{ + OrderID: orderID, + UID: s.uid, + WalletType: kind, + Asset: s.asset, + Amount: amount, + Balance: w.Balance, + }) + s.localBalances[kind] = w + return nil } -func (repo *userWallet) LockByIDs(ctx context.Context, ids []int64) ([]entity.Wallet, error) { - var wallets []entity.Wallet - - err := repo.db.WithContext(ctx). - Model(&entity.Wallet{}). - Select("id, crypto, balance, type"). - Where("id IN ?", ids). - Clauses(clause.Locking{Strength: "UPDATE"}). - Find(&wallets).Error - - if err != nil { - return []entity.Wallet{}, notFoundError(err) - } - - for _, w := range wallets { - repo.localWalletBalance[w.Type] = w - } - - return wallets, nil +// DecreaseBalance 本質上是 IncreaseBalance 的負數版本 +func (s *WalletService) DecreaseBalance(kind wallet.Types, orderID string, amount decimal.Decimal) error { + return s.IncreaseBalance(kind, orderID, amount.Neg()) } -func (repo *userWallet) CheckReady(ctx context.Context) (bool, error) { +// PrepareTransactions 為每筆暫存的 WalletTransaction 填入共用欄位 (txID, brand, businessType), +// 並回傳完整可落庫的切片 +func (s *WalletService) PrepareTransactions( + txID int64, + orderID, brand string, + businessType wallet.BusinessName, +) []entity.WalletTransaction { + for i := range s.transactions { + s.transactions[i].TransactionID = txID + s.transactions[i].OrderID = orderID + s.transactions[i].Brand = brand + s.transactions[i].BusinessType = businessType.ToInt8() + } + return s.transactions +} + +// PersistBalances 寫入本地緩存中所有錢包的最終餘額到資料庫 +func (s *WalletService) PersistBalances(ctx context.Context) error { + return s.db.Transaction(func(tx *gorm.DB) error { + for _, w := range s.localBalances { + if err := tx.WithContext(ctx). + Model(&entity.Wallet{}). + Where("id = ?", w.ID). + UpdateColumns(map[string]interface{}{ + "balance": w.Balance, + "update_at": time.Now().Unix(), + }).Error; err != nil { + return fmt.Errorf("更新錢包餘額失敗 (id=%d): %w", w.ID, err) + } + } + return nil + }, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) +} + +// PersistOrderBalances 寫入所有訂單錢包的最終餘額到 transaction 表 +func (s *WalletService) PersistOrderBalances(ctx context.Context) error { + return s.db.Transaction(func(tx *gorm.DB) error { + for id, bal := range s.localOrderBalances { + if err := tx.WithContext(ctx). + Model(&entity.Transaction{}). + Where("id = ?", id). + Update("post_transfer_balance", bal).Error; err != nil { + return fmt.Errorf("更新訂單錢包餘額失敗 (id=%d): %w", id, err) + } + } + return nil + }, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) +} + +func (s *WalletService) HasAvailableBalance(ctx context.Context) (bool, error) { var exists bool - err := repo.buildCommonWhereSQL(repo.uid, repo.asset).WithContext(ctx). + err := s.db.WithContext(ctx). Model(&entity.Wallet{}). Select("1"). + Where("uid = ? AND asset = ?", s.uid, s.asset). Where("type = ?", wallet.TypeAvailable). Limit(1). Scan(&exists).Error @@ -171,134 +210,10 @@ func (repo *userWallet) CheckReady(ctx context.Context) (bool, error) { return exists, nil } -// Add 新增某種餘額餘額 -// 使用前 localWalletBalance 必須有資料,所以必須執行過 GetWithLock / All 才會有資料 -func (repo *userWallet) Add(kind wallet.Types, orderID string, amount decimal.Decimal) error { - w, ok := repo.localWalletBalance[kind] - if !ok { - return repository.ErrRecordNotFound - } - - w.Balance = w.Balance.Add(amount) - if w.Balance.LessThan(decimal.Zero) { - return repository.ErrBalanceInsufficient - } - - repo.transactions = append(repo.transactions, entity.WalletTransaction{ - OrderID: orderID, - UID: repo.uid, - WalletType: kind, - Asset: repo.asset, - Amount: amount, - Balance: w.Balance, - }) - - repo.localWalletBalance[kind] = w - - return nil -} - -func (repo *userWallet) Sub(kind wallet.Types, orderID string, amount decimal.Decimal) error { - return repo.Add(kind, orderID, decimal.Zero.Sub(amount)) -} - -// Transactions 為本次整筆交易 (txID) 給所有暫存的 WalletTransaction 設置共用欄位, -// 並回傳整批交易紀錄以便後續寫入資料庫。 -func (repo *userWallet) Transactions( - txID int64, - orderID string, - brand string, - businessType wallet.BusinessName, -) []entity.WalletTransaction { - for i := range repo.transactions { - repo.transactions[i].TransactionID = txID - repo.transactions[i].OrderID = orderID - repo.transactions[i].Brand = brand - repo.transactions[i].BusinessType = businessType.ToInt8() - } - return repo.transactions -} - -func (repo *userWallet) Commit(ctx context.Context) error { - // 事務隔離等級設定 - rc := &sql.TxOptions{ - Isolation: sql.LevelReadCommitted, - ReadOnly: false, - } - - err := repo.db.Transaction(func(tx *gorm.DB) error { - for _, w := range repo.localWalletBalance { - err := tx.WithContext(ctx). - Model(&entity.Wallet{}). - Where("id = ?", w.ID). - UpdateColumns(map[string]any{ - "balance": w.Balance, - "update_time": time.Now().UTC().Unix(), - }).Error - - if err != nil { - return fmt.Errorf("failed to update wallet id %d: %w", w.ID, err) - } - } - - return nil // 所有更新成功才 return nil - }, rc) - - if err != nil { - return fmt.Errorf("update uid: %s asset: %s error: %w", repo.uid, repo.asset, err) - } - - return nil -} - -func (repo *userWallet) GetTransactions() []entity.WalletTransaction { - return repo.transactions -} - -func (repo *userWallet) CommitOrder(ctx context.Context) error { - rc := &sql.TxOptions{ - Isolation: sql.LevelReadCommitted, - ReadOnly: false, - } - err := repo.db.Transaction(func(tx *gorm.DB) error { - - for id, balance := range repo.localOrderBalance { - err := tx.WithContext(ctx). - Model(&entity.Transaction{}). - Where("id = ?", id). - Update("balance", balance).Error - - if err != nil { - return fmt.Errorf("failed to update order balance, id=%d, err=%w", id, err) - } - } - - return nil // 所有更新成功才 return nil - }, rc) - - if err != nil { - return fmt.Errorf("update uid: %s asset: %s error: %w", repo.uid, repo.asset, err) - } - - return nil -} - -func (repo *userWallet) AddTransaction(txID int64, orderID string, brand string, business wallet.BusinessName, kind wallet.Types, amount decimal.Decimal) { - //TODO implement me - panic("implement me") -} - -// ============================================================================= - -func (repo *userWallet) buildCommonWhereSQL(uid, asset string) *gorm.DB { - return repo.db.Where("uid = ?", uid). - Where("asset = ?", asset) -} - -func notFoundError(err error) error { +// translateNotFound 將 GORM 的 RecordNotFound 轉為自訂錯誤 +func translateNotFound(err error) error { if errors.Is(err, gorm.ErrRecordNotFound) { return repository.ErrRecordNotFound } - return err } diff --git a/pkg/repository/user_wallet_test.go b/pkg/repository/user_wallet_test.go new file mode 100644 index 0000000..0882b74 --- /dev/null +++ b/pkg/repository/user_wallet_test.go @@ -0,0 +1,1025 @@ +package repository + +import ( + "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity" + "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository" + "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet" + "context" + "fmt" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "testing" + "time" +) + +func TestWalletService_InitializeWallets(t *testing.T) { + _, db, teardown, err := SetupTestWalletRepository() + assert.NoError(t, err) + defer teardown() + + // 先手動建表 + createTable := ` + CREATE TABLE IF NOT EXISTS wallet ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + brand VARCHAR(50) NOT NULL, + uid VARCHAR(64) NOT NULL, + asset VARCHAR(32) NOT NULL, + balance DECIMAL(30,18) NOT NULL DEFAULT 0, + type TINYINT NOT NULL, + create_at INTEGER NOT NULL DEFAULT 0, + update_at INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (id), + UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` + assert.NoError(t, db.Exec(createTable).Error) + + type args struct { + uid string + asset string + brand string + } + tests := []struct { + name string + args args + wantCount int + wantErr bool + validateDB bool + }{ + { + name: "正常初始化一次", + args: args{uid: "user1", asset: "BTC", brand: "brandA"}, + wantCount: len(wallet.AllTypes), + }, + { + name: "再次初始化同一 UID/asset/brand,應因 UNIQUE KEY 失敗", + args: args{uid: "user1", asset: "BTC", brand: "brandA"}, + wantErr: true, + validateDB: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := NewWalletService(db, tt.args.uid, tt.args.asset) + ctx := context.Background() + + got, err := service.InitializeWallets(ctx, tt.args.brand) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + // 回傳 slice 長度應等於 AllTypes + assert.Len(t, got, tt.wantCount) + + if tt.validateDB { + // 再查一次 DB,確認確實寫入 + var dbRows []entity.Wallet + err := db.WithContext(ctx). + Where("uid = ? AND asset = ? AND brand = ?", tt.args.uid, tt.args.asset, tt.args.brand). + Find(&dbRows).Error + assert.NoError(t, err) + assert.Len(t, dbRows, tt.wantCount) + // 檢查每一筆都初始為零 + for _, w := range dbRows { + assert.Equal(t, decimal.Zero, w.Balance) + } + } + }) + } +} + +func TestWalletService_GetAllBalances(t *testing.T) { + _, db, teardown, err := SetupTestWalletRepository() + assert.NoError(t, err) + defer teardown() + + // 建表 + createTable := ` + CREATE TABLE IF NOT EXISTS wallet ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + brand VARCHAR(50) NOT NULL, + uid VARCHAR(64) NOT NULL, + asset VARCHAR(32) NOT NULL, + balance DECIMAL(30,18) NOT NULL DEFAULT 0, + type TINYINT NOT NULL, + create_at INTEGER NOT NULL DEFAULT 0, + update_at INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (id), + UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` + assert.NoError(t, db.Exec(createTable).Error) + + type seedRow struct { + Brand string + UID string + Asset string + Balance decimal.Decimal + Type wallet.Types + } + + type args struct { + uid string + asset string + brand string + } + tests := []struct { + name string + args args + seed []seedRow + wantCount int + }{ + { + name: "no data returns empty", + args: args{"u1", "BTC", "b1"}, + seed: nil, + wantCount: 0, + }, + { + name: "single user many types", + args: args{"u1", "BTC", "b1"}, + seed: []seedRow{ + {"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]}, + {"b1", "u1", "BTC", decimal.NewFromInt(20), wallet.AllTypes[1]}, + {"b1", "u1", "BTC", decimal.NewFromInt(30), wallet.AllTypes[2]}, + }, + wantCount: 3, + }, + { + name: "mixed users and assets", + args: args{"u2", "ETH", "b2"}, + seed: []seedRow{ + {"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]}, + {"b2", "u2", "ETH", decimal.NewFromInt(15), wallet.AllTypes[1]}, + {"b2", "u2", "ETH", decimal.NewFromInt(25), wallet.AllTypes[2]}, + }, + wantCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + // 清空表 + assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error) + + // 插入 seed + for _, row := range tt.seed { + err := db.WithContext(ctx).Exec( + `INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + row.Brand, row.UID, row.Asset, row.Balance, row.Type, + time.Now().Unix(), time.Now().Unix(), + ).Error + assert.NoError(t, err) + } + + // 建立 service + svc := NewWalletService(db, tt.args.uid, tt.args.asset) + + // 呼叫 GetAllBalances + got, err := svc.GetAllBalances(ctx) + assert.NoError(t, err) + assert.Len(t, got, tt.wantCount) + + // 檢查回傳資料與本地快取一致 + ws := svc.(*WalletService) + for _, w := range got { + // 回傳的每筆都應該存在於 localBalances + cached, ok := ws.localBalances[w.Type] + assert.True(t, ok) + assert.Equal(t, w.Balance, cached.Balance) + assert.Equal(t, w.Asset, cached.Asset) + assert.Equal(t, w.Type, cached.Type) + } + }) + } +} + +func TestWalletService_GetBalancesForTypes(t *testing.T) { + _, db, teardown, err := SetupTestWalletRepository() + assert.NoError(t, err) + defer teardown() + + // 建表 + createTable := ` + CREATE TABLE IF NOT EXISTS wallet ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + brand VARCHAR(50) NOT NULL, + uid VARCHAR(64) NOT NULL, + asset VARCHAR(32) NOT NULL, + balance DECIMAL(30,18) NOT NULL DEFAULT 0, + type TINYINT NOT NULL, + create_at INTEGER NOT NULL DEFAULT 0, + update_at INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (id), + UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` + assert.NoError(t, db.Exec(createTable).Error) + + type seedRow struct { + Brand string + UID string + Asset string + Balance decimal.Decimal + Type wallet.Types + } + + tests := []struct { + name string + uid string + asset string + brand string + seed []seedRow + kinds []wallet.Types + wantCount int + }{ + { + name: "no matching types returns empty", + uid: "user1", asset: "BTC", brand: "b1", + seed: []seedRow{ + {"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]}, + }, + kinds: []wallet.Types{wallet.AllTypes[1]}, + wantCount: 0, + }, + { + name: "single type match", + uid: "user1", asset: "BTC", brand: "b1", + seed: []seedRow{ + {"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]}, + {"b1", "user1", "BTC", decimal.NewFromInt(7), wallet.AllTypes[1]}, + }, + kinds: []wallet.Types{wallet.AllTypes[1]}, + wantCount: 1, + }, + { + name: "multiple type matches", + uid: "user2", asset: "ETH", brand: "b2", + seed: []seedRow{ + {"b2", "user2", "ETH", decimal.NewFromInt(3), wallet.AllTypes[0]}, + {"b2", "user2", "ETH", decimal.NewFromInt(8), wallet.AllTypes[2]}, + {"b2", "user2", "ETH", decimal.NewFromInt(10), wallet.AllTypes[1]}, + }, + kinds: []wallet.Types{wallet.AllTypes[0], wallet.AllTypes[2]}, + wantCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + // 清表 + assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error) + + // 插入種子資料 + for _, r := range tt.seed { + assert.NoError(t, db.Exec( + `INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + r.Brand, r.UID, r.Asset, r.Balance, r.Type, + time.Now().Unix(), time.Now().Unix(), + ).Error) + } + + // 建立 service + svc := NewWalletService(db, tt.uid, tt.asset) + + // 呼叫 GetBalancesForTypes + got, err := svc.GetBalancesForTypes(ctx, tt.kinds) + assert.NoError(t, err) + assert.Len(t, got, tt.wantCount) + + // 檢查每筆結果皆正確緩存 + ws := svc.(*WalletService) + for _, w := range got { + cached, ok := ws.localBalances[w.Type] + assert.True(t, ok) + assert.Equal(t, w.Balance, cached.Balance) + assert.Equal(t, w.Asset, cached.Asset) + } + }) + } +} + +func TestForUpdateLockBehavior(t *testing.T) { + // 建立測試環境 + _, db, teardown, err := SetupTestWalletRepository() + assert.NoError(t, err) + defer teardown() + + // 建表 + createTable := ` + CREATE TABLE IF NOT EXISTS wallet ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + brand VARCHAR(50) NOT NULL, + uid VARCHAR(64) NOT NULL, + asset VARCHAR(32) NOT NULL, + balance DECIMAL(30,18) NOT NULL DEFAULT 0, + type TINYINT NOT NULL, + create_at INTEGER NOT NULL DEFAULT 0, + update_at INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (id), + UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` + assert.NoError(t, db.Exec(createTable).Error) + + // 種子資料:一筆 wallet + ctx := context.Background() + initial := entity.Wallet{ + Brand: "b1", + UID: "user1", + Asset: "BTC", + Balance: decimal.NewFromInt(100), + Type: wallet.AllTypes[0], + } + err = db.WithContext(ctx).Create(&initial).Error + assert.NoError(t, err) + + // 讀回自動產生的 ID + var seeded entity.Wallet + assert.NoError(t, db.Where("uid = ? AND asset = ?", initial.UID, initial.Asset). + First(&seeded).Error) + + // 開啟第一個 transaction 並 SELECT … FOR UPDATE + tx1 := db.Begin() + var locked entity.Wallet + err = tx1.Clauses(clause.Locking{Strength: "UPDATE"}). + WithContext(ctx). + Where("id = ?", seeded.ID). + Take(&locked).Error + assert.NoError(t, err) + + // 啟動 goroutine 嘗試在第二 transaction 更新該 row + done := make(chan error, 1) + go func() { + tx2 := db.Begin() + // 試圖更新,被 FOR UPDATE 鎖住前應該會 block + err := tx2.WithContext(ctx). + Model(&entity.Wallet{}). + Where("id = ?", seeded.ID). + Update("balance", seeded.Balance.Add(decimal.NewFromInt(50))).Error + done <- err + }() + + // 等 100ms 確認尚未完成 + time.Sleep(100 * time.Millisecond) + select { + case err2 := <-done: + t.Fatalf("expected update to be blocked, but completed with err=%v", err2) + default: + // 正在等待鎖 + } + + // 釋放鎖 + assert.NoError(t, tx1.Commit().Error) + + // 現在第二 transaction 應該很快完成 + select { + case err2 := <-done: + assert.NoError(t, err2) + case <-time.After(500 * time.Millisecond): + t.Fatal("update did not complete after lock released") + } +} + +func TestWalletService_IncreaseBalance(t *testing.T) { + uid := "user1" + asset := "BTC" + + type testCase struct { + name string + kind wallet.Types + initial *entity.Wallet // nil 表示不存在 + orderID string + amount decimal.Decimal + wantErr error + wantBalance decimal.Decimal + wantTxCount int + } + tests := []testCase{ + { + name: "missing wallet type", + kind: wallet.AllTypes[0], + initial: nil, + orderID: "ord1", + amount: decimal.NewFromInt(10), + wantErr: repository.ErrRecordNotFound, + wantBalance: decimal.Zero, + wantTxCount: 0, + }, + { + name: "increase from zero", + kind: wallet.AllTypes[1], + initial: &entity.Wallet{Balance: decimal.Zero}, + orderID: "ord2", + amount: decimal.NewFromInt(15), + wantErr: nil, + wantBalance: decimal.NewFromInt(15), + wantTxCount: 1, + }, + { + name: "successful increment on non-zero", + kind: wallet.AllTypes[2], + initial: &entity.Wallet{Balance: decimal.NewFromInt(5)}, + orderID: "ord3", + amount: decimal.NewFromInt(7), + wantErr: nil, + wantBalance: decimal.NewFromInt(12), + wantTxCount: 1, + }, + { + name: "insufficient leads to error", + kind: wallet.AllTypes[2], + initial: &entity.Wallet{Balance: decimal.NewFromInt(3)}, + orderID: "ord4", + amount: decimal.NewFromInt(-5), + wantErr: repository.ErrBalanceInsufficient, + wantBalance: decimal.NewFromInt(3), + wantTxCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // 準備 WalletService + svc := NewWalletService(nil, uid, asset).(*WalletService) + // 初始化本地快取 + if tc.initial != nil { + svc.localBalances = map[wallet.Types]entity.Wallet{tc.kind: *tc.initial} + } else { + svc.localBalances = map[wallet.Types]entity.Wallet{} + } + // 清空交易紀錄 + svc.transactions = nil + + // 執行 + err := svc.IncreaseBalance(tc.kind, tc.orderID, tc.amount) + + // 驗證錯誤 + if tc.wantErr != nil { + assert.ErrorIs(t, err, tc.wantErr) + } else { + assert.NoError(t, err) + } + + // 驗證快取餘額 + gotBal := svc.CurrentBalance(tc.kind) + assert.True(t, gotBal.Equal(tc.wantBalance), "balance = %s, want %s", gotBal, tc.wantBalance) + + // 驗證交易筆數 + assert.Len(t, svc.transactions, tc.wantTxCount) + if tc.wantTxCount > 0 { + tx := svc.transactions[0] + assert.Equal(t, tc.orderID, tx.OrderID) + assert.Equal(t, uid, tx.UID) + assert.Equal(t, asset, tx.Asset) + assert.True(t, tx.Amount.Equal(tc.amount)) + assert.True(t, tx.Balance.Equal(tc.wantBalance)) + } + }) + } +} + +func TestWalletService_PrepareTransactions(t *testing.T) { + const ( + txID = int64(42) + orderID = "order-123" + brand = "brandX" + bizNameStr = "business-test" + ) + biz := wallet.BusinessName(bizNameStr) + + tests := []struct { + name string + initialTxs []entity.WalletTransaction + wantCount int + }{ + { + name: "no transactions returns empty slice", + initialTxs: nil, + wantCount: 0, + }, + { + name: "single transaction is populated", + initialTxs: []entity.WalletTransaction{ + { + OrderID: "placeholder", + UID: "u1", + WalletType: wallet.AllTypes[0], + Asset: "BTC", + Amount: decimal.NewFromInt(5), + Balance: decimal.NewFromInt(10), + }, + }, + wantCount: 1, + }, + { + name: "multiple transactions are all populated", + initialTxs: []entity.WalletTransaction{ + {UID: "u1", WalletType: wallet.AllTypes[1], Asset: "ETH", Amount: decimal.NewFromInt(1), Balance: decimal.NewFromInt(1)}, + {UID: "u2", WalletType: wallet.AllTypes[2], Asset: "TWD", Amount: decimal.NewFromInt(2), Balance: decimal.NewFromInt(2)}, + {UID: "u3", WalletType: wallet.AllTypes[3%len(wallet.AllTypes)], Asset: "USD", Amount: decimal.NewFromInt(3), Balance: decimal.NewFromInt(3)}, + }, + wantCount: 3, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // 建立 WalletService 並注入初始交易 + svc := &WalletService{ + transactions: make([]entity.WalletTransaction, len(tc.initialTxs)), + } + copy(svc.transactions, tc.initialTxs) + + // 執行 PrepareTransactions + result := svc.PrepareTransactions(txID, orderID, brand, biz) + + // 檢查回傳長度 + assert.Len(t, result, tc.wantCount) + assert.Len(t, svc.transactions, tc.wantCount) + + // 驗證每筆交易的共用欄位是否正確 + for i := 0; i < tc.wantCount; i++ { + tx := result[i] + assert.Equal(t, txID, tx.TransactionID, "tx[%d].TransactionID", i) + assert.Equal(t, orderID, tx.OrderID, "tx[%d].OrderID", i) + assert.Equal(t, brand, tx.Brand, "tx[%d].Brand", i) + assert.Equal(t, biz.ToInt8(), tx.BusinessType, "tx[%d].BusinessType", i) + // 原有欄位保持不變 + assert.Equal(t, tc.initialTxs[i].UID, tx.UID, "tx[%d].UID unchanged", i) + assert.Equal(t, tc.initialTxs[i].WalletType, tx.WalletType, "tx[%d].WalletType unchanged", i) + assert.Equal(t, tc.initialTxs[i].Asset, tx.Asset, "tx[%d].Asset unchanged", i) + assert.True(t, tx.Amount.Equal(tc.initialTxs[i].Amount), "tx[%d].Amount unchanged", i) + assert.True(t, tx.Balance.Equal(tc.initialTxs[i].Balance), "tx[%d].Balance unchanged", i) + } + }) + } +} + +func TestWalletService_PersistBalances(t *testing.T) { + _, db, teardown, err := SetupTestWalletRepository() + assert.NoError(t, err) + defer teardown() + + // helper:删除表、重建表、插入两笔 seed 数据,返回这两笔带 ID 的 slice + seedWallets := func() []entity.Wallet { + // DROP + AutoMigrate + assert.NoError(t, db.Migrator().DropTable(&entity.Wallet{})) + assert.NoError(t, db.AutoMigrate(&entity.Wallet{})) + + base := []entity.Wallet{ + {UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(10), Type: wallet.AllTypes[0]}, + {UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(20), Type: wallet.AllTypes[1]}, + } + assert.NoError(t, db.Create(&base).Error) + return base + } + + type fields struct { + updates map[wallet.Types]decimal.Decimal + } + type want struct { + final []decimal.Decimal + err bool + } + + tests := []struct { + name string + fields fields + want want + }{ + { + name: "no local balances → no change", + fields: fields{updates: map[wallet.Types]decimal.Decimal{}}, + want: want{ + final: []decimal.Decimal{decimal.NewFromInt(10), decimal.NewFromInt(20)}, + err: false, + }, + }, + { + name: "update both balances", + fields: fields{updates: map[wallet.Types]decimal.Decimal{ + wallet.AllTypes[0]: decimal.NewFromInt(15), + wallet.AllTypes[1]: decimal.NewFromInt(5), + }}, + want: want{ + final: []decimal.Decimal{decimal.NewFromInt(15), decimal.NewFromInt(5)}, + err: false, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // seed 并拿到带 ID 的记录 + seed := seedWallets() + + // 构造 localBalances:只有 Type 在 updates 里的才覆盖 + localMap := make(map[wallet.Types]entity.Wallet, len(tc.fields.updates)) + for _, w := range seed { + if nb, ok := tc.fields.updates[w.Type]; ok { + localMap[w.Type] = entity.Wallet{ + ID: w.ID, + Balance: nb, + } + } + } + + // 初始化 service 并注入 localBalances + svc := NewWalletService(db, "u1", "BTC").(*WalletService) + svc.localBalances = localMap + + // 执行 PersistBalances + err := svc.PersistBalances(context.Background()) + if tc.want.err { + assert.Error(t, err) + return + } + assert.NoError(t, err) + + // 重新查 DB,按 ID 顺序比较余额 + var got []entity.Wallet + assert.NoError(t, db. + Where("uid = ?", "u1"). + Order("id"). + Find(&got).Error) + + assert.Len(t, got, len(seed)) + for i, w := range got { + assert.Truef(t, + w.Balance.Equal(tc.want.final[i]), + "第 %d 条记录 (id=%d) 余额 = %s, 期望 %s", + i, w.ID, w.Balance, tc.want.final[i], + ) + } + }) + } +} + +func TestWalletService_PersistOrderBalances(t *testing.T) { + _, db, teardown, err := SetupTestWalletRepository() + assert.NoError(t, err) + defer teardown() + + // helper:重建 transaction 表並 seed 資料 + seedTransactions := func() []entity.Transaction { + // DROP + AutoMigrate + _ = db.Migrator().DropTable(&entity.Transaction{}) + assert.NoError(t, db.AutoMigrate(&entity.Transaction{})) + + // 插入兩筆 transaction + base := []entity.Transaction{ + {PostTransferBalance: decimal.NewFromInt(100)}, + {PostTransferBalance: decimal.NewFromInt(200)}, + } + assert.NoError(t, db.Create(&base).Error) + return base + } + + type fields struct { + updates map[int64]decimal.Decimal + } + type want struct { + final []decimal.Decimal + err bool + } + + tests := []struct { + name string + fields fields + want want + }{ + { + name: "no local order balances → no change", + fields: fields{updates: map[int64]decimal.Decimal{}}, + want: want{ + final: []decimal.Decimal{decimal.NewFromInt(100), decimal.NewFromInt(200)}, + err: false, + }, + }, + { + name: "update first order balance only", + fields: fields{updates: map[int64]decimal.Decimal{ + 1: decimal.NewFromInt(150), + }}, + want: want{ + final: []decimal.Decimal{decimal.NewFromInt(150), decimal.NewFromInt(200)}, + err: false, + }, + }, + { + name: "update both order balances", + fields: fields{updates: map[int64]decimal.Decimal{ + 1: decimal.NewFromInt(110), + 2: decimal.NewFromInt(220), + }}, + want: want{ + final: []decimal.Decimal{decimal.NewFromInt(110), decimal.NewFromInt(220)}, + err: false, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // seed 資料並拿回 slice(包含自動產生的 ID) + seed := seedTransactions() + + // 建構 localOrderBalances:key 是 seed[i].ID + localMap := make(map[int64]decimal.Decimal, len(tc.fields.updates)) + for id, newBal := range tc.fields.updates { + localMap[id] = newBal + } + + // 初始化 service、注入 localOrderBalances + svc := NewWalletService(db, "u1", "BTC").(*WalletService) + svc.localOrderBalances = localMap + + // 執行 PersistOrderBalances + err := svc.PersistOrderBalances(context.Background()) + if tc.want.err { + assert.Error(t, err) + return + } + assert.NoError(t, err) + + // 依照 seed 的 ID 順序讀回資料庫 + var got []entity.Transaction + assert.NoError(t, db. + Where("1 = 1"). + Order("id"). + Find(&got).Error) + // 長度要和 seed 相同 + assert.Len(t, got, len(seed)) + // 比對每一筆 balance + for i, tr := range got { + assert.Truef(t, + tr.PostTransferBalance.Equal(tc.want.final[i]), + "第 %d 筆交易(id=%d) balance = %s, want %s", + i, tr.ID, tr.PostTransferBalance, tc.want.final[i], + ) + } + }) + } +} + +func TestWalletService_HasAvailableBalance(t *testing.T) { + _, db, teardown, err := SetupTestWalletRepository() + assert.NoError(t, err) + defer teardown() + + // 建表 + createTable := ` + CREATE TABLE IF NOT EXISTS wallet ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + brand VARCHAR(50) NOT NULL, + uid VARCHAR(64) NOT NULL, + asset VARCHAR(32) NOT NULL, + balance DECIMAL(30,18) NOT NULL DEFAULT 0, + type TINYINT NOT NULL, + create_at INTEGER NOT NULL DEFAULT 0, + update_at INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (id), + UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` + assert.NoError(t, db.Exec(createTable).Error) + + type seedRow struct { + Brand string + UID string + Asset string + Type wallet.Types + Balance decimal.Decimal + } + + tests := []struct { + name string + uid string + asset string + seed []seedRow + wantExist bool + }{ + { + name: "無任何紀錄", + uid: "user1", + asset: "BTC", + seed: nil, + wantExist: false, + }, + { + name: "只有其他類型錢包", + uid: "user1", + asset: "BTC", + seed: []seedRow{ + {"", "user1", "BTC", wallet.TypeFreeze, decimal.Zero}, + {"", "user1", "BTC", wallet.TypeUnconfirmed, decimal.Zero}, + }, + wantExist: false, + }, + { + name: "已有可用錢包", + uid: "user1", + asset: "BTC", + seed: []seedRow{ + {"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(10)}, + }, + wantExist: true, + }, + { + name: "不同 UID 不算", + uid: "user2", + asset: "BTC", + seed: []seedRow{ + {"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, + }, + wantExist: false, + }, + { + name: "不同 Asset 不算", + uid: "user1", + asset: "ETH", + seed: []seedRow{ + {"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, + }, + wantExist: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // 每個子測試前都清空 wallet table + assert.NoError(t, db.Exec("DELETE FROM wallet").Error) + + // seed 資料到 wallet + for _, r := range tc.seed { + w := entity.Wallet{ + Brand: r.Brand, + UID: r.UID, + Asset: r.Asset, + Type: r.Type, + Balance: r.Balance, + } + assert.NoError(t, db.Create(&w).Error) + } + + // 建 service + svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService) + got, err := svc.HasAvailableBalance(context.Background()) + assert.NoError(t, err) + assert.Equal(t, tc.wantExist, got) + }) + } +} + +func SetupTestDB(t *testing.T) (*gorm.DB, func()) { + host, port, _, tearDown, err := startMySQLContainer() + assert.NoError(t, err) + + dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", + MySQLUser, MySQLPassword, host, port, MySQLDatabase, + ) + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + assert.NoError(t, err) + + // 建表 + create := ` + CREATE TABLE IF NOT EXISTS wallet ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID', + brand VARCHAR(50) NOT NULL DEFAULT '' COMMENT '品牌', + uid VARCHAR(64) NOT NULL COMMENT '使用者 UID', + asset VARCHAR(32) NOT NULL COMMENT '資產代號', + balance DECIMAL(30,18) UNSIGNED NOT NULL DEFAULT 0 COMMENT '餘額', + type TINYINT NOT NULL COMMENT '錢包類型', + create_at INTEGER NOT NULL DEFAULT 0 COMMENT '建立時間', + update_at INTEGER NOT NULL DEFAULT 0 COMMENT '更新時間', + PRIMARY KEY (id), + UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type), + KEY idx_uid (uid), + KEY idx_brand (brand) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` + assert.NoError(t, db.Exec(create).Error) + + return db, tearDown +} + +func TestWalletService_GetBalancesForUpdate(t *testing.T) { + db, tearDown := SetupTestDB(t) + defer tearDown() + + type seedRow struct { + Brand string + UID string + Asset string + Type wallet.Types + Balance decimal.Decimal + } + + tests := []struct { + name string + uid string + asset string + seed []seedRow + kinds []wallet.Types + wantIDs []int64 + expectErr bool + }{ + { + name: "查詢空結果", + uid: "u1", + asset: "BTC", + seed: nil, + kinds: []wallet.Types{wallet.TypeAvailable}, + wantIDs: nil, + expectErr: false, + }, + { + name: "單一類型查詢", + uid: "u1", + asset: "BTC", + seed: []seedRow{ + {"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, + {"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)}, + }, + kinds: []wallet.Types{wallet.TypeFreeze}, + wantIDs: []int64{2}, + expectErr: false, + }, + { + name: "多類型查詢", + uid: "u1", + asset: "BTC", + seed: []seedRow{ + {"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, + {"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)}, + {"", "u1", "BTC", wallet.TypeUnconfirmed, decimal.NewFromInt(3)}, + }, + kinds: []wallet.Types{wallet.TypeAvailable, wallet.TypeUnconfirmed}, + wantIDs: []int64{3, 5}, + expectErr: false, + }, + { + name: "不同 UID 不列入", + uid: "u2", + asset: "BTC", + seed: []seedRow{ + {"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, + }, + kinds: []wallet.Types{wallet.TypeAvailable}, + wantIDs: nil, + expectErr: false, + }, + { + name: "不同 Asset 不列入", + uid: "u1", + asset: "ETH", + seed: []seedRow{ + {"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, + }, + kinds: []wallet.Types{wallet.TypeAvailable}, + wantIDs: nil, + expectErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // 清空資料 + assert.NoError(t, db.Exec("DELETE FROM wallet").Error) + // Seed + for _, r := range tc.seed { + w := entity.Wallet{ + Brand: r.Brand, + UID: r.UID, + Asset: r.Asset, + Type: r.Type, + Balance: r.Balance, + } + // Create will auto-assign incremental IDs + assert.NoError(t, db.Create(&w).Error) + } + + // 建 Service + svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService) + got, err := svc.GetBalancesForUpdate(context.Background(), tc.kinds) + if tc.expectErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + + // 取回 IDs 並排序 + var gotIDs []int64 + for _, w := range got { + gotIDs = append(gotIDs, w.ID) + // localBalances 應該被更新 + assert.Equal(t, w, svc.localBalances[w.Type]) + } + assert.Equal(t, tc.wantIDs, gotIDs) + }) + } +} diff --git a/pkg/repository/wallet.go b/pkg/repository/wallet.go index f8b3991..2b141c5 100644 --- a/pkg/repository/wallet.go +++ b/pkg/repository/wallet.go @@ -46,11 +46,11 @@ func (repo *WalletRepository) Transaction(fn func(db *gorm.DB) error) error { } func (repo *WalletRepository) Session(uid, asset string) repository.UserWalletService { - return NewUserWallet(repo.DB, uid, asset) + return NewWalletService(repo.DB, uid, asset) } func (repo *WalletRepository) SessionWithTx(db *gorm.DB, uid, asset string) repository.UserWalletService { - return NewUserWallet(db, uid, asset) + return NewWalletService(db, uid, asset) } func (repo *WalletRepository) InitWallets(ctx context.Context, param []repository.Wallet) error { diff --git a/pkg/usecase/wallet_tx_option.go b/pkg/usecase/wallet_tx_option.go index 3441c32..78ced43 100644 --- a/pkg/usecase/wallet_tx_option.go +++ b/pkg/usecase/wallet_tx_option.go @@ -31,7 +31,7 @@ func (use *WalletUseCase) withLockAvailable() walletActionOption { if !use.checkWalletExistence(uidAsset) { // 找不到錢包存不存在 - wStatus, err := w.CheckReady(ctx) + wStatus, err := w.HasAvailableBalance(ctx) if err != nil { return fmt.Errorf("failed to check wallet: %w", err) } @@ -43,7 +43,7 @@ func (use *WalletUseCase) withLockAvailable() walletActionOption { // return use.translateError(err) //} - if _, err := w.Init(ctx, tx.FromUID, tx.Asset, tx.Brand); err != nil { + if _, err := w.InitializeWallets(ctx, tx.Brand); err != nil { return err } @@ -53,7 +53,7 @@ func (use *WalletUseCase) withLockAvailable() walletActionOption { use.markWalletAsExisting(uidAsset) } - _, err := w.GetWithLock(ctx, []wallet.Types{wallet.TypeAvailable}) + _, err := w.GetBalancesForUpdate(ctx, []wallet.Types{wallet.TypeAvailable}) if err != nil { return err } @@ -65,7 +65,7 @@ func (use *WalletUseCase) withLockAvailable() walletActionOption { // subAvailable 減少用戶可用餘額 func (use *WalletUseCase) withSubAvailable() walletActionOption { return func(_ context.Context, tx *usecase.WalletTransferRequest, w repository.UserWalletService) error { - if err := w.Sub(wallet.TypeAvailable, tx.ReferenceOrderID, tx.Amount); err != nil { + if err := w.DecreaseBalance(wallet.TypeAvailable, tx.ReferenceOrderID, tx.Amount); err != nil { if errors.Is(err, repository.ErrBalanceInsufficient) { // todo 錯誤要看怎麼給(餘額不足) return fmt.Errorf("balance insufficient") diff --git a/pkg/usecase/wallet_tx_processer.go b/pkg/usecase/wallet_tx_processer.go index c34fcfa..33ff987 100644 --- a/pkg/usecase/wallet_tx_processer.go +++ b/pkg/usecase/wallet_tx_processer.go @@ -40,7 +40,7 @@ func (use *WalletUseCase) ProcessTransaction( // flows 會按照順序做.順序是重要的 for _, flow := range flows { // 1️⃣ 建立針對該使用者+資產的 UserWalletService - wSvc := repo.NewUserWallet(db, flow.UID, flow.Asset) + wSvc := repo.NewWalletService(db, flow.UID, flow.Asset) // 2️⃣ 依序執行所有定義好的錢包操作 for _, action := range flow.Actions { @@ -80,7 +80,7 @@ func (use *WalletUseCase) ProcessTransaction( for _, w := range wallets { walletTxs = append( walletTxs, - w.Transactions( + w.PrepareTransactions( txRecord.ID, txRecord.OrderID, req.Brand, @@ -96,10 +96,10 @@ func (use *WalletUseCase) ProcessTransaction( // 8️⃣ 最後才真正把錢包的餘額更新到資料庫(同一事務) for _, wSvc := range wallets { - if err := wSvc.Commit(ctx); err != nil { + if err := wSvc.PersistBalances(ctx); err != nil { return err } - if err := wSvc.CommitOrder(ctx); err != nil { + if err := wSvc.PersistOrderBalances(ctx); err != nil { return err } }