app-cloudep-wallet-service/pkg/repository/user_wallet_test.go

1026 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
"context"
"fmt"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"testing"
"time"
)
func TestWalletService_InitializeWallets(t *testing.T) {
_, db, teardown, err := SetupTestWalletRepository()
assert.NoError(t, err)
defer teardown()
// 先手動建表
createTable := `
CREATE TABLE IF NOT EXISTS wallet (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
brand VARCHAR(50) NOT NULL,
uid VARCHAR(64) NOT NULL,
asset VARCHAR(32) NOT NULL,
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
type TINYINT NOT NULL,
create_at INTEGER NOT NULL DEFAULT 0,
update_at INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (id),
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
assert.NoError(t, db.Exec(createTable).Error)
type args struct {
uid string
asset string
brand string
}
tests := []struct {
name string
args args
wantCount int
wantErr bool
validateDB bool
}{
{
name: "正常初始化一次",
args: args{uid: "user1", asset: "BTC", brand: "brandA"},
wantCount: len(wallet.AllTypes),
},
{
name: "再次初始化同一 UID/asset/brand應因 UNIQUE KEY 失敗",
args: args{uid: "user1", asset: "BTC", brand: "brandA"},
wantErr: true,
validateDB: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewWalletService(db, tt.args.uid, tt.args.asset)
ctx := context.Background()
got, err := service.InitializeWallets(ctx, tt.args.brand)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 回傳 slice 長度應等於 AllTypes
assert.Len(t, got, tt.wantCount)
if tt.validateDB {
// 再查一次 DB確認確實寫入
var dbRows []entity.Wallet
err := db.WithContext(ctx).
Where("uid = ? AND asset = ? AND brand = ?", tt.args.uid, tt.args.asset, tt.args.brand).
Find(&dbRows).Error
assert.NoError(t, err)
assert.Len(t, dbRows, tt.wantCount)
// 檢查每一筆都初始為零
for _, w := range dbRows {
assert.Equal(t, decimal.Zero, w.Balance)
}
}
})
}
}
func TestWalletService_GetAllBalances(t *testing.T) {
_, db, teardown, err := SetupTestWalletRepository()
assert.NoError(t, err)
defer teardown()
// 建表
createTable := `
CREATE TABLE IF NOT EXISTS wallet (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
brand VARCHAR(50) NOT NULL,
uid VARCHAR(64) NOT NULL,
asset VARCHAR(32) NOT NULL,
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
type TINYINT NOT NULL,
create_at INTEGER NOT NULL DEFAULT 0,
update_at INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (id),
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
assert.NoError(t, db.Exec(createTable).Error)
type seedRow struct {
Brand string
UID string
Asset string
Balance decimal.Decimal
Type wallet.Types
}
type args struct {
uid string
asset string
brand string
}
tests := []struct {
name string
args args
seed []seedRow
wantCount int
}{
{
name: "no data returns empty",
args: args{"u1", "BTC", "b1"},
seed: nil,
wantCount: 0,
},
{
name: "single user many types",
args: args{"u1", "BTC", "b1"},
seed: []seedRow{
{"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]},
{"b1", "u1", "BTC", decimal.NewFromInt(20), wallet.AllTypes[1]},
{"b1", "u1", "BTC", decimal.NewFromInt(30), wallet.AllTypes[2]},
},
wantCount: 3,
},
{
name: "mixed users and assets",
args: args{"u2", "ETH", "b2"},
seed: []seedRow{
{"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]},
{"b2", "u2", "ETH", decimal.NewFromInt(15), wallet.AllTypes[1]},
{"b2", "u2", "ETH", decimal.NewFromInt(25), wallet.AllTypes[2]},
},
wantCount: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
// 清空表
assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error)
// 插入 seed
for _, row := range tt.seed {
err := db.WithContext(ctx).Exec(
`INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
row.Brand, row.UID, row.Asset, row.Balance, row.Type,
time.Now().Unix(), time.Now().Unix(),
).Error
assert.NoError(t, err)
}
// 建立 service
svc := NewWalletService(db, tt.args.uid, tt.args.asset)
// 呼叫 GetAllBalances
got, err := svc.GetAllBalances(ctx)
assert.NoError(t, err)
assert.Len(t, got, tt.wantCount)
// 檢查回傳資料與本地快取一致
ws := svc.(*WalletService)
for _, w := range got {
// 回傳的每筆都應該存在於 localBalances
cached, ok := ws.localBalances[w.Type]
assert.True(t, ok)
assert.Equal(t, w.Balance, cached.Balance)
assert.Equal(t, w.Asset, cached.Asset)
assert.Equal(t, w.Type, cached.Type)
}
})
}
}
func TestWalletService_GetBalancesForTypes(t *testing.T) {
_, db, teardown, err := SetupTestWalletRepository()
assert.NoError(t, err)
defer teardown()
// 建表
createTable := `
CREATE TABLE IF NOT EXISTS wallet (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
brand VARCHAR(50) NOT NULL,
uid VARCHAR(64) NOT NULL,
asset VARCHAR(32) NOT NULL,
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
type TINYINT NOT NULL,
create_at INTEGER NOT NULL DEFAULT 0,
update_at INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (id),
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
assert.NoError(t, db.Exec(createTable).Error)
type seedRow struct {
Brand string
UID string
Asset string
Balance decimal.Decimal
Type wallet.Types
}
tests := []struct {
name string
uid string
asset string
brand string
seed []seedRow
kinds []wallet.Types
wantCount int
}{
{
name: "no matching types returns empty",
uid: "user1", asset: "BTC", brand: "b1",
seed: []seedRow{
{"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]},
},
kinds: []wallet.Types{wallet.AllTypes[1]},
wantCount: 0,
},
{
name: "single type match",
uid: "user1", asset: "BTC", brand: "b1",
seed: []seedRow{
{"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]},
{"b1", "user1", "BTC", decimal.NewFromInt(7), wallet.AllTypes[1]},
},
kinds: []wallet.Types{wallet.AllTypes[1]},
wantCount: 1,
},
{
name: "multiple type matches",
uid: "user2", asset: "ETH", brand: "b2",
seed: []seedRow{
{"b2", "user2", "ETH", decimal.NewFromInt(3), wallet.AllTypes[0]},
{"b2", "user2", "ETH", decimal.NewFromInt(8), wallet.AllTypes[2]},
{"b2", "user2", "ETH", decimal.NewFromInt(10), wallet.AllTypes[1]},
},
kinds: []wallet.Types{wallet.AllTypes[0], wallet.AllTypes[2]},
wantCount: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
// 清表
assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error)
// 插入種子資料
for _, r := range tt.seed {
assert.NoError(t, db.Exec(
`INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
r.Brand, r.UID, r.Asset, r.Balance, r.Type,
time.Now().Unix(), time.Now().Unix(),
).Error)
}
// 建立 service
svc := NewWalletService(db, tt.uid, tt.asset)
// 呼叫 GetBalancesForTypes
got, err := svc.GetBalancesForTypes(ctx, tt.kinds)
assert.NoError(t, err)
assert.Len(t, got, tt.wantCount)
// 檢查每筆結果皆正確緩存
ws := svc.(*WalletService)
for _, w := range got {
cached, ok := ws.localBalances[w.Type]
assert.True(t, ok)
assert.Equal(t, w.Balance, cached.Balance)
assert.Equal(t, w.Asset, cached.Asset)
}
})
}
}
func TestForUpdateLockBehavior(t *testing.T) {
// 建立測試環境
_, db, teardown, err := SetupTestWalletRepository()
assert.NoError(t, err)
defer teardown()
// 建表
createTable := `
CREATE TABLE IF NOT EXISTS wallet (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
brand VARCHAR(50) NOT NULL,
uid VARCHAR(64) NOT NULL,
asset VARCHAR(32) NOT NULL,
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
type TINYINT NOT NULL,
create_at INTEGER NOT NULL DEFAULT 0,
update_at INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (id),
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
assert.NoError(t, db.Exec(createTable).Error)
// 種子資料:一筆 wallet
ctx := context.Background()
initial := entity.Wallet{
Brand: "b1",
UID: "user1",
Asset: "BTC",
Balance: decimal.NewFromInt(100),
Type: wallet.AllTypes[0],
}
err = db.WithContext(ctx).Create(&initial).Error
assert.NoError(t, err)
// 讀回自動產生的 ID
var seeded entity.Wallet
assert.NoError(t, db.Where("uid = ? AND asset = ?", initial.UID, initial.Asset).
First(&seeded).Error)
// 開啟第一個 transaction 並 SELECT … FOR UPDATE
tx1 := db.Begin()
var locked entity.Wallet
err = tx1.Clauses(clause.Locking{Strength: "UPDATE"}).
WithContext(ctx).
Where("id = ?", seeded.ID).
Take(&locked).Error
assert.NoError(t, err)
// 啟動 goroutine 嘗試在第二 transaction 更新該 row
done := make(chan error, 1)
go func() {
tx2 := db.Begin()
// 試圖更新,被 FOR UPDATE 鎖住前應該會 block
err := tx2.WithContext(ctx).
Model(&entity.Wallet{}).
Where("id = ?", seeded.ID).
Update("balance", seeded.Balance.Add(decimal.NewFromInt(50))).Error
done <- err
}()
// 等 100ms 確認尚未完成
time.Sleep(100 * time.Millisecond)
select {
case err2 := <-done:
t.Fatalf("expected update to be blocked, but completed with err=%v", err2)
default:
// 正在等待鎖
}
// 釋放鎖
assert.NoError(t, tx1.Commit().Error)
// 現在第二 transaction 應該很快完成
select {
case err2 := <-done:
assert.NoError(t, err2)
case <-time.After(500 * time.Millisecond):
t.Fatal("update did not complete after lock released")
}
}
func TestWalletService_IncreaseBalance(t *testing.T) {
uid := "user1"
asset := "BTC"
type testCase struct {
name string
kind wallet.Types
initial *entity.Wallet // nil 表示不存在
orderID string
amount decimal.Decimal
wantErr error
wantBalance decimal.Decimal
wantTxCount int
}
tests := []testCase{
{
name: "missing wallet type",
kind: wallet.AllTypes[0],
initial: nil,
orderID: "ord1",
amount: decimal.NewFromInt(10),
wantErr: repository.ErrRecordNotFound,
wantBalance: decimal.Zero,
wantTxCount: 0,
},
{
name: "increase from zero",
kind: wallet.AllTypes[1],
initial: &entity.Wallet{Balance: decimal.Zero},
orderID: "ord2",
amount: decimal.NewFromInt(15),
wantErr: nil,
wantBalance: decimal.NewFromInt(15),
wantTxCount: 1,
},
{
name: "successful increment on non-zero",
kind: wallet.AllTypes[2],
initial: &entity.Wallet{Balance: decimal.NewFromInt(5)},
orderID: "ord3",
amount: decimal.NewFromInt(7),
wantErr: nil,
wantBalance: decimal.NewFromInt(12),
wantTxCount: 1,
},
{
name: "insufficient leads to error",
kind: wallet.AllTypes[2],
initial: &entity.Wallet{Balance: decimal.NewFromInt(3)},
orderID: "ord4",
amount: decimal.NewFromInt(-5),
wantErr: repository.ErrBalanceInsufficient,
wantBalance: decimal.NewFromInt(3),
wantTxCount: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// 準備 WalletService
svc := NewWalletService(nil, uid, asset).(*WalletService)
// 初始化本地快取
if tc.initial != nil {
svc.localBalances = map[wallet.Types]entity.Wallet{tc.kind: *tc.initial}
} else {
svc.localBalances = map[wallet.Types]entity.Wallet{}
}
// 清空交易紀錄
svc.transactions = nil
// 執行
err := svc.IncreaseBalance(tc.kind, tc.orderID, tc.amount)
// 驗證錯誤
if tc.wantErr != nil {
assert.ErrorIs(t, err, tc.wantErr)
} else {
assert.NoError(t, err)
}
// 驗證快取餘額
gotBal := svc.CurrentBalance(tc.kind)
assert.True(t, gotBal.Equal(tc.wantBalance), "balance = %s, want %s", gotBal, tc.wantBalance)
// 驗證交易筆數
assert.Len(t, svc.transactions, tc.wantTxCount)
if tc.wantTxCount > 0 {
tx := svc.transactions[0]
assert.Equal(t, tc.orderID, tx.OrderID)
assert.Equal(t, uid, tx.UID)
assert.Equal(t, asset, tx.Asset)
assert.True(t, tx.Amount.Equal(tc.amount))
assert.True(t, tx.Balance.Equal(tc.wantBalance))
}
})
}
}
func TestWalletService_PrepareTransactions(t *testing.T) {
const (
txID = int64(42)
orderID = "order-123"
brand = "brandX"
bizNameStr = "business-test"
)
biz := wallet.BusinessName(bizNameStr)
tests := []struct {
name string
initialTxs []entity.WalletTransaction
wantCount int
}{
{
name: "no transactions returns empty slice",
initialTxs: nil,
wantCount: 0,
},
{
name: "single transaction is populated",
initialTxs: []entity.WalletTransaction{
{
OrderID: "placeholder",
UID: "u1",
WalletType: wallet.AllTypes[0],
Asset: "BTC",
Amount: decimal.NewFromInt(5),
Balance: decimal.NewFromInt(10),
},
},
wantCount: 1,
},
{
name: "multiple transactions are all populated",
initialTxs: []entity.WalletTransaction{
{UID: "u1", WalletType: wallet.AllTypes[1], Asset: "ETH", Amount: decimal.NewFromInt(1), Balance: decimal.NewFromInt(1)},
{UID: "u2", WalletType: wallet.AllTypes[2], Asset: "TWD", Amount: decimal.NewFromInt(2), Balance: decimal.NewFromInt(2)},
{UID: "u3", WalletType: wallet.AllTypes[3%len(wallet.AllTypes)], Asset: "USD", Amount: decimal.NewFromInt(3), Balance: decimal.NewFromInt(3)},
},
wantCount: 3,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// 建立 WalletService 並注入初始交易
svc := &WalletService{
transactions: make([]entity.WalletTransaction, len(tc.initialTxs)),
}
copy(svc.transactions, tc.initialTxs)
// 執行 PrepareTransactions
result := svc.PrepareTransactions(txID, orderID, brand, biz)
// 檢查回傳長度
assert.Len(t, result, tc.wantCount)
assert.Len(t, svc.transactions, tc.wantCount)
// 驗證每筆交易的共用欄位是否正確
for i := 0; i < tc.wantCount; i++ {
tx := result[i]
assert.Equal(t, txID, tx.TransactionID, "tx[%d].TransactionID", i)
assert.Equal(t, orderID, tx.OrderID, "tx[%d].OrderID", i)
assert.Equal(t, brand, tx.Brand, "tx[%d].Brand", i)
assert.Equal(t, biz.ToInt8(), tx.BusinessType, "tx[%d].BusinessType", i)
// 原有欄位保持不變
assert.Equal(t, tc.initialTxs[i].UID, tx.UID, "tx[%d].UID unchanged", i)
assert.Equal(t, tc.initialTxs[i].WalletType, tx.WalletType, "tx[%d].WalletType unchanged", i)
assert.Equal(t, tc.initialTxs[i].Asset, tx.Asset, "tx[%d].Asset unchanged", i)
assert.True(t, tx.Amount.Equal(tc.initialTxs[i].Amount), "tx[%d].Amount unchanged", i)
assert.True(t, tx.Balance.Equal(tc.initialTxs[i].Balance), "tx[%d].Balance unchanged", i)
}
})
}
}
func TestWalletService_PersistBalances(t *testing.T) {
_, db, teardown, err := SetupTestWalletRepository()
assert.NoError(t, err)
defer teardown()
// helper删除表、重建表、插入两笔 seed 数据,返回这两笔带 ID 的 slice
seedWallets := func() []entity.Wallet {
// DROP + AutoMigrate
assert.NoError(t, db.Migrator().DropTable(&entity.Wallet{}))
assert.NoError(t, db.AutoMigrate(&entity.Wallet{}))
base := []entity.Wallet{
{UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(10), Type: wallet.AllTypes[0]},
{UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(20), Type: wallet.AllTypes[1]},
}
assert.NoError(t, db.Create(&base).Error)
return base
}
type fields struct {
updates map[wallet.Types]decimal.Decimal
}
type want struct {
final []decimal.Decimal
err bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "no local balances → no change",
fields: fields{updates: map[wallet.Types]decimal.Decimal{}},
want: want{
final: []decimal.Decimal{decimal.NewFromInt(10), decimal.NewFromInt(20)},
err: false,
},
},
{
name: "update both balances",
fields: fields{updates: map[wallet.Types]decimal.Decimal{
wallet.AllTypes[0]: decimal.NewFromInt(15),
wallet.AllTypes[1]: decimal.NewFromInt(5),
}},
want: want{
final: []decimal.Decimal{decimal.NewFromInt(15), decimal.NewFromInt(5)},
err: false,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// seed 并拿到带 ID 的记录
seed := seedWallets()
// 构造 localBalances只有 Type 在 updates 里的才覆盖
localMap := make(map[wallet.Types]entity.Wallet, len(tc.fields.updates))
for _, w := range seed {
if nb, ok := tc.fields.updates[w.Type]; ok {
localMap[w.Type] = entity.Wallet{
ID: w.ID,
Balance: nb,
}
}
}
// 初始化 service 并注入 localBalances
svc := NewWalletService(db, "u1", "BTC").(*WalletService)
svc.localBalances = localMap
// 执行 PersistBalances
err := svc.PersistBalances(context.Background())
if tc.want.err {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 重新查 DB按 ID 顺序比较余额
var got []entity.Wallet
assert.NoError(t, db.
Where("uid = ?", "u1").
Order("id").
Find(&got).Error)
assert.Len(t, got, len(seed))
for i, w := range got {
assert.Truef(t,
w.Balance.Equal(tc.want.final[i]),
"第 %d 条记录 (id=%d) 余额 = %s, 期望 %s",
i, w.ID, w.Balance, tc.want.final[i],
)
}
})
}
}
func TestWalletService_PersistOrderBalances(t *testing.T) {
_, db, teardown, err := SetupTestWalletRepository()
assert.NoError(t, err)
defer teardown()
// helper重建 transaction 表並 seed 資料
seedTransactions := func() []entity.Transaction {
// DROP + AutoMigrate
_ = db.Migrator().DropTable(&entity.Transaction{})
assert.NoError(t, db.AutoMigrate(&entity.Transaction{}))
// 插入兩筆 transaction
base := []entity.Transaction{
{PostTransferBalance: decimal.NewFromInt(100)},
{PostTransferBalance: decimal.NewFromInt(200)},
}
assert.NoError(t, db.Create(&base).Error)
return base
}
type fields struct {
updates map[int64]decimal.Decimal
}
type want struct {
final []decimal.Decimal
err bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "no local order balances → no change",
fields: fields{updates: map[int64]decimal.Decimal{}},
want: want{
final: []decimal.Decimal{decimal.NewFromInt(100), decimal.NewFromInt(200)},
err: false,
},
},
{
name: "update first order balance only",
fields: fields{updates: map[int64]decimal.Decimal{
1: decimal.NewFromInt(150),
}},
want: want{
final: []decimal.Decimal{decimal.NewFromInt(150), decimal.NewFromInt(200)},
err: false,
},
},
{
name: "update both order balances",
fields: fields{updates: map[int64]decimal.Decimal{
1: decimal.NewFromInt(110),
2: decimal.NewFromInt(220),
}},
want: want{
final: []decimal.Decimal{decimal.NewFromInt(110), decimal.NewFromInt(220)},
err: false,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// seed 資料並拿回 slice包含自動產生的 ID
seed := seedTransactions()
// 建構 localOrderBalanceskey 是 seed[i].ID
localMap := make(map[int64]decimal.Decimal, len(tc.fields.updates))
for id, newBal := range tc.fields.updates {
localMap[id] = newBal
}
// 初始化 service、注入 localOrderBalances
svc := NewWalletService(db, "u1", "BTC").(*WalletService)
svc.localOrderBalances = localMap
// 執行 PersistOrderBalances
err := svc.PersistOrderBalances(context.Background())
if tc.want.err {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 依照 seed 的 ID 順序讀回資料庫
var got []entity.Transaction
assert.NoError(t, db.
Where("1 = 1").
Order("id").
Find(&got).Error)
// 長度要和 seed 相同
assert.Len(t, got, len(seed))
// 比對每一筆 balance
for i, tr := range got {
assert.Truef(t,
tr.PostTransferBalance.Equal(tc.want.final[i]),
"第 %d 筆交易(id=%d) balance = %s, want %s",
i, tr.ID, tr.PostTransferBalance, tc.want.final[i],
)
}
})
}
}
func TestWalletService_HasAvailableBalance(t *testing.T) {
_, db, teardown, err := SetupTestWalletRepository()
assert.NoError(t, err)
defer teardown()
// 建表
createTable := `
CREATE TABLE IF NOT EXISTS wallet (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
brand VARCHAR(50) NOT NULL,
uid VARCHAR(64) NOT NULL,
asset VARCHAR(32) NOT NULL,
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
type TINYINT NOT NULL,
create_at INTEGER NOT NULL DEFAULT 0,
update_at INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (id),
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
assert.NoError(t, db.Exec(createTable).Error)
type seedRow struct {
Brand string
UID string
Asset string
Type wallet.Types
Balance decimal.Decimal
}
tests := []struct {
name string
uid string
asset string
seed []seedRow
wantExist bool
}{
{
name: "無任何紀錄",
uid: "user1",
asset: "BTC",
seed: nil,
wantExist: false,
},
{
name: "只有其他類型錢包",
uid: "user1",
asset: "BTC",
seed: []seedRow{
{"", "user1", "BTC", wallet.TypeFreeze, decimal.Zero},
{"", "user1", "BTC", wallet.TypeUnconfirmed, decimal.Zero},
},
wantExist: false,
},
{
name: "已有可用錢包",
uid: "user1",
asset: "BTC",
seed: []seedRow{
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(10)},
},
wantExist: true,
},
{
name: "不同 UID 不算",
uid: "user2",
asset: "BTC",
seed: []seedRow{
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
},
wantExist: false,
},
{
name: "不同 Asset 不算",
uid: "user1",
asset: "ETH",
seed: []seedRow{
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
},
wantExist: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// 每個子測試前都清空 wallet table
assert.NoError(t, db.Exec("DELETE FROM wallet").Error)
// seed 資料到 wallet
for _, r := range tc.seed {
w := entity.Wallet{
Brand: r.Brand,
UID: r.UID,
Asset: r.Asset,
Type: r.Type,
Balance: r.Balance,
}
assert.NoError(t, db.Create(&w).Error)
}
// 建 service
svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService)
got, err := svc.HasAvailableBalance(context.Background())
assert.NoError(t, err)
assert.Equal(t, tc.wantExist, got)
})
}
}
func SetupTestDB(t *testing.T) (*gorm.DB, func()) {
host, port, _, tearDown, err := startMySQLContainer()
assert.NoError(t, err)
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true",
MySQLUser, MySQLPassword, host, port, MySQLDatabase,
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
assert.NoError(t, err)
// 建表
create := `
CREATE TABLE IF NOT EXISTS wallet (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID',
brand VARCHAR(50) NOT NULL DEFAULT '' COMMENT '品牌',
uid VARCHAR(64) NOT NULL COMMENT '使用者 UID',
asset VARCHAR(32) NOT NULL COMMENT '資產代號',
balance DECIMAL(30,18) UNSIGNED NOT NULL DEFAULT 0 COMMENT '餘額',
type TINYINT NOT NULL COMMENT '錢包類型',
create_at INTEGER NOT NULL DEFAULT 0 COMMENT '建立時間',
update_at INTEGER NOT NULL DEFAULT 0 COMMENT '更新時間',
PRIMARY KEY (id),
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type),
KEY idx_uid (uid),
KEY idx_brand (brand)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
assert.NoError(t, db.Exec(create).Error)
return db, tearDown
}
func TestWalletService_GetBalancesForUpdate(t *testing.T) {
db, tearDown := SetupTestDB(t)
defer tearDown()
type seedRow struct {
Brand string
UID string
Asset string
Type wallet.Types
Balance decimal.Decimal
}
tests := []struct {
name string
uid string
asset string
seed []seedRow
kinds []wallet.Types
wantIDs []int64
expectErr bool
}{
{
name: "查詢空結果",
uid: "u1",
asset: "BTC",
seed: nil,
kinds: []wallet.Types{wallet.TypeAvailable},
wantIDs: nil,
expectErr: false,
},
{
name: "單一類型查詢",
uid: "u1",
asset: "BTC",
seed: []seedRow{
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
{"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)},
},
kinds: []wallet.Types{wallet.TypeFreeze},
wantIDs: []int64{2},
expectErr: false,
},
{
name: "多類型查詢",
uid: "u1",
asset: "BTC",
seed: []seedRow{
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
{"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)},
{"", "u1", "BTC", wallet.TypeUnconfirmed, decimal.NewFromInt(3)},
},
kinds: []wallet.Types{wallet.TypeAvailable, wallet.TypeUnconfirmed},
wantIDs: []int64{3, 5},
expectErr: false,
},
{
name: "不同 UID 不列入",
uid: "u2",
asset: "BTC",
seed: []seedRow{
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
},
kinds: []wallet.Types{wallet.TypeAvailable},
wantIDs: nil,
expectErr: false,
},
{
name: "不同 Asset 不列入",
uid: "u1",
asset: "ETH",
seed: []seedRow{
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
},
kinds: []wallet.Types{wallet.TypeAvailable},
wantIDs: nil,
expectErr: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// 清空資料
assert.NoError(t, db.Exec("DELETE FROM wallet").Error)
// Seed
for _, r := range tc.seed {
w := entity.Wallet{
Brand: r.Brand,
UID: r.UID,
Asset: r.Asset,
Type: r.Type,
Balance: r.Balance,
}
// Create will auto-assign incremental IDs
assert.NoError(t, db.Create(&w).Error)
}
// 建 Service
svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService)
got, err := svc.GetBalancesForUpdate(context.Background(), tc.kinds)
if tc.expectErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 取回 IDs 並排序
var gotIDs []int64
for _, w := range got {
gotIDs = append(gotIDs, w.ID)
// localBalances 應該被更新
assert.Equal(t, w, svc.localBalances[w.Type])
}
assert.Equal(t, tc.wantIDs, gotIDs)
})
}
}