1026 lines
27 KiB
Go
1026 lines
27 KiB
Go
|
package repository
|
|||
|
|
|||
|
import (
|
|||
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
|
|||
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
|
|||
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
|
|||
|
"context"
|
|||
|
"fmt"
|
|||
|
"github.com/shopspring/decimal"
|
|||
|
"github.com/stretchr/testify/assert"
|
|||
|
"gorm.io/driver/mysql"
|
|||
|
"gorm.io/gorm"
|
|||
|
"gorm.io/gorm/clause"
|
|||
|
"testing"
|
|||
|
"time"
|
|||
|
)
|
|||
|
|
|||
|
func TestWalletService_InitializeWallets(t *testing.T) {
|
|||
|
_, db, teardown, err := SetupTestWalletRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer teardown()
|
|||
|
|
|||
|
// 先手動建表
|
|||
|
createTable := `
|
|||
|
CREATE TABLE IF NOT EXISTS wallet (
|
|||
|
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
|||
|
brand VARCHAR(50) NOT NULL,
|
|||
|
uid VARCHAR(64) NOT NULL,
|
|||
|
asset VARCHAR(32) NOT NULL,
|
|||
|
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
|||
|
type TINYINT NOT NULL,
|
|||
|
create_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
update_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
PRIMARY KEY (id),
|
|||
|
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
|||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
|||
|
assert.NoError(t, db.Exec(createTable).Error)
|
|||
|
|
|||
|
type args struct {
|
|||
|
uid string
|
|||
|
asset string
|
|||
|
brand string
|
|||
|
}
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
args args
|
|||
|
wantCount int
|
|||
|
wantErr bool
|
|||
|
validateDB bool
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "正常初始化一次",
|
|||
|
args: args{uid: "user1", asset: "BTC", brand: "brandA"},
|
|||
|
wantCount: len(wallet.AllTypes),
|
|||
|
},
|
|||
|
{
|
|||
|
name: "再次初始化同一 UID/asset/brand,應因 UNIQUE KEY 失敗",
|
|||
|
args: args{uid: "user1", asset: "BTC", brand: "brandA"},
|
|||
|
wantErr: true,
|
|||
|
validateDB: false,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
service := NewWalletService(db, tt.args.uid, tt.args.asset)
|
|||
|
ctx := context.Background()
|
|||
|
|
|||
|
got, err := service.InitializeWallets(ctx, tt.args.brand)
|
|||
|
if tt.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
// 回傳 slice 長度應等於 AllTypes
|
|||
|
assert.Len(t, got, tt.wantCount)
|
|||
|
|
|||
|
if tt.validateDB {
|
|||
|
// 再查一次 DB,確認確實寫入
|
|||
|
var dbRows []entity.Wallet
|
|||
|
err := db.WithContext(ctx).
|
|||
|
Where("uid = ? AND asset = ? AND brand = ?", tt.args.uid, tt.args.asset, tt.args.brand).
|
|||
|
Find(&dbRows).Error
|
|||
|
assert.NoError(t, err)
|
|||
|
assert.Len(t, dbRows, tt.wantCount)
|
|||
|
// 檢查每一筆都初始為零
|
|||
|
for _, w := range dbRows {
|
|||
|
assert.Equal(t, decimal.Zero, w.Balance)
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestWalletService_GetAllBalances(t *testing.T) {
|
|||
|
_, db, teardown, err := SetupTestWalletRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer teardown()
|
|||
|
|
|||
|
// 建表
|
|||
|
createTable := `
|
|||
|
CREATE TABLE IF NOT EXISTS wallet (
|
|||
|
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
|||
|
brand VARCHAR(50) NOT NULL,
|
|||
|
uid VARCHAR(64) NOT NULL,
|
|||
|
asset VARCHAR(32) NOT NULL,
|
|||
|
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
|||
|
type TINYINT NOT NULL,
|
|||
|
create_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
update_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
PRIMARY KEY (id),
|
|||
|
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
|||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
|||
|
assert.NoError(t, db.Exec(createTable).Error)
|
|||
|
|
|||
|
type seedRow struct {
|
|||
|
Brand string
|
|||
|
UID string
|
|||
|
Asset string
|
|||
|
Balance decimal.Decimal
|
|||
|
Type wallet.Types
|
|||
|
}
|
|||
|
|
|||
|
type args struct {
|
|||
|
uid string
|
|||
|
asset string
|
|||
|
brand string
|
|||
|
}
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
args args
|
|||
|
seed []seedRow
|
|||
|
wantCount int
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "no data returns empty",
|
|||
|
args: args{"u1", "BTC", "b1"},
|
|||
|
seed: nil,
|
|||
|
wantCount: 0,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "single user many types",
|
|||
|
args: args{"u1", "BTC", "b1"},
|
|||
|
seed: []seedRow{
|
|||
|
{"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]},
|
|||
|
{"b1", "u1", "BTC", decimal.NewFromInt(20), wallet.AllTypes[1]},
|
|||
|
{"b1", "u1", "BTC", decimal.NewFromInt(30), wallet.AllTypes[2]},
|
|||
|
},
|
|||
|
wantCount: 3,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "mixed users and assets",
|
|||
|
args: args{"u2", "ETH", "b2"},
|
|||
|
seed: []seedRow{
|
|||
|
{"b1", "u1", "BTC", decimal.NewFromInt(10), wallet.AllTypes[0]},
|
|||
|
{"b2", "u2", "ETH", decimal.NewFromInt(15), wallet.AllTypes[1]},
|
|||
|
{"b2", "u2", "ETH", decimal.NewFromInt(25), wallet.AllTypes[2]},
|
|||
|
},
|
|||
|
wantCount: 2,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
ctx := context.Background()
|
|||
|
// 清空表
|
|||
|
assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error)
|
|||
|
|
|||
|
// 插入 seed
|
|||
|
for _, row := range tt.seed {
|
|||
|
err := db.WithContext(ctx).Exec(
|
|||
|
`INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at)
|
|||
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|||
|
row.Brand, row.UID, row.Asset, row.Balance, row.Type,
|
|||
|
time.Now().Unix(), time.Now().Unix(),
|
|||
|
).Error
|
|||
|
assert.NoError(t, err)
|
|||
|
}
|
|||
|
|
|||
|
// 建立 service
|
|||
|
svc := NewWalletService(db, tt.args.uid, tt.args.asset)
|
|||
|
|
|||
|
// 呼叫 GetAllBalances
|
|||
|
got, err := svc.GetAllBalances(ctx)
|
|||
|
assert.NoError(t, err)
|
|||
|
assert.Len(t, got, tt.wantCount)
|
|||
|
|
|||
|
// 檢查回傳資料與本地快取一致
|
|||
|
ws := svc.(*WalletService)
|
|||
|
for _, w := range got {
|
|||
|
// 回傳的每筆都應該存在於 localBalances
|
|||
|
cached, ok := ws.localBalances[w.Type]
|
|||
|
assert.True(t, ok)
|
|||
|
assert.Equal(t, w.Balance, cached.Balance)
|
|||
|
assert.Equal(t, w.Asset, cached.Asset)
|
|||
|
assert.Equal(t, w.Type, cached.Type)
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestWalletService_GetBalancesForTypes(t *testing.T) {
|
|||
|
_, db, teardown, err := SetupTestWalletRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer teardown()
|
|||
|
|
|||
|
// 建表
|
|||
|
createTable := `
|
|||
|
CREATE TABLE IF NOT EXISTS wallet (
|
|||
|
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
|||
|
brand VARCHAR(50) NOT NULL,
|
|||
|
uid VARCHAR(64) NOT NULL,
|
|||
|
asset VARCHAR(32) NOT NULL,
|
|||
|
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
|||
|
type TINYINT NOT NULL,
|
|||
|
create_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
update_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
PRIMARY KEY (id),
|
|||
|
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
|||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
|||
|
assert.NoError(t, db.Exec(createTable).Error)
|
|||
|
|
|||
|
type seedRow struct {
|
|||
|
Brand string
|
|||
|
UID string
|
|||
|
Asset string
|
|||
|
Balance decimal.Decimal
|
|||
|
Type wallet.Types
|
|||
|
}
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
uid string
|
|||
|
asset string
|
|||
|
brand string
|
|||
|
seed []seedRow
|
|||
|
kinds []wallet.Types
|
|||
|
wantCount int
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "no matching types returns empty",
|
|||
|
uid: "user1", asset: "BTC", brand: "b1",
|
|||
|
seed: []seedRow{
|
|||
|
{"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]},
|
|||
|
},
|
|||
|
kinds: []wallet.Types{wallet.AllTypes[1]},
|
|||
|
wantCount: 0,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "single type match",
|
|||
|
uid: "user1", asset: "BTC", brand: "b1",
|
|||
|
seed: []seedRow{
|
|||
|
{"b1", "user1", "BTC", decimal.NewFromInt(5), wallet.AllTypes[0]},
|
|||
|
{"b1", "user1", "BTC", decimal.NewFromInt(7), wallet.AllTypes[1]},
|
|||
|
},
|
|||
|
kinds: []wallet.Types{wallet.AllTypes[1]},
|
|||
|
wantCount: 1,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple type matches",
|
|||
|
uid: "user2", asset: "ETH", brand: "b2",
|
|||
|
seed: []seedRow{
|
|||
|
{"b2", "user2", "ETH", decimal.NewFromInt(3), wallet.AllTypes[0]},
|
|||
|
{"b2", "user2", "ETH", decimal.NewFromInt(8), wallet.AllTypes[2]},
|
|||
|
{"b2", "user2", "ETH", decimal.NewFromInt(10), wallet.AllTypes[1]},
|
|||
|
},
|
|||
|
kinds: []wallet.Types{wallet.AllTypes[0], wallet.AllTypes[2]},
|
|||
|
wantCount: 2,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
ctx := context.Background()
|
|||
|
// 清表
|
|||
|
assert.NoError(t, db.Exec(`DELETE FROM wallet`).Error)
|
|||
|
|
|||
|
// 插入種子資料
|
|||
|
for _, r := range tt.seed {
|
|||
|
assert.NoError(t, db.Exec(
|
|||
|
`INSERT INTO wallet (brand, uid, asset, balance, type, create_at, update_at)
|
|||
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|||
|
r.Brand, r.UID, r.Asset, r.Balance, r.Type,
|
|||
|
time.Now().Unix(), time.Now().Unix(),
|
|||
|
).Error)
|
|||
|
}
|
|||
|
|
|||
|
// 建立 service
|
|||
|
svc := NewWalletService(db, tt.uid, tt.asset)
|
|||
|
|
|||
|
// 呼叫 GetBalancesForTypes
|
|||
|
got, err := svc.GetBalancesForTypes(ctx, tt.kinds)
|
|||
|
assert.NoError(t, err)
|
|||
|
assert.Len(t, got, tt.wantCount)
|
|||
|
|
|||
|
// 檢查每筆結果皆正確緩存
|
|||
|
ws := svc.(*WalletService)
|
|||
|
for _, w := range got {
|
|||
|
cached, ok := ws.localBalances[w.Type]
|
|||
|
assert.True(t, ok)
|
|||
|
assert.Equal(t, w.Balance, cached.Balance)
|
|||
|
assert.Equal(t, w.Asset, cached.Asset)
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestForUpdateLockBehavior(t *testing.T) {
|
|||
|
// 建立測試環境
|
|||
|
_, db, teardown, err := SetupTestWalletRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer teardown()
|
|||
|
|
|||
|
// 建表
|
|||
|
createTable := `
|
|||
|
CREATE TABLE IF NOT EXISTS wallet (
|
|||
|
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
|||
|
brand VARCHAR(50) NOT NULL,
|
|||
|
uid VARCHAR(64) NOT NULL,
|
|||
|
asset VARCHAR(32) NOT NULL,
|
|||
|
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
|||
|
type TINYINT NOT NULL,
|
|||
|
create_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
update_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
PRIMARY KEY (id),
|
|||
|
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
|||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
|||
|
assert.NoError(t, db.Exec(createTable).Error)
|
|||
|
|
|||
|
// 種子資料:一筆 wallet
|
|||
|
ctx := context.Background()
|
|||
|
initial := entity.Wallet{
|
|||
|
Brand: "b1",
|
|||
|
UID: "user1",
|
|||
|
Asset: "BTC",
|
|||
|
Balance: decimal.NewFromInt(100),
|
|||
|
Type: wallet.AllTypes[0],
|
|||
|
}
|
|||
|
err = db.WithContext(ctx).Create(&initial).Error
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
// 讀回自動產生的 ID
|
|||
|
var seeded entity.Wallet
|
|||
|
assert.NoError(t, db.Where("uid = ? AND asset = ?", initial.UID, initial.Asset).
|
|||
|
First(&seeded).Error)
|
|||
|
|
|||
|
// 開啟第一個 transaction 並 SELECT … FOR UPDATE
|
|||
|
tx1 := db.Begin()
|
|||
|
var locked entity.Wallet
|
|||
|
err = tx1.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|||
|
WithContext(ctx).
|
|||
|
Where("id = ?", seeded.ID).
|
|||
|
Take(&locked).Error
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
// 啟動 goroutine 嘗試在第二 transaction 更新該 row
|
|||
|
done := make(chan error, 1)
|
|||
|
go func() {
|
|||
|
tx2 := db.Begin()
|
|||
|
// 試圖更新,被 FOR UPDATE 鎖住前應該會 block
|
|||
|
err := tx2.WithContext(ctx).
|
|||
|
Model(&entity.Wallet{}).
|
|||
|
Where("id = ?", seeded.ID).
|
|||
|
Update("balance", seeded.Balance.Add(decimal.NewFromInt(50))).Error
|
|||
|
done <- err
|
|||
|
}()
|
|||
|
|
|||
|
// 等 100ms 確認尚未完成
|
|||
|
time.Sleep(100 * time.Millisecond)
|
|||
|
select {
|
|||
|
case err2 := <-done:
|
|||
|
t.Fatalf("expected update to be blocked, but completed with err=%v", err2)
|
|||
|
default:
|
|||
|
// 正在等待鎖
|
|||
|
}
|
|||
|
|
|||
|
// 釋放鎖
|
|||
|
assert.NoError(t, tx1.Commit().Error)
|
|||
|
|
|||
|
// 現在第二 transaction 應該很快完成
|
|||
|
select {
|
|||
|
case err2 := <-done:
|
|||
|
assert.NoError(t, err2)
|
|||
|
case <-time.After(500 * time.Millisecond):
|
|||
|
t.Fatal("update did not complete after lock released")
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestWalletService_IncreaseBalance(t *testing.T) {
|
|||
|
uid := "user1"
|
|||
|
asset := "BTC"
|
|||
|
|
|||
|
type testCase struct {
|
|||
|
name string
|
|||
|
kind wallet.Types
|
|||
|
initial *entity.Wallet // nil 表示不存在
|
|||
|
orderID string
|
|||
|
amount decimal.Decimal
|
|||
|
wantErr error
|
|||
|
wantBalance decimal.Decimal
|
|||
|
wantTxCount int
|
|||
|
}
|
|||
|
tests := []testCase{
|
|||
|
{
|
|||
|
name: "missing wallet type",
|
|||
|
kind: wallet.AllTypes[0],
|
|||
|
initial: nil,
|
|||
|
orderID: "ord1",
|
|||
|
amount: decimal.NewFromInt(10),
|
|||
|
wantErr: repository.ErrRecordNotFound,
|
|||
|
wantBalance: decimal.Zero,
|
|||
|
wantTxCount: 0,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "increase from zero",
|
|||
|
kind: wallet.AllTypes[1],
|
|||
|
initial: &entity.Wallet{Balance: decimal.Zero},
|
|||
|
orderID: "ord2",
|
|||
|
amount: decimal.NewFromInt(15),
|
|||
|
wantErr: nil,
|
|||
|
wantBalance: decimal.NewFromInt(15),
|
|||
|
wantTxCount: 1,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "successful increment on non-zero",
|
|||
|
kind: wallet.AllTypes[2],
|
|||
|
initial: &entity.Wallet{Balance: decimal.NewFromInt(5)},
|
|||
|
orderID: "ord3",
|
|||
|
amount: decimal.NewFromInt(7),
|
|||
|
wantErr: nil,
|
|||
|
wantBalance: decimal.NewFromInt(12),
|
|||
|
wantTxCount: 1,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "insufficient leads to error",
|
|||
|
kind: wallet.AllTypes[2],
|
|||
|
initial: &entity.Wallet{Balance: decimal.NewFromInt(3)},
|
|||
|
orderID: "ord4",
|
|||
|
amount: decimal.NewFromInt(-5),
|
|||
|
wantErr: repository.ErrBalanceInsufficient,
|
|||
|
wantBalance: decimal.NewFromInt(3),
|
|||
|
wantTxCount: 0,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tc := range tests {
|
|||
|
t.Run(tc.name, func(t *testing.T) {
|
|||
|
// 準備 WalletService
|
|||
|
svc := NewWalletService(nil, uid, asset).(*WalletService)
|
|||
|
// 初始化本地快取
|
|||
|
if tc.initial != nil {
|
|||
|
svc.localBalances = map[wallet.Types]entity.Wallet{tc.kind: *tc.initial}
|
|||
|
} else {
|
|||
|
svc.localBalances = map[wallet.Types]entity.Wallet{}
|
|||
|
}
|
|||
|
// 清空交易紀錄
|
|||
|
svc.transactions = nil
|
|||
|
|
|||
|
// 執行
|
|||
|
err := svc.IncreaseBalance(tc.kind, tc.orderID, tc.amount)
|
|||
|
|
|||
|
// 驗證錯誤
|
|||
|
if tc.wantErr != nil {
|
|||
|
assert.ErrorIs(t, err, tc.wantErr)
|
|||
|
} else {
|
|||
|
assert.NoError(t, err)
|
|||
|
}
|
|||
|
|
|||
|
// 驗證快取餘額
|
|||
|
gotBal := svc.CurrentBalance(tc.kind)
|
|||
|
assert.True(t, gotBal.Equal(tc.wantBalance), "balance = %s, want %s", gotBal, tc.wantBalance)
|
|||
|
|
|||
|
// 驗證交易筆數
|
|||
|
assert.Len(t, svc.transactions, tc.wantTxCount)
|
|||
|
if tc.wantTxCount > 0 {
|
|||
|
tx := svc.transactions[0]
|
|||
|
assert.Equal(t, tc.orderID, tx.OrderID)
|
|||
|
assert.Equal(t, uid, tx.UID)
|
|||
|
assert.Equal(t, asset, tx.Asset)
|
|||
|
assert.True(t, tx.Amount.Equal(tc.amount))
|
|||
|
assert.True(t, tx.Balance.Equal(tc.wantBalance))
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestWalletService_PrepareTransactions(t *testing.T) {
|
|||
|
const (
|
|||
|
txID = int64(42)
|
|||
|
orderID = "order-123"
|
|||
|
brand = "brandX"
|
|||
|
bizNameStr = "business-test"
|
|||
|
)
|
|||
|
biz := wallet.BusinessName(bizNameStr)
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
initialTxs []entity.WalletTransaction
|
|||
|
wantCount int
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "no transactions returns empty slice",
|
|||
|
initialTxs: nil,
|
|||
|
wantCount: 0,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "single transaction is populated",
|
|||
|
initialTxs: []entity.WalletTransaction{
|
|||
|
{
|
|||
|
OrderID: "placeholder",
|
|||
|
UID: "u1",
|
|||
|
WalletType: wallet.AllTypes[0],
|
|||
|
Asset: "BTC",
|
|||
|
Amount: decimal.NewFromInt(5),
|
|||
|
Balance: decimal.NewFromInt(10),
|
|||
|
},
|
|||
|
},
|
|||
|
wantCount: 1,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple transactions are all populated",
|
|||
|
initialTxs: []entity.WalletTransaction{
|
|||
|
{UID: "u1", WalletType: wallet.AllTypes[1], Asset: "ETH", Amount: decimal.NewFromInt(1), Balance: decimal.NewFromInt(1)},
|
|||
|
{UID: "u2", WalletType: wallet.AllTypes[2], Asset: "TWD", Amount: decimal.NewFromInt(2), Balance: decimal.NewFromInt(2)},
|
|||
|
{UID: "u3", WalletType: wallet.AllTypes[3%len(wallet.AllTypes)], Asset: "USD", Amount: decimal.NewFromInt(3), Balance: decimal.NewFromInt(3)},
|
|||
|
},
|
|||
|
wantCount: 3,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tc := range tests {
|
|||
|
t.Run(tc.name, func(t *testing.T) {
|
|||
|
// 建立 WalletService 並注入初始交易
|
|||
|
svc := &WalletService{
|
|||
|
transactions: make([]entity.WalletTransaction, len(tc.initialTxs)),
|
|||
|
}
|
|||
|
copy(svc.transactions, tc.initialTxs)
|
|||
|
|
|||
|
// 執行 PrepareTransactions
|
|||
|
result := svc.PrepareTransactions(txID, orderID, brand, biz)
|
|||
|
|
|||
|
// 檢查回傳長度
|
|||
|
assert.Len(t, result, tc.wantCount)
|
|||
|
assert.Len(t, svc.transactions, tc.wantCount)
|
|||
|
|
|||
|
// 驗證每筆交易的共用欄位是否正確
|
|||
|
for i := 0; i < tc.wantCount; i++ {
|
|||
|
tx := result[i]
|
|||
|
assert.Equal(t, txID, tx.TransactionID, "tx[%d].TransactionID", i)
|
|||
|
assert.Equal(t, orderID, tx.OrderID, "tx[%d].OrderID", i)
|
|||
|
assert.Equal(t, brand, tx.Brand, "tx[%d].Brand", i)
|
|||
|
assert.Equal(t, biz.ToInt8(), tx.BusinessType, "tx[%d].BusinessType", i)
|
|||
|
// 原有欄位保持不變
|
|||
|
assert.Equal(t, tc.initialTxs[i].UID, tx.UID, "tx[%d].UID unchanged", i)
|
|||
|
assert.Equal(t, tc.initialTxs[i].WalletType, tx.WalletType, "tx[%d].WalletType unchanged", i)
|
|||
|
assert.Equal(t, tc.initialTxs[i].Asset, tx.Asset, "tx[%d].Asset unchanged", i)
|
|||
|
assert.True(t, tx.Amount.Equal(tc.initialTxs[i].Amount), "tx[%d].Amount unchanged", i)
|
|||
|
assert.True(t, tx.Balance.Equal(tc.initialTxs[i].Balance), "tx[%d].Balance unchanged", i)
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestWalletService_PersistBalances(t *testing.T) {
|
|||
|
_, db, teardown, err := SetupTestWalletRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer teardown()
|
|||
|
|
|||
|
// helper:删除表、重建表、插入两笔 seed 数据,返回这两笔带 ID 的 slice
|
|||
|
seedWallets := func() []entity.Wallet {
|
|||
|
// DROP + AutoMigrate
|
|||
|
assert.NoError(t, db.Migrator().DropTable(&entity.Wallet{}))
|
|||
|
assert.NoError(t, db.AutoMigrate(&entity.Wallet{}))
|
|||
|
|
|||
|
base := []entity.Wallet{
|
|||
|
{UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(10), Type: wallet.AllTypes[0]},
|
|||
|
{UID: "u1", Asset: "BTC", Brand: "b", Balance: decimal.NewFromInt(20), Type: wallet.AllTypes[1]},
|
|||
|
}
|
|||
|
assert.NoError(t, db.Create(&base).Error)
|
|||
|
return base
|
|||
|
}
|
|||
|
|
|||
|
type fields struct {
|
|||
|
updates map[wallet.Types]decimal.Decimal
|
|||
|
}
|
|||
|
type want struct {
|
|||
|
final []decimal.Decimal
|
|||
|
err bool
|
|||
|
}
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
fields fields
|
|||
|
want want
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "no local balances → no change",
|
|||
|
fields: fields{updates: map[wallet.Types]decimal.Decimal{}},
|
|||
|
want: want{
|
|||
|
final: []decimal.Decimal{decimal.NewFromInt(10), decimal.NewFromInt(20)},
|
|||
|
err: false,
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "update both balances",
|
|||
|
fields: fields{updates: map[wallet.Types]decimal.Decimal{
|
|||
|
wallet.AllTypes[0]: decimal.NewFromInt(15),
|
|||
|
wallet.AllTypes[1]: decimal.NewFromInt(5),
|
|||
|
}},
|
|||
|
want: want{
|
|||
|
final: []decimal.Decimal{decimal.NewFromInt(15), decimal.NewFromInt(5)},
|
|||
|
err: false,
|
|||
|
},
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tc := range tests {
|
|||
|
t.Run(tc.name, func(t *testing.T) {
|
|||
|
// seed 并拿到带 ID 的记录
|
|||
|
seed := seedWallets()
|
|||
|
|
|||
|
// 构造 localBalances:只有 Type 在 updates 里的才覆盖
|
|||
|
localMap := make(map[wallet.Types]entity.Wallet, len(tc.fields.updates))
|
|||
|
for _, w := range seed {
|
|||
|
if nb, ok := tc.fields.updates[w.Type]; ok {
|
|||
|
localMap[w.Type] = entity.Wallet{
|
|||
|
ID: w.ID,
|
|||
|
Balance: nb,
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// 初始化 service 并注入 localBalances
|
|||
|
svc := NewWalletService(db, "u1", "BTC").(*WalletService)
|
|||
|
svc.localBalances = localMap
|
|||
|
|
|||
|
// 执行 PersistBalances
|
|||
|
err := svc.PersistBalances(context.Background())
|
|||
|
if tc.want.err {
|
|||
|
assert.Error(t, err)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
// 重新查 DB,按 ID 顺序比较余额
|
|||
|
var got []entity.Wallet
|
|||
|
assert.NoError(t, db.
|
|||
|
Where("uid = ?", "u1").
|
|||
|
Order("id").
|
|||
|
Find(&got).Error)
|
|||
|
|
|||
|
assert.Len(t, got, len(seed))
|
|||
|
for i, w := range got {
|
|||
|
assert.Truef(t,
|
|||
|
w.Balance.Equal(tc.want.final[i]),
|
|||
|
"第 %d 条记录 (id=%d) 余额 = %s, 期望 %s",
|
|||
|
i, w.ID, w.Balance, tc.want.final[i],
|
|||
|
)
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestWalletService_PersistOrderBalances(t *testing.T) {
|
|||
|
_, db, teardown, err := SetupTestWalletRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer teardown()
|
|||
|
|
|||
|
// helper:重建 transaction 表並 seed 資料
|
|||
|
seedTransactions := func() []entity.Transaction {
|
|||
|
// DROP + AutoMigrate
|
|||
|
_ = db.Migrator().DropTable(&entity.Transaction{})
|
|||
|
assert.NoError(t, db.AutoMigrate(&entity.Transaction{}))
|
|||
|
|
|||
|
// 插入兩筆 transaction
|
|||
|
base := []entity.Transaction{
|
|||
|
{PostTransferBalance: decimal.NewFromInt(100)},
|
|||
|
{PostTransferBalance: decimal.NewFromInt(200)},
|
|||
|
}
|
|||
|
assert.NoError(t, db.Create(&base).Error)
|
|||
|
return base
|
|||
|
}
|
|||
|
|
|||
|
type fields struct {
|
|||
|
updates map[int64]decimal.Decimal
|
|||
|
}
|
|||
|
type want struct {
|
|||
|
final []decimal.Decimal
|
|||
|
err bool
|
|||
|
}
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
fields fields
|
|||
|
want want
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "no local order balances → no change",
|
|||
|
fields: fields{updates: map[int64]decimal.Decimal{}},
|
|||
|
want: want{
|
|||
|
final: []decimal.Decimal{decimal.NewFromInt(100), decimal.NewFromInt(200)},
|
|||
|
err: false,
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "update first order balance only",
|
|||
|
fields: fields{updates: map[int64]decimal.Decimal{
|
|||
|
1: decimal.NewFromInt(150),
|
|||
|
}},
|
|||
|
want: want{
|
|||
|
final: []decimal.Decimal{decimal.NewFromInt(150), decimal.NewFromInt(200)},
|
|||
|
err: false,
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "update both order balances",
|
|||
|
fields: fields{updates: map[int64]decimal.Decimal{
|
|||
|
1: decimal.NewFromInt(110),
|
|||
|
2: decimal.NewFromInt(220),
|
|||
|
}},
|
|||
|
want: want{
|
|||
|
final: []decimal.Decimal{decimal.NewFromInt(110), decimal.NewFromInt(220)},
|
|||
|
err: false,
|
|||
|
},
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tc := range tests {
|
|||
|
t.Run(tc.name, func(t *testing.T) {
|
|||
|
// seed 資料並拿回 slice(包含自動產生的 ID)
|
|||
|
seed := seedTransactions()
|
|||
|
|
|||
|
// 建構 localOrderBalances:key 是 seed[i].ID
|
|||
|
localMap := make(map[int64]decimal.Decimal, len(tc.fields.updates))
|
|||
|
for id, newBal := range tc.fields.updates {
|
|||
|
localMap[id] = newBal
|
|||
|
}
|
|||
|
|
|||
|
// 初始化 service、注入 localOrderBalances
|
|||
|
svc := NewWalletService(db, "u1", "BTC").(*WalletService)
|
|||
|
svc.localOrderBalances = localMap
|
|||
|
|
|||
|
// 執行 PersistOrderBalances
|
|||
|
err := svc.PersistOrderBalances(context.Background())
|
|||
|
if tc.want.err {
|
|||
|
assert.Error(t, err)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
// 依照 seed 的 ID 順序讀回資料庫
|
|||
|
var got []entity.Transaction
|
|||
|
assert.NoError(t, db.
|
|||
|
Where("1 = 1").
|
|||
|
Order("id").
|
|||
|
Find(&got).Error)
|
|||
|
// 長度要和 seed 相同
|
|||
|
assert.Len(t, got, len(seed))
|
|||
|
// 比對每一筆 balance
|
|||
|
for i, tr := range got {
|
|||
|
assert.Truef(t,
|
|||
|
tr.PostTransferBalance.Equal(tc.want.final[i]),
|
|||
|
"第 %d 筆交易(id=%d) balance = %s, want %s",
|
|||
|
i, tr.ID, tr.PostTransferBalance, tc.want.final[i],
|
|||
|
)
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestWalletService_HasAvailableBalance(t *testing.T) {
|
|||
|
_, db, teardown, err := SetupTestWalletRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer teardown()
|
|||
|
|
|||
|
// 建表
|
|||
|
createTable := `
|
|||
|
CREATE TABLE IF NOT EXISTS wallet (
|
|||
|
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
|||
|
brand VARCHAR(50) NOT NULL,
|
|||
|
uid VARCHAR(64) NOT NULL,
|
|||
|
asset VARCHAR(32) NOT NULL,
|
|||
|
balance DECIMAL(30,18) NOT NULL DEFAULT 0,
|
|||
|
type TINYINT NOT NULL,
|
|||
|
create_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
update_at INTEGER NOT NULL DEFAULT 0,
|
|||
|
PRIMARY KEY (id),
|
|||
|
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type)
|
|||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
|||
|
assert.NoError(t, db.Exec(createTable).Error)
|
|||
|
|
|||
|
type seedRow struct {
|
|||
|
Brand string
|
|||
|
UID string
|
|||
|
Asset string
|
|||
|
Type wallet.Types
|
|||
|
Balance decimal.Decimal
|
|||
|
}
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
uid string
|
|||
|
asset string
|
|||
|
seed []seedRow
|
|||
|
wantExist bool
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "無任何紀錄",
|
|||
|
uid: "user1",
|
|||
|
asset: "BTC",
|
|||
|
seed: nil,
|
|||
|
wantExist: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "只有其他類型錢包",
|
|||
|
uid: "user1",
|
|||
|
asset: "BTC",
|
|||
|
seed: []seedRow{
|
|||
|
{"", "user1", "BTC", wallet.TypeFreeze, decimal.Zero},
|
|||
|
{"", "user1", "BTC", wallet.TypeUnconfirmed, decimal.Zero},
|
|||
|
},
|
|||
|
wantExist: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "已有可用錢包",
|
|||
|
uid: "user1",
|
|||
|
asset: "BTC",
|
|||
|
seed: []seedRow{
|
|||
|
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(10)},
|
|||
|
},
|
|||
|
wantExist: true,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "不同 UID 不算",
|
|||
|
uid: "user2",
|
|||
|
asset: "BTC",
|
|||
|
seed: []seedRow{
|
|||
|
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
|||
|
},
|
|||
|
wantExist: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "不同 Asset 不算",
|
|||
|
uid: "user1",
|
|||
|
asset: "ETH",
|
|||
|
seed: []seedRow{
|
|||
|
{"", "user1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
|||
|
},
|
|||
|
wantExist: false,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tc := range tests {
|
|||
|
t.Run(tc.name, func(t *testing.T) {
|
|||
|
// 每個子測試前都清空 wallet table
|
|||
|
assert.NoError(t, db.Exec("DELETE FROM wallet").Error)
|
|||
|
|
|||
|
// seed 資料到 wallet
|
|||
|
for _, r := range tc.seed {
|
|||
|
w := entity.Wallet{
|
|||
|
Brand: r.Brand,
|
|||
|
UID: r.UID,
|
|||
|
Asset: r.Asset,
|
|||
|
Type: r.Type,
|
|||
|
Balance: r.Balance,
|
|||
|
}
|
|||
|
assert.NoError(t, db.Create(&w).Error)
|
|||
|
}
|
|||
|
|
|||
|
// 建 service
|
|||
|
svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService)
|
|||
|
got, err := svc.HasAvailableBalance(context.Background())
|
|||
|
assert.NoError(t, err)
|
|||
|
assert.Equal(t, tc.wantExist, got)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func SetupTestDB(t *testing.T) (*gorm.DB, func()) {
|
|||
|
host, port, _, tearDown, err := startMySQLContainer()
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true",
|
|||
|
MySQLUser, MySQLPassword, host, port, MySQLDatabase,
|
|||
|
)
|
|||
|
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
// 建表
|
|||
|
create := `
|
|||
|
CREATE TABLE IF NOT EXISTS wallet (
|
|||
|
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID',
|
|||
|
brand VARCHAR(50) NOT NULL DEFAULT '' COMMENT '品牌',
|
|||
|
uid VARCHAR(64) NOT NULL COMMENT '使用者 UID',
|
|||
|
asset VARCHAR(32) NOT NULL COMMENT '資產代號',
|
|||
|
balance DECIMAL(30,18) UNSIGNED NOT NULL DEFAULT 0 COMMENT '餘額',
|
|||
|
type TINYINT NOT NULL COMMENT '錢包類型',
|
|||
|
create_at INTEGER NOT NULL DEFAULT 0 COMMENT '建立時間',
|
|||
|
update_at INTEGER NOT NULL DEFAULT 0 COMMENT '更新時間',
|
|||
|
PRIMARY KEY (id),
|
|||
|
UNIQUE KEY uq_brand_uid_asset_type (brand, uid, asset, type),
|
|||
|
KEY idx_uid (uid),
|
|||
|
KEY idx_brand (brand)
|
|||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
|||
|
assert.NoError(t, db.Exec(create).Error)
|
|||
|
|
|||
|
return db, tearDown
|
|||
|
}
|
|||
|
|
|||
|
func TestWalletService_GetBalancesForUpdate(t *testing.T) {
|
|||
|
db, tearDown := SetupTestDB(t)
|
|||
|
defer tearDown()
|
|||
|
|
|||
|
type seedRow struct {
|
|||
|
Brand string
|
|||
|
UID string
|
|||
|
Asset string
|
|||
|
Type wallet.Types
|
|||
|
Balance decimal.Decimal
|
|||
|
}
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
uid string
|
|||
|
asset string
|
|||
|
seed []seedRow
|
|||
|
kinds []wallet.Types
|
|||
|
wantIDs []int64
|
|||
|
expectErr bool
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "查詢空結果",
|
|||
|
uid: "u1",
|
|||
|
asset: "BTC",
|
|||
|
seed: nil,
|
|||
|
kinds: []wallet.Types{wallet.TypeAvailable},
|
|||
|
wantIDs: nil,
|
|||
|
expectErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "單一類型查詢",
|
|||
|
uid: "u1",
|
|||
|
asset: "BTC",
|
|||
|
seed: []seedRow{
|
|||
|
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
|||
|
{"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)},
|
|||
|
},
|
|||
|
kinds: []wallet.Types{wallet.TypeFreeze},
|
|||
|
wantIDs: []int64{2},
|
|||
|
expectErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "多類型查詢",
|
|||
|
uid: "u1",
|
|||
|
asset: "BTC",
|
|||
|
seed: []seedRow{
|
|||
|
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
|||
|
{"", "u1", "BTC", wallet.TypeFreeze, decimal.NewFromInt(2)},
|
|||
|
{"", "u1", "BTC", wallet.TypeUnconfirmed, decimal.NewFromInt(3)},
|
|||
|
},
|
|||
|
kinds: []wallet.Types{wallet.TypeAvailable, wallet.TypeUnconfirmed},
|
|||
|
wantIDs: []int64{3, 5},
|
|||
|
expectErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "不同 UID 不列入",
|
|||
|
uid: "u2",
|
|||
|
asset: "BTC",
|
|||
|
seed: []seedRow{
|
|||
|
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
|||
|
},
|
|||
|
kinds: []wallet.Types{wallet.TypeAvailable},
|
|||
|
wantIDs: nil,
|
|||
|
expectErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "不同 Asset 不列入",
|
|||
|
uid: "u1",
|
|||
|
asset: "ETH",
|
|||
|
seed: []seedRow{
|
|||
|
{"", "u1", "BTC", wallet.TypeAvailable, decimal.NewFromInt(5)},
|
|||
|
},
|
|||
|
kinds: []wallet.Types{wallet.TypeAvailable},
|
|||
|
wantIDs: nil,
|
|||
|
expectErr: false,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tc := range tests {
|
|||
|
t.Run(tc.name, func(t *testing.T) {
|
|||
|
// 清空資料
|
|||
|
assert.NoError(t, db.Exec("DELETE FROM wallet").Error)
|
|||
|
// Seed
|
|||
|
for _, r := range tc.seed {
|
|||
|
w := entity.Wallet{
|
|||
|
Brand: r.Brand,
|
|||
|
UID: r.UID,
|
|||
|
Asset: r.Asset,
|
|||
|
Type: r.Type,
|
|||
|
Balance: r.Balance,
|
|||
|
}
|
|||
|
// Create will auto-assign incremental IDs
|
|||
|
assert.NoError(t, db.Create(&w).Error)
|
|||
|
}
|
|||
|
|
|||
|
// 建 Service
|
|||
|
svc := NewWalletService(db, tc.uid, tc.asset).(*WalletService)
|
|||
|
got, err := svc.GetBalancesForUpdate(context.Background(), tc.kinds)
|
|||
|
if tc.expectErr {
|
|||
|
assert.Error(t, err)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
// 取回 IDs 並排序
|
|||
|
var gotIDs []int64
|
|||
|
for _, w := range got {
|
|||
|
gotIDs = append(gotIDs, w.ID)
|
|||
|
// localBalances 應該被更新
|
|||
|
assert.Equal(t, w, svc.localBalances[w.Type])
|
|||
|
}
|
|||
|
assert.Equal(t, tc.wantIDs, gotIDs)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|