1026 lines
27 KiB
Go
1026 lines
27 KiB
Go
package repository
|
||
|
||
import (
|
||
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
|
||
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
|
||
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
|
||
"context"
|
||
"fmt"
|
||
"github.com/shopspring/decimal"
|
||
"github.com/stretchr/testify/assert"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
func TestWalletService_InitializeWallets(t *testing.T) {
|
||
_, db, teardown, err := SetupTestWalletRepository()
|
||
assert.NoError(t, err)
|
||
defer teardown()
|
||
|
||
// 先手動建表
|
||
createTable := `
|
||
CREATE TABLE IF NOT EXISTS wallet (
|
||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
||
brand VARCHAR(50) NOT NULL,
|
||
uid VARCHAR(64) NOT NULL,
|
||
asset VARCHAR(32) NOT NULL,
|
||
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
||
type TINYINT NOT NULL,
|
||
create_at INTEGER NOT NULL DEFAULT 0,
|
||
update_at INTEGER NOT NULL DEFAULT 0,
|
||
PRIMARY KEY (id),
|
||
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
||
assert.NoError(t, db.Exec(createTable).Error)
|
||
|
||
type args struct {
|
||
uid string
|
||
asset string
|
||
brand string
|
||
}
|
||
tests := []struct {
|
||
name string
|
||
args args
|
||
wantCount int
|
||
wantErr bool
|
||
validateDB bool
|
||
}{
|
||
{
|
||
name: "正常初始化一次",
|
||
args: args{uid: "user1", asset: "BTC", brand: "brandA"},
|
||
wantCount: len(wallet.AllTypes),
|
||
},
|
||
{
|
||
name: "再次初始化同一 UID/asset/brand,應因 UNIQUE KEY 失敗",
|
||
args: args{uid: "user1", asset: "BTC", brand: "brandA"},
|
||
wantErr: true,
|
||
validateDB: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
service := NewWalletService(db, tt.args.uid, tt.args.asset)
|
||
ctx := context.Background()
|
||
|
||
got, err := service.InitializeWallets(ctx, tt.args.brand)
|
||
if tt.wantErr {
|
||
assert.Error(t, err)
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
// 回傳 slice 長度應等於 AllTypes
|
||
assert.Len(t, got, tt.wantCount)
|
||
|
||
if tt.validateDB {
|
||
// 再查一次 DB,確認確實寫入
|
||
var dbRows []entity.Wallet
|
||
err := db.WithContext(ctx).
|
||
Where("uid = ? AND asset = ? AND brand = ?", tt.args.uid, tt.args.asset, tt.args.brand).
|
||
Find(&dbRows).Error
|
||
assert.NoError(t, err)
|
||
assert.Len(t, dbRows, tt.wantCount)
|
||
// 檢查每一筆都初始為零
|
||
for _, w := range dbRows {
|
||
assert.Equal(t, decimal.Zero, w.Balance)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWalletService_GetAllBalances(t *testing.T) {
|
||
_, db, teardown, err := SetupTestWalletRepository()
|
||
assert.NoError(t, err)
|
||
defer teardown()
|
||
|
||
// 建表
|
||
createTable := `
|
||
CREATE TABLE IF NOT EXISTS wallet (
|
||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
||
brand VARCHAR(50) NOT NULL,
|
||
uid VARCHAR(64) NOT NULL,
|
||
asset VARCHAR(32) NOT NULL,
|
||
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
||
type TINYINT NOT NULL,
|
||
create_at INTEGER NOT NULL DEFAULT 0,
|
||
update_at INTEGER NOT NULL DEFAULT 0,
|
||
PRIMARY KEY (id),
|
||
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
||
assert.NoError(t, db.Exec(createTable).Error)
|
||
|
||
type seedRow struct {
|
||
Brand string
|
||
UID string
|
||
Asset string
|
||
Balance decimal.Decimal
|
||
Type wallet.Types
|
||
}
|
||
|
||
type args struct {
|
||
uid string
|
||
asset string
|
||
brand string
|
||
}
|
||
tests := []struct {
|
||
name string
|
||
args args
|
||
seed []seedRow
|
||
wantCount int
|
||
}{
|
||
{
|
||
name: "no data returns empty",
|
||
args: args{"u1", "BTC", "b1"},
|
||
seed: nil,
|
||
wantCount: 0,
|
||
},
|
||
{
|
||
name: "single user many types",
|
||
args: args{"u1", "BTC", "b1"},
|
||
seed: []seedRow{
|
||
{"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]},
|
||
{"b1", "u1", "BTC", decimal.NewFromInt(20), wallet.AllTypes[1]},
|
||
{"b1", "u1", "BTC", decimal.NewFromInt(30), wallet.AllTypes[2]},
|
||
},
|
||
wantCount: 3,
|
||
},
|
||
{
|
||
name: "mixed users and assets",
|
||
args: args{"u2", "ETH", "b2"},
|
||
seed: []seedRow{
|
||
{"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]},
|
||
{"b2", "u2", "ETH", decimal.NewFromInt(15), wallet.AllTypes[1]},
|
||
{"b2", "u2", "ETH", decimal.NewFromInt(25), wallet.AllTypes[2]},
|
||
},
|
||
wantCount: 2,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
ctx := context.Background()
|
||
// 清空表
|
||
assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error)
|
||
|
||
// 插入 seed
|
||
for _, row := range tt.seed {
|
||
err := db.WithContext(ctx).Exec(
|
||
`INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||
row.Brand, row.UID, row.Asset, row.Balance, row.Type,
|
||
time.Now().Unix(), time.Now().Unix(),
|
||
).Error
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
// 建立 service
|
||
svc := NewWalletService(db, tt.args.uid, tt.args.asset)
|
||
|
||
// 呼叫 GetAllBalances
|
||
got, err := svc.GetAllBalances(ctx)
|
||
assert.NoError(t, err)
|
||
assert.Len(t, got, tt.wantCount)
|
||
|
||
// 檢查回傳資料與本地快取一致
|
||
ws := svc.(*WalletService)
|
||
for _, w := range got {
|
||
// 回傳的每筆都應該存在於 localBalances
|
||
cached, ok := ws.localBalances[w.Type]
|
||
assert.True(t, ok)
|
||
assert.Equal(t, w.Balance, cached.Balance)
|
||
assert.Equal(t, w.Asset, cached.Asset)
|
||
assert.Equal(t, w.Type, cached.Type)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWalletService_GetBalancesForTypes(t *testing.T) {
|
||
_, db, teardown, err := SetupTestWalletRepository()
|
||
assert.NoError(t, err)
|
||
defer teardown()
|
||
|
||
// 建表
|
||
createTable := `
|
||
CREATE TABLE IF NOT EXISTS wallet (
|
||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
||
brand VARCHAR(50) NOT NULL,
|
||
uid VARCHAR(64) NOT NULL,
|
||
asset VARCHAR(32) NOT NULL,
|
||
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
||
type TINYINT NOT NULL,
|
||
create_at INTEGER NOT NULL DEFAULT 0,
|
||
update_at INTEGER NOT NULL DEFAULT 0,
|
||
PRIMARY KEY (id),
|
||
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
||
assert.NoError(t, db.Exec(createTable).Error)
|
||
|
||
type seedRow struct {
|
||
Brand string
|
||
UID string
|
||
Asset string
|
||
Balance decimal.Decimal
|
||
Type wallet.Types
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
uid string
|
||
asset string
|
||
brand string
|
||
seed []seedRow
|
||
kinds []wallet.Types
|
||
wantCount int
|
||
}{
|
||
{
|
||
name: "no matching types returns empty",
|
||
uid: "user1", asset: "BTC", brand: "b1",
|
||
seed: []seedRow{
|
||
{"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]},
|
||
},
|
||
kinds: []wallet.Types{wallet.AllTypes[1]},
|
||
wantCount: 0,
|
||
},
|
||
{
|
||
name: "single type match",
|
||
uid: "user1", asset: "BTC", brand: "b1",
|
||
seed: []seedRow{
|
||
{"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]},
|
||
{"b1", "user1", "BTC", decimal.NewFromInt(7), wallet.AllTypes[1]},
|
||
},
|
||
kinds: []wallet.Types{wallet.AllTypes[1]},
|
||
wantCount: 1,
|
||
},
|
||
{
|
||
name: "multiple type matches",
|
||
uid: "user2", asset: "ETH", brand: "b2",
|
||
seed: []seedRow{
|
||
{"b2", "user2", "ETH", decimal.NewFromInt(3), wallet.AllTypes[0]},
|
||
{"b2", "user2", "ETH", decimal.NewFromInt(8), wallet.AllTypes[2]},
|
||
{"b2", "user2", "ETH", decimal.NewFromInt(10), wallet.AllTypes[1]},
|
||
},
|
||
kinds: []wallet.Types{wallet.AllTypes[0], wallet.AllTypes[2]},
|
||
wantCount: 2,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
ctx := context.Background()
|
||
// 清表
|
||
assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error)
|
||
|
||
// 插入種子資料
|
||
for _, r := range tt.seed {
|
||
assert.NoError(t, db.Exec(
|
||
`INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||
r.Brand, r.UID, r.Asset, r.Balance, r.Type,
|
||
time.Now().Unix(), time.Now().Unix(),
|
||
).Error)
|
||
}
|
||
|
||
// 建立 service
|
||
svc := NewWalletService(db, tt.uid, tt.asset)
|
||
|
||
// 呼叫 GetBalancesForTypes
|
||
got, err := svc.GetBalancesForTypes(ctx, tt.kinds)
|
||
assert.NoError(t, err)
|
||
assert.Len(t, got, tt.wantCount)
|
||
|
||
// 檢查每筆結果皆正確緩存
|
||
ws := svc.(*WalletService)
|
||
for _, w := range got {
|
||
cached, ok := ws.localBalances[w.Type]
|
||
assert.True(t, ok)
|
||
assert.Equal(t, w.Balance, cached.Balance)
|
||
assert.Equal(t, w.Asset, cached.Asset)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestForUpdateLockBehavior(t *testing.T) {
|
||
// 建立測試環境
|
||
_, db, teardown, err := SetupTestWalletRepository()
|
||
assert.NoError(t, err)
|
||
defer teardown()
|
||
|
||
// 建表
|
||
createTable := `
|
||
CREATE TABLE IF NOT EXISTS wallet (
|
||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
||
brand VARCHAR(50) NOT NULL,
|
||
uid VARCHAR(64) NOT NULL,
|
||
asset VARCHAR(32) NOT NULL,
|
||
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
||
type TINYINT NOT NULL,
|
||
create_at INTEGER NOT NULL DEFAULT 0,
|
||
update_at INTEGER NOT NULL DEFAULT 0,
|
||
PRIMARY KEY (id),
|
||
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
||
assert.NoError(t, db.Exec(createTable).Error)
|
||
|
||
// 種子資料:一筆 wallet
|
||
ctx := context.Background()
|
||
initial := entity.Wallet{
|
||
Brand: "b1",
|
||
UID: "user1",
|
||
Asset: "BTC",
|
||
Balance: decimal.NewFromInt(100),
|
||
Type: wallet.AllTypes[0],
|
||
}
|
||
err = db.WithContext(ctx).Create(&initial).Error
|
||
assert.NoError(t, err)
|
||
|
||
// 讀回自動產生的 ID
|
||
var seeded entity.Wallet
|
||
assert.NoError(t, db.Where("uid = ? AND asset = ?", initial.UID, initial.Asset).
|
||
First(&seeded).Error)
|
||
|
||
// 開啟第一個 transaction 並 SELECT … FOR UPDATE
|
||
tx1 := db.Begin()
|
||
var locked entity.Wallet
|
||
err = tx1.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||
WithContext(ctx).
|
||
Where("id = ?", seeded.ID).
|
||
Take(&locked).Error
|
||
assert.NoError(t, err)
|
||
|
||
// 啟動 goroutine 嘗試在第二 transaction 更新該 row
|
||
done := make(chan error, 1)
|
||
go func() {
|
||
tx2 := db.Begin()
|
||
// 試圖更新,被 FOR UPDATE 鎖住前應該會 block
|
||
err := tx2.WithContext(ctx).
|
||
Model(&entity.Wallet{}).
|
||
Where("id = ?", seeded.ID).
|
||
Update("balance", seeded.Balance.Add(decimal.NewFromInt(50))).Error
|
||
done <- err
|
||
}()
|
||
|
||
// 等 100ms 確認尚未完成
|
||
time.Sleep(100 * time.Millisecond)
|
||
select {
|
||
case err2 := <-done:
|
||
t.Fatalf("expected update to be blocked, but completed with err=%v", err2)
|
||
default:
|
||
// 正在等待鎖
|
||
}
|
||
|
||
// 釋放鎖
|
||
assert.NoError(t, tx1.Commit().Error)
|
||
|
||
// 現在第二 transaction 應該很快完成
|
||
select {
|
||
case err2 := <-done:
|
||
assert.NoError(t, err2)
|
||
case <-time.After(500 * time.Millisecond):
|
||
t.Fatal("update did not complete after lock released")
|
||
}
|
||
}
|
||
|
||
func TestWalletService_IncreaseBalance(t *testing.T) {
|
||
uid := "user1"
|
||
asset := "BTC"
|
||
|
||
type testCase struct {
|
||
name string
|
||
kind wallet.Types
|
||
initial *entity.Wallet // nil 表示不存在
|
||
orderID string
|
||
amount decimal.Decimal
|
||
wantErr error
|
||
wantBalance decimal.Decimal
|
||
wantTxCount int
|
||
}
|
||
tests := []testCase{
|
||
{
|
||
name: "missing wallet type",
|
||
kind: wallet.AllTypes[0],
|
||
initial: nil,
|
||
orderID: "ord1",
|
||
amount: decimal.NewFromInt(10),
|
||
wantErr: repository.ErrRecordNotFound,
|
||
wantBalance: decimal.Zero,
|
||
wantTxCount: 0,
|
||
},
|
||
{
|
||
name: "increase from zero",
|
||
kind: wallet.AllTypes[1],
|
||
initial: &entity.Wallet{Balance: decimal.Zero},
|
||
orderID: "ord2",
|
||
amount: decimal.NewFromInt(15),
|
||
wantErr: nil,
|
||
wantBalance: decimal.NewFromInt(15),
|
||
wantTxCount: 1,
|
||
},
|
||
{
|
||
name: "successful increment on non-zero",
|
||
kind: wallet.AllTypes[2],
|
||
initial: &entity.Wallet{Balance: decimal.NewFromInt(5)},
|
||
orderID: "ord3",
|
||
amount: decimal.NewFromInt(7),
|
||
wantErr: nil,
|
||
wantBalance: decimal.NewFromInt(12),
|
||
wantTxCount: 1,
|
||
},
|
||
{
|
||
name: "insufficient leads to error",
|
||
kind: wallet.AllTypes[2],
|
||
initial: &entity.Wallet{Balance: decimal.NewFromInt(3)},
|
||
orderID: "ord4",
|
||
amount: decimal.NewFromInt(-5),
|
||
wantErr: repository.ErrBalanceInsufficient,
|
||
wantBalance: decimal.NewFromInt(3),
|
||
wantTxCount: 0,
|
||
},
|
||
}
|
||
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
// 準備 WalletService
|
||
svc := NewWalletService(nil, uid, asset).(*WalletService)
|
||
// 初始化本地快取
|
||
if tc.initial != nil {
|
||
svc.localBalances = map[wallet.Types]entity.Wallet{tc.kind: *tc.initial}
|
||
} else {
|
||
svc.localBalances = map[wallet.Types]entity.Wallet{}
|
||
}
|
||
// 清空交易紀錄
|
||
svc.transactions = nil
|
||
|
||
// 執行
|
||
err := svc.IncreaseBalance(tc.kind, tc.orderID, tc.amount)
|
||
|
||
// 驗證錯誤
|
||
if tc.wantErr != nil {
|
||
assert.ErrorIs(t, err, tc.wantErr)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
// 驗證快取餘額
|
||
gotBal := svc.CurrentBalance(tc.kind)
|
||
assert.True(t, gotBal.Equal(tc.wantBalance), "balance = %s, want %s", gotBal, tc.wantBalance)
|
||
|
||
// 驗證交易筆數
|
||
assert.Len(t, svc.transactions, tc.wantTxCount)
|
||
if tc.wantTxCount > 0 {
|
||
tx := svc.transactions[0]
|
||
assert.Equal(t, tc.orderID, tx.OrderID)
|
||
assert.Equal(t, uid, tx.UID)
|
||
assert.Equal(t, asset, tx.Asset)
|
||
assert.True(t, tx.Amount.Equal(tc.amount))
|
||
assert.True(t, tx.Balance.Equal(tc.wantBalance))
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWalletService_PrepareTransactions(t *testing.T) {
|
||
const (
|
||
txID = int64(42)
|
||
orderID = "order-123"
|
||
brand = "brandX"
|
||
bizNameStr = "business-test"
|
||
)
|
||
biz := wallet.BusinessName(bizNameStr)
|
||
|
||
tests := []struct {
|
||
name string
|
||
initialTxs []entity.WalletTransaction
|
||
wantCount int
|
||
}{
|
||
{
|
||
name: "no transactions returns empty slice",
|
||
initialTxs: nil,
|
||
wantCount: 0,
|
||
},
|
||
{
|
||
name: "single transaction is populated",
|
||
initialTxs: []entity.WalletTransaction{
|
||
{
|
||
OrderID: "placeholder",
|
||
UID: "u1",
|
||
WalletType: wallet.AllTypes[0],
|
||
Asset: "BTC",
|
||
Amount: decimal.NewFromInt(5),
|
||
Balance: decimal.NewFromInt(10),
|
||
},
|
||
},
|
||
wantCount: 1,
|
||
},
|
||
{
|
||
name: "multiple transactions are all populated",
|
||
initialTxs: []entity.WalletTransaction{
|
||
{UID: "u1", WalletType: wallet.AllTypes[1], Asset: "ETH", Amount: decimal.NewFromInt(1), Balance: decimal.NewFromInt(1)},
|
||
{UID: "u2", WalletType: wallet.AllTypes[2], Asset: "TWD", Amount: decimal.NewFromInt(2), Balance: decimal.NewFromInt(2)},
|
||
{UID: "u3", WalletType: wallet.AllTypes[3%len(wallet.AllTypes)], Asset: "USD", Amount: decimal.NewFromInt(3), Balance: decimal.NewFromInt(3)},
|
||
},
|
||
wantCount: 3,
|
||
},
|
||
}
|
||
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
// 建立 WalletService 並注入初始交易
|
||
svc := &WalletService{
|
||
transactions: make([]entity.WalletTransaction, len(tc.initialTxs)),
|
||
}
|
||
copy(svc.transactions, tc.initialTxs)
|
||
|
||
// 執行 PrepareTransactions
|
||
result := svc.PrepareTransactions(txID, orderID, brand, biz)
|
||
|
||
// 檢查回傳長度
|
||
assert.Len(t, result, tc.wantCount)
|
||
assert.Len(t, svc.transactions, tc.wantCount)
|
||
|
||
// 驗證每筆交易的共用欄位是否正確
|
||
for i := 0; i < tc.wantCount; i++ {
|
||
tx := result[i]
|
||
assert.Equal(t, txID, tx.TransactionID, "tx[%d].TransactionID", i)
|
||
assert.Equal(t, orderID, tx.OrderID, "tx[%d].OrderID", i)
|
||
assert.Equal(t, brand, tx.Brand, "tx[%d].Brand", i)
|
||
assert.Equal(t, biz.ToInt8(), tx.BusinessType, "tx[%d].BusinessType", i)
|
||
// 原有欄位保持不變
|
||
assert.Equal(t, tc.initialTxs[i].UID, tx.UID, "tx[%d].UID unchanged", i)
|
||
assert.Equal(t, tc.initialTxs[i].WalletType, tx.WalletType, "tx[%d].WalletType unchanged", i)
|
||
assert.Equal(t, tc.initialTxs[i].Asset, tx.Asset, "tx[%d].Asset unchanged", i)
|
||
assert.True(t, tx.Amount.Equal(tc.initialTxs[i].Amount), "tx[%d].Amount unchanged", i)
|
||
assert.True(t, tx.Balance.Equal(tc.initialTxs[i].Balance), "tx[%d].Balance unchanged", i)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWalletService_PersistBalances(t *testing.T) {
|
||
_, db, teardown, err := SetupTestWalletRepository()
|
||
assert.NoError(t, err)
|
||
defer teardown()
|
||
|
||
// helper:删除表、重建表、插入两笔 seed 数据,返回这两笔带 ID 的 slice
|
||
seedWallets := func() []entity.Wallet {
|
||
// DROP + AutoMigrate
|
||
assert.NoError(t, db.Migrator().DropTable(&entity.Wallet{}))
|
||
assert.NoError(t, db.AutoMigrate(&entity.Wallet{}))
|
||
|
||
base := []entity.Wallet{
|
||
{UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(10), Type: wallet.AllTypes[0]},
|
||
{UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(20), Type: wallet.AllTypes[1]},
|
||
}
|
||
assert.NoError(t, db.Create(&base).Error)
|
||
return base
|
||
}
|
||
|
||
type fields struct {
|
||
updates map[wallet.Types]decimal.Decimal
|
||
}
|
||
type want struct {
|
||
final []decimal.Decimal
|
||
err bool
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
fields fields
|
||
want want
|
||
}{
|
||
{
|
||
name: "no local balances → no change",
|
||
fields: fields{updates: map[wallet.Types]decimal.Decimal{}},
|
||
want: want{
|
||
final: []decimal.Decimal{decimal.NewFromInt(10), decimal.NewFromInt(20)},
|
||
err: false,
|
||
},
|
||
},
|
||
{
|
||
name: "update both balances",
|
||
fields: fields{updates: map[wallet.Types]decimal.Decimal{
|
||
wallet.AllTypes[0]: decimal.NewFromInt(15),
|
||
wallet.AllTypes[1]: decimal.NewFromInt(5),
|
||
}},
|
||
want: want{
|
||
final: []decimal.Decimal{decimal.NewFromInt(15), decimal.NewFromInt(5)},
|
||
err: false,
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
// seed 并拿到带 ID 的记录
|
||
seed := seedWallets()
|
||
|
||
// 构造 localBalances:只有 Type 在 updates 里的才覆盖
|
||
localMap := make(map[wallet.Types]entity.Wallet, len(tc.fields.updates))
|
||
for _, w := range seed {
|
||
if nb, ok := tc.fields.updates[w.Type]; ok {
|
||
localMap[w.Type] = entity.Wallet{
|
||
ID: w.ID,
|
||
Balance: nb,
|
||
}
|
||
}
|
||
}
|
||
|
||
// 初始化 service 并注入 localBalances
|
||
svc := NewWalletService(db, "u1", "BTC").(*WalletService)
|
||
svc.localBalances = localMap
|
||
|
||
// 执行 PersistBalances
|
||
err := svc.PersistBalances(context.Background())
|
||
if tc.want.err {
|
||
assert.Error(t, err)
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
|
||
// 重新查 DB,按 ID 顺序比较余额
|
||
var got []entity.Wallet
|
||
assert.NoError(t, db.
|
||
Where("uid = ?", "u1").
|
||
Order("id").
|
||
Find(&got).Error)
|
||
|
||
assert.Len(t, got, len(seed))
|
||
for i, w := range got {
|
||
assert.Truef(t,
|
||
w.Balance.Equal(tc.want.final[i]),
|
||
"第 %d 条记录 (id=%d) 余额 = %s, 期望 %s",
|
||
i, w.ID, w.Balance, tc.want.final[i],
|
||
)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWalletService_PersistOrderBalances(t *testing.T) {
|
||
_, db, teardown, err := SetupTestWalletRepository()
|
||
assert.NoError(t, err)
|
||
defer teardown()
|
||
|
||
// helper:重建 transaction 表並 seed 資料
|
||
seedTransactions := func() []entity.Transaction {
|
||
// DROP + AutoMigrate
|
||
_ = db.Migrator().DropTable(&entity.Transaction{})
|
||
assert.NoError(t, db.AutoMigrate(&entity.Transaction{}))
|
||
|
||
// 插入兩筆 transaction
|
||
base := []entity.Transaction{
|
||
{PostTransferBalance: decimal.NewFromInt(100)},
|
||
{PostTransferBalance: decimal.NewFromInt(200)},
|
||
}
|
||
assert.NoError(t, db.Create(&base).Error)
|
||
return base
|
||
}
|
||
|
||
type fields struct {
|
||
updates map[int64]decimal.Decimal
|
||
}
|
||
type want struct {
|
||
final []decimal.Decimal
|
||
err bool
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
fields fields
|
||
want want
|
||
}{
|
||
{
|
||
name: "no local order balances → no change",
|
||
fields: fields{updates: map[int64]decimal.Decimal{}},
|
||
want: want{
|
||
final: []decimal.Decimal{decimal.NewFromInt(100), decimal.NewFromInt(200)},
|
||
err: false,
|
||
},
|
||
},
|
||
{
|
||
name: "update first order balance only",
|
||
fields: fields{updates: map[int64]decimal.Decimal{
|
||
1: decimal.NewFromInt(150),
|
||
}},
|
||
want: want{
|
||
final: []decimal.Decimal{decimal.NewFromInt(150), decimal.NewFromInt(200)},
|
||
err: false,
|
||
},
|
||
},
|
||
{
|
||
name: "update both order balances",
|
||
fields: fields{updates: map[int64]decimal.Decimal{
|
||
1: decimal.NewFromInt(110),
|
||
2: decimal.NewFromInt(220),
|
||
}},
|
||
want: want{
|
||
final: []decimal.Decimal{decimal.NewFromInt(110), decimal.NewFromInt(220)},
|
||
err: false,
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
// seed 資料並拿回 slice(包含自動產生的 ID)
|
||
seed := seedTransactions()
|
||
|
||
// 建構 localOrderBalances:key 是 seed[i].ID
|
||
localMap := make(map[int64]decimal.Decimal, len(tc.fields.updates))
|
||
for id, newBal := range tc.fields.updates {
|
||
localMap[id] = newBal
|
||
}
|
||
|
||
// 初始化 service、注入 localOrderBalances
|
||
svc := NewWalletService(db, "u1", "BTC").(*WalletService)
|
||
svc.localOrderBalances = localMap
|
||
|
||
// 執行 PersistOrderBalances
|
||
err := svc.PersistOrderBalances(context.Background())
|
||
if tc.want.err {
|
||
assert.Error(t, err)
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
|
||
// 依照 seed 的 ID 順序讀回資料庫
|
||
var got []entity.Transaction
|
||
assert.NoError(t, db.
|
||
Where("1 = 1").
|
||
Order("id").
|
||
Find(&got).Error)
|
||
// 長度要和 seed 相同
|
||
assert.Len(t, got, len(seed))
|
||
// 比對每一筆 balance
|
||
for i, tr := range got {
|
||
assert.Truef(t,
|
||
tr.PostTransferBalance.Equal(tc.want.final[i]),
|
||
"第 %d 筆交易(id=%d) balance = %s, want %s",
|
||
i, tr.ID, tr.PostTransferBalance, tc.want.final[i],
|
||
)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWalletService_HasAvailableBalance(t *testing.T) {
|
||
_, db, teardown, err := SetupTestWalletRepository()
|
||
assert.NoError(t, err)
|
||
defer teardown()
|
||
|
||
// 建表
|
||
createTable := `
|
||
CREATE TABLE IF NOT EXISTS wallet (
|
||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
||
brand VARCHAR(50) NOT NULL,
|
||
uid VARCHAR(64) NOT NULL,
|
||
asset VARCHAR(32) NOT NULL,
|
||
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
||
type TINYINT NOT NULL,
|
||
create_at INTEGER NOT NULL DEFAULT 0,
|
||
update_at INTEGER NOT NULL DEFAULT 0,
|
||
PRIMARY KEY (id),
|
||
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
||
assert.NoError(t, db.Exec(createTable).Error)
|
||
|
||
type seedRow struct {
|
||
Brand string
|
||
UID string
|
||
Asset string
|
||
Type wallet.Types
|
||
Balance decimal.Decimal
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
uid string
|
||
asset string
|
||
seed []seedRow
|
||
wantExist bool
|
||
}{
|
||
{
|
||
name: "無任何紀錄",
|
||
uid: "user1",
|
||
asset: "BTC",
|
||
seed: nil,
|
||
wantExist: false,
|
||
},
|
||
{
|
||
name: "只有其他類型錢包",
|
||
uid: "user1",
|
||
asset: "BTC",
|
||
seed: []seedRow{
|
||
{"", "user1", "BTC", wallet.TypeFreeze, decimal.Zero},
|
||
{"", "user1", "BTC", wallet.TypeUnconfirmed, decimal.Zero},
|
||
},
|
||
wantExist: false,
|
||
},
|
||
{
|
||
name: "已有可用錢包",
|
||
uid: "user1",
|
||
asset: "BTC",
|
||
seed: []seedRow{
|
||
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(10)},
|
||
},
|
||
wantExist: true,
|
||
},
|
||
{
|
||
name: "不同 UID 不算",
|
||
uid: "user2",
|
||
asset: "BTC",
|
||
seed: []seedRow{
|
||
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
||
},
|
||
wantExist: false,
|
||
},
|
||
{
|
||
name: "不同 Asset 不算",
|
||
uid: "user1",
|
||
asset: "ETH",
|
||
seed: []seedRow{
|
||
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
||
},
|
||
wantExist: false,
|
||
},
|
||
}
|
||
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
// 每個子測試前都清空 wallet table
|
||
assert.NoError(t, db.Exec("DELETE FROM wallet").Error)
|
||
|
||
// seed 資料到 wallet
|
||
for _, r := range tc.seed {
|
||
w := entity.Wallet{
|
||
Brand: r.Brand,
|
||
UID: r.UID,
|
||
Asset: r.Asset,
|
||
Type: r.Type,
|
||
Balance: r.Balance,
|
||
}
|
||
assert.NoError(t, db.Create(&w).Error)
|
||
}
|
||
|
||
// 建 service
|
||
svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService)
|
||
got, err := svc.HasAvailableBalance(context.Background())
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, tc.wantExist, got)
|
||
})
|
||
}
|
||
}
|
||
|
||
func SetupTestDB(t *testing.T) (*gorm.DB, func()) {
|
||
host, port, _, tearDown, err := startMySQLContainer()
|
||
assert.NoError(t, err)
|
||
|
||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true",
|
||
MySQLUser, MySQLPassword, host, port, MySQLDatabase,
|
||
)
|
||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||
assert.NoError(t, err)
|
||
|
||
// 建表
|
||
create := `
|
||
CREATE TABLE IF NOT EXISTS wallet (
|
||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID',
|
||
brand VARCHAR(50) NOT NULL DEFAULT '' COMMENT '品牌',
|
||
uid VARCHAR(64) NOT NULL COMMENT '使用者 UID',
|
||
asset VARCHAR(32) NOT NULL COMMENT '資產代號',
|
||
balance DECIMAL(30,18) UNSIGNED NOT NULL DEFAULT 0 COMMENT '餘額',
|
||
type TINYINT NOT NULL COMMENT '錢包類型',
|
||
create_at INTEGER NOT NULL DEFAULT 0 COMMENT '建立時間',
|
||
update_at INTEGER NOT NULL DEFAULT 0 COMMENT '更新時間',
|
||
PRIMARY KEY (id),
|
||
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type),
|
||
KEY idx_uid (uid),
|
||
KEY idx_brand (brand)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
||
assert.NoError(t, db.Exec(create).Error)
|
||
|
||
return db, tearDown
|
||
}
|
||
|
||
func TestWalletService_GetBalancesForUpdate(t *testing.T) {
|
||
db, tearDown := SetupTestDB(t)
|
||
defer tearDown()
|
||
|
||
type seedRow struct {
|
||
Brand string
|
||
UID string
|
||
Asset string
|
||
Type wallet.Types
|
||
Balance decimal.Decimal
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
uid string
|
||
asset string
|
||
seed []seedRow
|
||
kinds []wallet.Types
|
||
wantIDs []int64
|
||
expectErr bool
|
||
}{
|
||
{
|
||
name: "查詢空結果",
|
||
uid: "u1",
|
||
asset: "BTC",
|
||
seed: nil,
|
||
kinds: []wallet.Types{wallet.TypeAvailable},
|
||
wantIDs: nil,
|
||
expectErr: false,
|
||
},
|
||
{
|
||
name: "單一類型查詢",
|
||
uid: "u1",
|
||
asset: "BTC",
|
||
seed: []seedRow{
|
||
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
||
{"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)},
|
||
},
|
||
kinds: []wallet.Types{wallet.TypeFreeze},
|
||
wantIDs: []int64{2},
|
||
expectErr: false,
|
||
},
|
||
{
|
||
name: "多類型查詢",
|
||
uid: "u1",
|
||
asset: "BTC",
|
||
seed: []seedRow{
|
||
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
||
{"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)},
|
||
{"", "u1", "BTC", wallet.TypeUnconfirmed, decimal.NewFromInt(3)},
|
||
},
|
||
kinds: []wallet.Types{wallet.TypeAvailable, wallet.TypeUnconfirmed},
|
||
wantIDs: []int64{3, 5},
|
||
expectErr: false,
|
||
},
|
||
{
|
||
name: "不同 UID 不列入",
|
||
uid: "u2",
|
||
asset: "BTC",
|
||
seed: []seedRow{
|
||
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
||
},
|
||
kinds: []wallet.Types{wallet.TypeAvailable},
|
||
wantIDs: nil,
|
||
expectErr: false,
|
||
},
|
||
{
|
||
name: "不同 Asset 不列入",
|
||
uid: "u1",
|
||
asset: "ETH",
|
||
seed: []seedRow{
|
||
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
||
},
|
||
kinds: []wallet.Types{wallet.TypeAvailable},
|
||
wantIDs: nil,
|
||
expectErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
// 清空資料
|
||
assert.NoError(t, db.Exec("DELETE FROM wallet").Error)
|
||
// Seed
|
||
for _, r := range tc.seed {
|
||
w := entity.Wallet{
|
||
Brand: r.Brand,
|
||
UID: r.UID,
|
||
Asset: r.Asset,
|
||
Type: r.Type,
|
||
Balance: r.Balance,
|
||
}
|
||
// Create will auto-assign incremental IDs
|
||
assert.NoError(t, db.Create(&w).Error)
|
||
}
|
||
|
||
// 建 Service
|
||
svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService)
|
||
got, err := svc.GetBalancesForUpdate(context.Background(), tc.kinds)
|
||
if tc.expectErr {
|
||
assert.Error(t, err)
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
|
||
// 取回 IDs 並排序
|
||
var gotIDs []int64
|
||
for _, w := range got {
|
||
gotIDs = append(gotIDs, w.ID)
|
||
// localBalances 應該被更新
|
||
assert.Equal(t, w, svc.localBalances[w.Type])
|
||
}
|
||
assert.Equal(t, tc.wantIDs, gotIDs)
|
||
})
|
||
}
|
||
}
|