package repository import ( "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet" "context" "fmt" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/clause" "testing" "time" ) func TestWalletService_InitializeWallets(t *testing.T) { _, db, teardown, err := SetupTestWalletRepository() assert.NoError(t, err) defer teardown() // 先手動建表 createTable := ` CREATE TABLE IF NOT EXISTS wallet ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, brand VARCHAR(50) NOT NULL, uid VARCHAR(64) NOT NULL, asset VARCHAR(32) NOT NULL, balance DECIMAL(30,18) NOT NULL DEFAULT 0, type TINYINT NOT NULL, create_at INTEGER NOT NULL DEFAULT 0, update_at INTEGER NOT NULL DEFAULT 0, PRIMARY KEY (id), UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` assert.NoError(t, db.Exec(createTable).Error) type args struct { uid string asset string brand string } tests := []struct { name string args args wantCount int wantErr bool validateDB bool }{ { name: "正常初始化一次", args: args{uid: "user1", asset: "BTC", brand: "brandA"}, wantCount: len(wallet.AllTypes), }, { name: "再次初始化同一 UID/asset/brand,應因 UNIQUE KEY 失敗", args: args{uid: "user1", asset: "BTC", brand: "brandA"}, wantErr: true, validateDB: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { service := NewWalletService(db, tt.args.uid, tt.args.asset) ctx := context.Background() got, err := service.InitializeWallets(ctx, tt.args.brand) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) // 回傳 slice 長度應等於 AllTypes assert.Len(t, got, tt.wantCount) if tt.validateDB { // 再查一次 DB,確認確實寫入 var dbRows []entity.Wallet err := db.WithContext(ctx). Where("uid = ? AND asset = ? AND brand = ?", tt.args.uid, tt.args.asset, tt.args.brand). Find(&dbRows).Error assert.NoError(t, err) assert.Len(t, dbRows, tt.wantCount) // 檢查每一筆都初始為零 for _, w := range dbRows { assert.Equal(t, decimal.Zero, w.Balance) } } }) } } func TestWalletService_GetAllBalances(t *testing.T) { _, db, teardown, err := SetupTestWalletRepository() assert.NoError(t, err) defer teardown() // 建表 createTable := ` CREATE TABLE IF NOT EXISTS wallet ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, brand VARCHAR(50) NOT NULL, uid VARCHAR(64) NOT NULL, asset VARCHAR(32) NOT NULL, balance DECIMAL(30,18) NOT NULL DEFAULT 0, type TINYINT NOT NULL, create_at INTEGER NOT NULL DEFAULT 0, update_at INTEGER NOT NULL DEFAULT 0, PRIMARY KEY (id), UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` assert.NoError(t, db.Exec(createTable).Error) type seedRow struct { Brand string UID string Asset string Balance decimal.Decimal Type wallet.Types } type args struct { uid string asset string brand string } tests := []struct { name string args args seed []seedRow wantCount int }{ { name: "no data returns empty", args: args{"u1", "BTC", "b1"}, seed: nil, wantCount: 0, }, { name: "single user many types", args: args{"u1", "BTC", "b1"}, seed: []seedRow{ {"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]}, {"b1", "u1", "BTC", decimal.NewFromInt(20), wallet.AllTypes[1]}, {"b1", "u1", "BTC", decimal.NewFromInt(30), wallet.AllTypes[2]}, }, wantCount: 3, }, { name: "mixed users and assets", args: args{"u2", "ETH", "b2"}, seed: []seedRow{ {"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]}, {"b2", "u2", "ETH", decimal.NewFromInt(15), wallet.AllTypes[1]}, {"b2", "u2", "ETH", decimal.NewFromInt(25), wallet.AllTypes[2]}, }, wantCount: 2, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() // 清空表 assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error) // 插入 seed for _, row := range tt.seed { err := db.WithContext(ctx).Exec( `INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, row.Brand, row.UID, row.Asset, row.Balance, row.Type, time.Now().Unix(), time.Now().Unix(), ).Error assert.NoError(t, err) } // 建立 service svc := NewWalletService(db, tt.args.uid, tt.args.asset) // 呼叫 GetAllBalances got, err := svc.GetAllBalances(ctx) assert.NoError(t, err) assert.Len(t, got, tt.wantCount) // 檢查回傳資料與本地快取一致 ws := svc.(*WalletService) for _, w := range got { // 回傳的每筆都應該存在於 localBalances cached, ok := ws.localBalances[w.Type] assert.True(t, ok) assert.Equal(t, w.Balance, cached.Balance) assert.Equal(t, w.Asset, cached.Asset) assert.Equal(t, w.Type, cached.Type) } }) } } func TestWalletService_GetBalancesForTypes(t *testing.T) { _, db, teardown, err := SetupTestWalletRepository() assert.NoError(t, err) defer teardown() // 建表 createTable := ` CREATE TABLE IF NOT EXISTS wallet ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, brand VARCHAR(50) NOT NULL, uid VARCHAR(64) NOT NULL, asset VARCHAR(32) NOT NULL, balance DECIMAL(30,18) NOT NULL DEFAULT 0, type TINYINT NOT NULL, create_at INTEGER NOT NULL DEFAULT 0, update_at INTEGER NOT NULL DEFAULT 0, PRIMARY KEY (id), UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` assert.NoError(t, db.Exec(createTable).Error) type seedRow struct { Brand string UID string Asset string Balance decimal.Decimal Type wallet.Types } tests := []struct { name string uid string asset string brand string seed []seedRow kinds []wallet.Types wantCount int }{ { name: "no matching types returns empty", uid: "user1", asset: "BTC", brand: "b1", seed: []seedRow{ {"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]}, }, kinds: []wallet.Types{wallet.AllTypes[1]}, wantCount: 0, }, { name: "single type match", uid: "user1", asset: "BTC", brand: "b1", seed: []seedRow{ {"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]}, {"b1", "user1", "BTC", decimal.NewFromInt(7), wallet.AllTypes[1]}, }, kinds: []wallet.Types{wallet.AllTypes[1]}, wantCount: 1, }, { name: "multiple type matches", uid: "user2", asset: "ETH", brand: "b2", seed: []seedRow{ {"b2", "user2", "ETH", decimal.NewFromInt(3), wallet.AllTypes[0]}, {"b2", "user2", "ETH", decimal.NewFromInt(8), wallet.AllTypes[2]}, {"b2", "user2", "ETH", decimal.NewFromInt(10), wallet.AllTypes[1]}, }, kinds: []wallet.Types{wallet.AllTypes[0], wallet.AllTypes[2]}, wantCount: 2, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() // 清表 assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error) // 插入種子資料 for _, r := range tt.seed { assert.NoError(t, db.Exec( `INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, r.Brand, r.UID, r.Asset, r.Balance, r.Type, time.Now().Unix(), time.Now().Unix(), ).Error) } // 建立 service svc := NewWalletService(db, tt.uid, tt.asset) // 呼叫 GetBalancesForTypes got, err := svc.GetBalancesForTypes(ctx, tt.kinds) assert.NoError(t, err) assert.Len(t, got, tt.wantCount) // 檢查每筆結果皆正確緩存 ws := svc.(*WalletService) for _, w := range got { cached, ok := ws.localBalances[w.Type] assert.True(t, ok) assert.Equal(t, w.Balance, cached.Balance) assert.Equal(t, w.Asset, cached.Asset) } }) } } func TestForUpdateLockBehavior(t *testing.T) { // 建立測試環境 _, db, teardown, err := SetupTestWalletRepository() assert.NoError(t, err) defer teardown() // 建表 createTable := ` CREATE TABLE IF NOT EXISTS wallet ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, brand VARCHAR(50) NOT NULL, uid VARCHAR(64) NOT NULL, asset VARCHAR(32) NOT NULL, balance DECIMAL(30,18) NOT NULL DEFAULT 0, type TINYINT NOT NULL, create_at INTEGER NOT NULL DEFAULT 0, update_at INTEGER NOT NULL DEFAULT 0, PRIMARY KEY (id), UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` assert.NoError(t, db.Exec(createTable).Error) // 種子資料:一筆 wallet ctx := context.Background() initial := entity.Wallet{ Brand: "b1", UID: "user1", Asset: "BTC", Balance: decimal.NewFromInt(100), Type: wallet.AllTypes[0], } err = db.WithContext(ctx).Create(&initial).Error assert.NoError(t, err) // 讀回自動產生的 ID var seeded entity.Wallet assert.NoError(t, db.Where("uid = ? AND asset = ?", initial.UID, initial.Asset). First(&seeded).Error) // 開啟第一個 transaction 並 SELECT … FOR UPDATE tx1 := db.Begin() var locked entity.Wallet err = tx1.Clauses(clause.Locking{Strength: "UPDATE"}). WithContext(ctx). Where("id = ?", seeded.ID). Take(&locked).Error assert.NoError(t, err) // 啟動 goroutine 嘗試在第二 transaction 更新該 row done := make(chan error, 1) go func() { tx2 := db.Begin() // 試圖更新,被 FOR UPDATE 鎖住前應該會 block err := tx2.WithContext(ctx). Model(&entity.Wallet{}). Where("id = ?", seeded.ID). Update("balance", seeded.Balance.Add(decimal.NewFromInt(50))).Error done <- err }() // 等 100ms 確認尚未完成 time.Sleep(100 * time.Millisecond) select { case err2 := <-done: t.Fatalf("expected update to be blocked, but completed with err=%v", err2) default: // 正在等待鎖 } // 釋放鎖 assert.NoError(t, tx1.Commit().Error) // 現在第二 transaction 應該很快完成 select { case err2 := <-done: assert.NoError(t, err2) case <-time.After(500 * time.Millisecond): t.Fatal("update did not complete after lock released") } } func TestWalletService_IncreaseBalance(t *testing.T) { uid := "user1" asset := "BTC" type testCase struct { name string kind wallet.Types initial *entity.Wallet // nil 表示不存在 orderID string amount decimal.Decimal wantErr error wantBalance decimal.Decimal wantTxCount int } tests := []testCase{ { name: "missing wallet type", kind: wallet.AllTypes[0], initial: nil, orderID: "ord1", amount: decimal.NewFromInt(10), wantErr: repository.ErrRecordNotFound, wantBalance: decimal.Zero, wantTxCount: 0, }, { name: "increase from zero", kind: wallet.AllTypes[1], initial: &entity.Wallet{Balance: decimal.Zero}, orderID: "ord2", amount: decimal.NewFromInt(15), wantErr: nil, wantBalance: decimal.NewFromInt(15), wantTxCount: 1, }, { name: "successful increment on non-zero", kind: wallet.AllTypes[2], initial: &entity.Wallet{Balance: decimal.NewFromInt(5)}, orderID: "ord3", amount: decimal.NewFromInt(7), wantErr: nil, wantBalance: decimal.NewFromInt(12), wantTxCount: 1, }, { name: "insufficient leads to error", kind: wallet.AllTypes[2], initial: &entity.Wallet{Balance: decimal.NewFromInt(3)}, orderID: "ord4", amount: decimal.NewFromInt(-5), wantErr: repository.ErrBalanceInsufficient, wantBalance: decimal.NewFromInt(3), wantTxCount: 0, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // 準備 WalletService svc := NewWalletService(nil, uid, asset).(*WalletService) // 初始化本地快取 if tc.initial != nil { svc.localBalances = map[wallet.Types]entity.Wallet{tc.kind: *tc.initial} } else { svc.localBalances = map[wallet.Types]entity.Wallet{} } // 清空交易紀錄 svc.transactions = nil // 執行 err := svc.IncreaseBalance(tc.kind, tc.orderID, tc.amount) // 驗證錯誤 if tc.wantErr != nil { assert.ErrorIs(t, err, tc.wantErr) } else { assert.NoError(t, err) } // 驗證快取餘額 gotBal := svc.CurrentBalance(tc.kind) assert.True(t, gotBal.Equal(tc.wantBalance), "balance = %s, want %s", gotBal, tc.wantBalance) // 驗證交易筆數 assert.Len(t, svc.transactions, tc.wantTxCount) if tc.wantTxCount > 0 { tx := svc.transactions[0] assert.Equal(t, tc.orderID, tx.OrderID) assert.Equal(t, uid, tx.UID) assert.Equal(t, asset, tx.Asset) assert.True(t, tx.Amount.Equal(tc.amount)) assert.True(t, tx.Balance.Equal(tc.wantBalance)) } }) } } func TestWalletService_PrepareTransactions(t *testing.T) { const ( txID = int64(42) orderID = "order-123" brand = "brandX" bizNameStr = "business-test" ) biz := wallet.BusinessName(bizNameStr) tests := []struct { name string initialTxs []entity.WalletTransaction wantCount int }{ { name: "no transactions returns empty slice", initialTxs: nil, wantCount: 0, }, { name: "single transaction is populated", initialTxs: []entity.WalletTransaction{ { OrderID: "placeholder", UID: "u1", WalletType: wallet.AllTypes[0], Asset: "BTC", Amount: decimal.NewFromInt(5), Balance: decimal.NewFromInt(10), }, }, wantCount: 1, }, { name: "multiple transactions are all populated", initialTxs: []entity.WalletTransaction{ {UID: "u1", WalletType: wallet.AllTypes[1], Asset: "ETH", Amount: decimal.NewFromInt(1), Balance: decimal.NewFromInt(1)}, {UID: "u2", WalletType: wallet.AllTypes[2], Asset: "TWD", Amount: decimal.NewFromInt(2), Balance: decimal.NewFromInt(2)}, {UID: "u3", WalletType: wallet.AllTypes[3%len(wallet.AllTypes)], Asset: "USD", Amount: decimal.NewFromInt(3), Balance: decimal.NewFromInt(3)}, }, wantCount: 3, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // 建立 WalletService 並注入初始交易 svc := &WalletService{ transactions: make([]entity.WalletTransaction, len(tc.initialTxs)), } copy(svc.transactions, tc.initialTxs) // 執行 PrepareTransactions result := svc.PrepareTransactions(txID, orderID, brand, biz) // 檢查回傳長度 assert.Len(t, result, tc.wantCount) assert.Len(t, svc.transactions, tc.wantCount) // 驗證每筆交易的共用欄位是否正確 for i := 0; i < tc.wantCount; i++ { tx := result[i] assert.Equal(t, txID, tx.TransactionID, "tx[%d].TransactionID", i) assert.Equal(t, orderID, tx.OrderID, "tx[%d].OrderID", i) assert.Equal(t, brand, tx.Brand, "tx[%d].Brand", i) assert.Equal(t, biz.ToInt8(), tx.BusinessType, "tx[%d].BusinessType", i) // 原有欄位保持不變 assert.Equal(t, tc.initialTxs[i].UID, tx.UID, "tx[%d].UID unchanged", i) assert.Equal(t, tc.initialTxs[i].WalletType, tx.WalletType, "tx[%d].WalletType unchanged", i) assert.Equal(t, tc.initialTxs[i].Asset, tx.Asset, "tx[%d].Asset unchanged", i) assert.True(t, tx.Amount.Equal(tc.initialTxs[i].Amount), "tx[%d].Amount unchanged", i) assert.True(t, tx.Balance.Equal(tc.initialTxs[i].Balance), "tx[%d].Balance unchanged", i) } }) } } func TestWalletService_PersistBalances(t *testing.T) { _, db, teardown, err := SetupTestWalletRepository() assert.NoError(t, err) defer teardown() // helper:删除表、重建表、插入两笔 seed 数据,返回这两笔带 ID 的 slice seedWallets := func() []entity.Wallet { // DROP + AutoMigrate assert.NoError(t, db.Migrator().DropTable(&entity.Wallet{})) assert.NoError(t, db.AutoMigrate(&entity.Wallet{})) base := []entity.Wallet{ {UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(10), Type: wallet.AllTypes[0]}, {UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(20), Type: wallet.AllTypes[1]}, } assert.NoError(t, db.Create(&base).Error) return base } type fields struct { updates map[wallet.Types]decimal.Decimal } type want struct { final []decimal.Decimal err bool } tests := []struct { name string fields fields want want }{ { name: "no local balances → no change", fields: fields{updates: map[wallet.Types]decimal.Decimal{}}, want: want{ final: []decimal.Decimal{decimal.NewFromInt(10), decimal.NewFromInt(20)}, err: false, }, }, { name: "update both balances", fields: fields{updates: map[wallet.Types]decimal.Decimal{ wallet.AllTypes[0]: decimal.NewFromInt(15), wallet.AllTypes[1]: decimal.NewFromInt(5), }}, want: want{ final: []decimal.Decimal{decimal.NewFromInt(15), decimal.NewFromInt(5)}, err: false, }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // seed 并拿到带 ID 的记录 seed := seedWallets() // 构造 localBalances:只有 Type 在 updates 里的才覆盖 localMap := make(map[wallet.Types]entity.Wallet, len(tc.fields.updates)) for _, w := range seed { if nb, ok := tc.fields.updates[w.Type]; ok { localMap[w.Type] = entity.Wallet{ ID: w.ID, Balance: nb, } } } // 初始化 service 并注入 localBalances svc := NewWalletService(db, "u1", "BTC").(*WalletService) svc.localBalances = localMap // 执行 PersistBalances err := svc.PersistBalances(context.Background()) if tc.want.err { assert.Error(t, err) return } assert.NoError(t, err) // 重新查 DB,按 ID 顺序比较余额 var got []entity.Wallet assert.NoError(t, db. Where("uid = ?", "u1"). Order("id"). Find(&got).Error) assert.Len(t, got, len(seed)) for i, w := range got { assert.Truef(t, w.Balance.Equal(tc.want.final[i]), "第 %d 条记录 (id=%d) 余额 = %s, 期望 %s", i, w.ID, w.Balance, tc.want.final[i], ) } }) } } func TestWalletService_PersistOrderBalances(t *testing.T) { _, db, teardown, err := SetupTestWalletRepository() assert.NoError(t, err) defer teardown() // helper:重建 transaction 表並 seed 資料 seedTransactions := func() []entity.Transaction { // DROP + AutoMigrate _ = db.Migrator().DropTable(&entity.Transaction{}) assert.NoError(t, db.AutoMigrate(&entity.Transaction{})) // 插入兩筆 transaction base := []entity.Transaction{ {PostTransferBalance: decimal.NewFromInt(100)}, {PostTransferBalance: decimal.NewFromInt(200)}, } assert.NoError(t, db.Create(&base).Error) return base } type fields struct { updates map[int64]decimal.Decimal } type want struct { final []decimal.Decimal err bool } tests := []struct { name string fields fields want want }{ { name: "no local order balances → no change", fields: fields{updates: map[int64]decimal.Decimal{}}, want: want{ final: []decimal.Decimal{decimal.NewFromInt(100), decimal.NewFromInt(200)}, err: false, }, }, { name: "update first order balance only", fields: fields{updates: map[int64]decimal.Decimal{ 1: decimal.NewFromInt(150), }}, want: want{ final: []decimal.Decimal{decimal.NewFromInt(150), decimal.NewFromInt(200)}, err: false, }, }, { name: "update both order balances", fields: fields{updates: map[int64]decimal.Decimal{ 1: decimal.NewFromInt(110), 2: decimal.NewFromInt(220), }}, want: want{ final: []decimal.Decimal{decimal.NewFromInt(110), decimal.NewFromInt(220)}, err: false, }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // seed 資料並拿回 slice(包含自動產生的 ID) seed := seedTransactions() // 建構 localOrderBalances:key 是 seed[i].ID localMap := make(map[int64]decimal.Decimal, len(tc.fields.updates)) for id, newBal := range tc.fields.updates { localMap[id] = newBal } // 初始化 service、注入 localOrderBalances svc := NewWalletService(db, "u1", "BTC").(*WalletService) svc.localOrderBalances = localMap // 執行 PersistOrderBalances err := svc.PersistOrderBalances(context.Background()) if tc.want.err { assert.Error(t, err) return } assert.NoError(t, err) // 依照 seed 的 ID 順序讀回資料庫 var got []entity.Transaction assert.NoError(t, db. Where("1 = 1"). Order("id"). Find(&got).Error) // 長度要和 seed 相同 assert.Len(t, got, len(seed)) // 比對每一筆 balance for i, tr := range got { assert.Truef(t, tr.PostTransferBalance.Equal(tc.want.final[i]), "第 %d 筆交易(id=%d) balance = %s, want %s", i, tr.ID, tr.PostTransferBalance, tc.want.final[i], ) } }) } } func TestWalletService_HasAvailableBalance(t *testing.T) { _, db, teardown, err := SetupTestWalletRepository() assert.NoError(t, err) defer teardown() // 建表 createTable := ` CREATE TABLE IF NOT EXISTS wallet ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, brand VARCHAR(50) NOT NULL, uid VARCHAR(64) NOT NULL, asset VARCHAR(32) NOT NULL, balance DECIMAL(30,18) NOT NULL DEFAULT 0, type TINYINT NOT NULL, create_at INTEGER NOT NULL DEFAULT 0, update_at INTEGER NOT NULL DEFAULT 0, PRIMARY KEY (id), UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` assert.NoError(t, db.Exec(createTable).Error) type seedRow struct { Brand string UID string Asset string Type wallet.Types Balance decimal.Decimal } tests := []struct { name string uid string asset string seed []seedRow wantExist bool }{ { name: "無任何紀錄", uid: "user1", asset: "BTC", seed: nil, wantExist: false, }, { name: "只有其他類型錢包", uid: "user1", asset: "BTC", seed: []seedRow{ {"", "user1", "BTC", wallet.TypeFreeze, decimal.Zero}, {"", "user1", "BTC", wallet.TypeUnconfirmed, decimal.Zero}, }, wantExist: false, }, { name: "已有可用錢包", uid: "user1", asset: "BTC", seed: []seedRow{ {"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(10)}, }, wantExist: true, }, { name: "不同 UID 不算", uid: "user2", asset: "BTC", seed: []seedRow{ {"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, }, wantExist: false, }, { name: "不同 Asset 不算", uid: "user1", asset: "ETH", seed: []seedRow{ {"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, }, wantExist: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // 每個子測試前都清空 wallet table assert.NoError(t, db.Exec("DELETE FROM wallet").Error) // seed 資料到 wallet for _, r := range tc.seed { w := entity.Wallet{ Brand: r.Brand, UID: r.UID, Asset: r.Asset, Type: r.Type, Balance: r.Balance, } assert.NoError(t, db.Create(&w).Error) } // 建 service svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService) got, err := svc.HasAvailableBalance(context.Background()) assert.NoError(t, err) assert.Equal(t, tc.wantExist, got) }) } } func SetupTestDB(t *testing.T) (*gorm.DB, func()) { host, port, _, tearDown, err := startMySQLContainer() assert.NoError(t, err) dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", MySQLUser, MySQLPassword, host, port, MySQLDatabase, ) db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) assert.NoError(t, err) // 建表 create := ` CREATE TABLE IF NOT EXISTS wallet ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID', brand VARCHAR(50) NOT NULL DEFAULT '' COMMENT '品牌', uid VARCHAR(64) NOT NULL COMMENT '使用者 UID', asset VARCHAR(32) NOT NULL COMMENT '資產代號', balance DECIMAL(30,18) UNSIGNED NOT NULL DEFAULT 0 COMMENT '餘額', type TINYINT NOT NULL COMMENT '錢包類型', create_at INTEGER NOT NULL DEFAULT 0 COMMENT '建立時間', update_at INTEGER NOT NULL DEFAULT 0 COMMENT '更新時間', PRIMARY KEY (id), UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type), KEY idx_uid (uid), KEY idx_brand (brand) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` assert.NoError(t, db.Exec(create).Error) return db, tearDown } func TestWalletService_GetBalancesForUpdate(t *testing.T) { db, tearDown := SetupTestDB(t) defer tearDown() type seedRow struct { Brand string UID string Asset string Type wallet.Types Balance decimal.Decimal } tests := []struct { name string uid string asset string seed []seedRow kinds []wallet.Types wantIDs []int64 expectErr bool }{ { name: "查詢空結果", uid: "u1", asset: "BTC", seed: nil, kinds: []wallet.Types{wallet.TypeAvailable}, wantIDs: nil, expectErr: false, }, { name: "單一類型查詢", uid: "u1", asset: "BTC", seed: []seedRow{ {"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, {"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)}, }, kinds: []wallet.Types{wallet.TypeFreeze}, wantIDs: []int64{2}, expectErr: false, }, { name: "多類型查詢", uid: "u1", asset: "BTC", seed: []seedRow{ {"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, {"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)}, {"", "u1", "BTC", wallet.TypeUnconfirmed, decimal.NewFromInt(3)}, }, kinds: []wallet.Types{wallet.TypeAvailable, wallet.TypeUnconfirmed}, wantIDs: []int64{3, 5}, expectErr: false, }, { name: "不同 UID 不列入", uid: "u2", asset: "BTC", seed: []seedRow{ {"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, }, kinds: []wallet.Types{wallet.TypeAvailable}, wantIDs: nil, expectErr: false, }, { name: "不同 Asset 不列入", uid: "u1", asset: "ETH", seed: []seedRow{ {"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)}, }, kinds: []wallet.Types{wallet.TypeAvailable}, wantIDs: nil, expectErr: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // 清空資料 assert.NoError(t, db.Exec("DELETE FROM wallet").Error) // Seed for _, r := range tc.seed { w := entity.Wallet{ Brand: r.Brand, UID: r.UID, Asset: r.Asset, Type: r.Type, Balance: r.Balance, } // Create will auto-assign incremental IDs assert.NoError(t, db.Create(&w).Error) } // 建 Service svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService) got, err := svc.GetBalancesForUpdate(context.Background(), tc.kinds) if tc.expectErr { assert.Error(t, err) return } assert.NoError(t, err) // 取回 IDs 並排序 var gotIDs []int64 for _, w := range got { gotIDs = append(gotIDs, w.ID) // localBalances 應該被更新 assert.Equal(t, w, svc.localBalances[w.Type]) } assert.Equal(t, tc.wantIDs, gotIDs) }) } }