640 lines
18 KiB
Go
640 lines
18 KiB
Go
package repository
|
||
|
||
import (
|
||
"code.30cm.net/digimon/app-cloudep-wallet-service/internal/config"
|
||
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
|
||
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
|
||
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
|
||
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/lib/sql_client"
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/shopspring/decimal"
|
||
"github.com/stretchr/testify/assert"
|
||
"google.golang.org/protobuf/proto"
|
||
"gorm.io/gorm"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
func SetupTestTransactionRepository() (repository.TransactionRepository, *gorm.DB, func(), error) {
|
||
host, port, _, tearDown, err := startMySQLContainer()
|
||
if err != nil {
|
||
return nil, nil, nil, fmt.Errorf("failed to start MySQL container: %w", err)
|
||
}
|
||
|
||
conf := config.Config{
|
||
MySQL: struct {
|
||
UserName string
|
||
Password string
|
||
Host string
|
||
Port string
|
||
Database string
|
||
MaxIdleConns int
|
||
MaxOpenConns int
|
||
ConnMaxLifetime time.Duration
|
||
LogLevel string
|
||
}{
|
||
UserName: MySQLUser,
|
||
Password: MySQLPassword,
|
||
Host: host,
|
||
Port: port,
|
||
Database: MySQLDatabase,
|
||
MaxIdleConns: 10,
|
||
MaxOpenConns: 100,
|
||
ConnMaxLifetime: 300,
|
||
LogLevel: "info",
|
||
},
|
||
}
|
||
|
||
db, err := sql_client.NewMySQLClient(conf)
|
||
if err != nil {
|
||
tearDown()
|
||
return nil, nil, nil, fmt.Errorf("failed to create db client: %w", err)
|
||
}
|
||
|
||
repo := MustTransactionRepository(TransactionRepositoryParam{DB: db})
|
||
return repo, db, tearDown, nil
|
||
}
|
||
|
||
func createTransactionTable(t *testing.T, db *gorm.DB) {
|
||
sql := `
|
||
CREATE TABLE IF NOT EXISTS transaction (
|
||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||
order_id VARCHAR(64) NOT NULL,
|
||
transaction_id VARCHAR(64),
|
||
brand VARCHAR(32),
|
||
uid VARCHAR(64),
|
||
to_uid VARCHAR(64),
|
||
type TINYINT,
|
||
business_type TINYINT,
|
||
asset VARCHAR(32),
|
||
amount DECIMAL(30,18),
|
||
before_balance DECIMAL(30,18),
|
||
post_transfer_balance DECIMAL(30,18),
|
||
status TINYINT,
|
||
create_at BIGINT,
|
||
due_time BIGINT
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
||
assert.NoError(t, db.Exec(sql).Error)
|
||
}
|
||
|
||
func createWalletTransactionTable(t *testing.T, db *gorm.DB) {
|
||
createTableSQL := `
|
||
CREATE TABLE IF NOT EXISTS wallet_transaction (
|
||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID,自動遞增',
|
||
transaction_id BIGINT NOT NULL COMMENT '交易流水號(可對應某次業務操作,例如同一訂單的多筆變化)',
|
||
order_id VARCHAR(64) NOT NULL COMMENT '訂單編號(對應實際訂單或業務事件)',
|
||
brand VARCHAR(50) NOT NULL COMMENT '品牌(多租戶或多平台識別)',
|
||
uid VARCHAR(64) NOT NULL COMMENT '使用者 UID',
|
||
wallet_type TINYINT NOT NULL COMMENT '錢包類型(如主錢包、獎勵錢包、凍結錢包等)',
|
||
business_type TINYINT NOT NULL COMMENT '業務類型(如購物、退款、加值等)',
|
||
asset VARCHAR(32) NOT NULL COMMENT '資產代號(如 BTC、ETH、GEM_RED、USD 等)',
|
||
amount DECIMAL(30,18) NOT NULL COMMENT '變動金額(正數為收入,負數為支出)',
|
||
balance DECIMAL(30,18) NOT NULL COMMENT '當前錢包餘額(這筆交易後的餘額快照)',
|
||
create_at BIGINT NOT NULL DEFAULT 0 COMMENT '建立時間(UnixNano,紀錄交易發生時間)',
|
||
PRIMARY KEY (id),
|
||
KEY idx_uid (uid),
|
||
KEY idx_transaction_id (transaction_id),
|
||
KEY idx_order_id (order_id),
|
||
KEY idx_brand (brand),
|
||
KEY idx_wallet_type (wallet_type)
|
||
) ENGINE=InnoDB
|
||
DEFAULT CHARSET=utf8mb4
|
||
COLLATE=utf8mb4_unicode_ci
|
||
COMMENT='錢包資金異動紀錄(每一次交易行為的快照記錄)';`
|
||
|
||
assert.NoError(t, db.Exec(createTableSQL).Error)
|
||
}
|
||
|
||
func TestTransactionRepository_InsertAndBatchInsert(t *testing.T) {
|
||
// start container and connect
|
||
repo, db, tearDown, err := SetupTestTransactionRepository()
|
||
assert.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
// prepare table
|
||
createTransactionTable(t, db)
|
||
|
||
now := time.Now().Unix()
|
||
template := entity.Transaction{
|
||
OrderID: "o1",
|
||
TransactionID: "tx1",
|
||
Brand: "b1",
|
||
UID: "u1",
|
||
ToUID: "u2",
|
||
TxType: 1,
|
||
BusinessType: 2,
|
||
Asset: "BTC",
|
||
Amount: decimal.RequireFromString("100.5"),
|
||
BeforeBalance: decimal.RequireFromString("50.0"),
|
||
PostTransferBalance: decimal.RequireFromString("150.5"),
|
||
Status: 1,
|
||
CreateAt: now,
|
||
DueTime: now + 3600,
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
setup func(t *testing.T)
|
||
action func() error
|
||
validate func(t *testing.T)
|
||
}{
|
||
{
|
||
name: "single insert",
|
||
setup: func(t *testing.T) {
|
||
// clean table
|
||
assert.NoError(t, db.Exec("DELETE FROM transaction").Error)
|
||
},
|
||
action: func() error {
|
||
tx := template
|
||
return repo.Insert(context.Background(), &tx)
|
||
},
|
||
validate: func(t *testing.T) {
|
||
var count int64
|
||
assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error)
|
||
assert.Equal(t, int64(1), count)
|
||
|
||
var got entity.Transaction
|
||
err := db.Take(&got, "order_id = ?", template.OrderID).Error
|
||
assert.NoError(t, err)
|
||
},
|
||
},
|
||
{
|
||
name: "batch insert",
|
||
setup: func(t *testing.T) {
|
||
assert.NoError(t, db.Exec("DELETE FROM transaction").Error)
|
||
},
|
||
action: func() error {
|
||
// clone two entries with different order IDs
|
||
tx1 := template
|
||
tx2 := template
|
||
tx1.OrderID = "o2"
|
||
tx2.OrderID = "o3"
|
||
return repo.BatchInsert(context.Background(), []*entity.Transaction{&tx1, &tx2})
|
||
},
|
||
validate: func(t *testing.T) {
|
||
var count int64
|
||
assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error)
|
||
assert.Equal(t, int64(2), count)
|
||
|
||
var orders []string
|
||
rows, _ := db.Raw("SELECT order_id FROM transaction ORDER BY order_id").Rows()
|
||
defer rows.Close()
|
||
for rows.Next() {
|
||
var oid string
|
||
rows.Scan(&oid)
|
||
orders = append(orders, oid)
|
||
}
|
||
assert.Equal(t, []string{"o2", "o3"}, orders)
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
tt.setup(t)
|
||
err := tt.action()
|
||
assert.NoError(t, err)
|
||
tt.validate(t)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTransactionRepository_FindByOrderID(t *testing.T) {
|
||
// start container and connect
|
||
repo, db, tearDown, err := SetupTestTransactionRepository()
|
||
assert.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
// prepare table
|
||
createTransactionTable(t, db)
|
||
|
||
// 4) seed one row
|
||
now := time.Now().Unix()
|
||
res := db.Exec(
|
||
`INSERT INTO transaction
|
||
(order_id, transaction_id, brand, uid, to_uid, type, business_type, asset,
|
||
amount, before_balance, post_transfer_balance, status, create_at, due_time)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||
"order-123", "tx-abc", "brandX", "user1", "user2",
|
||
1, 2, "BTC",
|
||
"10.0", "5.0", "15.0",
|
||
1, now, now+3600,
|
||
)
|
||
assert.NoError(t, res.Error)
|
||
assert.Equal(t, int64(1), res.RowsAffected)
|
||
|
||
type want struct {
|
||
tx entity.Transaction
|
||
wantErr bool
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
orderID string
|
||
want want
|
||
}{
|
||
{
|
||
name: "found existing",
|
||
orderID: "order-123",
|
||
want: want{
|
||
tx: entity.Transaction{
|
||
ID: 1,
|
||
OrderID: "order-123",
|
||
TransactionID: "tx-abc",
|
||
Brand: "brandX",
|
||
UID: "user1",
|
||
ToUID: "user2",
|
||
TxType: wallet.TxType(1),
|
||
BusinessType: int8(2),
|
||
Asset: "BTC",
|
||
Amount: decimal.RequireFromString("10.0"),
|
||
BeforeBalance: decimal.RequireFromString("5.0"),
|
||
PostTransferBalance: decimal.RequireFromString("15.0"),
|
||
Status: wallet.Enable(1),
|
||
CreateAt: now,
|
||
DueTime: now + 3600,
|
||
},
|
||
wantErr: false,
|
||
},
|
||
},
|
||
{
|
||
name: "not found",
|
||
orderID: "missing",
|
||
want: want{wantErr: true},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got, err := repo.FindByOrderID(context.Background(), tt.orderID)
|
||
if tt.want.wantErr {
|
||
assert.Error(t, err)
|
||
assert.True(t, errors.Is(err, repository.ErrRecordNotFound))
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, tt.want.tx.ID, got.ID)
|
||
assert.Equal(t, tt.want.tx.OrderID, got.OrderID)
|
||
assert.Equal(t, tt.want.tx.TransactionID, got.TransactionID)
|
||
assert.Equal(t, tt.want.tx.Brand, got.Brand)
|
||
assert.Equal(t, tt.want.tx.UID, got.UID)
|
||
assert.Equal(t, tt.want.tx.ToUID, got.ToUID)
|
||
assert.Equal(t, tt.want.tx.TxType, got.TxType)
|
||
assert.Equal(t, tt.want.tx.BusinessType, got.BusinessType)
|
||
assert.Equal(t, tt.want.tx.Asset, got.Asset)
|
||
assert.True(t, tt.want.tx.Amount.Equal(got.Amount))
|
||
assert.True(t, tt.want.tx.BeforeBalance.Equal(got.BeforeBalance))
|
||
assert.True(t, tt.want.tx.PostTransferBalance.Equal(got.PostTransferBalance))
|
||
assert.Equal(t, tt.want.tx.Status, got.Status)
|
||
assert.Equal(t, tt.want.tx.CreateAt, got.CreateAt)
|
||
assert.Equal(t, tt.want.tx.DueTime, got.DueTime)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTransactionRepository_List(t *testing.T) {
|
||
repo, db, tearDown, err := SetupTestTransactionRepository()
|
||
assert.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
createTransactionTable(t, db)
|
||
|
||
now := time.Now().Unix()
|
||
rows := []entity.Transaction{
|
||
{OrderID: "A", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 30},
|
||
{OrderID: "B", UID: "u2", TxType: wallet.Withdraw, BusinessType: 2, Asset: "ETH", CreateAt: now - 20},
|
||
{OrderID: "C", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 10},
|
||
}
|
||
|
||
assert.NoError(t, db.Create(&rows).Error)
|
||
|
||
tests := []struct {
|
||
name string
|
||
query repository.TransactionQuery
|
||
wantCount int64
|
||
wantIDs []int64
|
||
wantErr error
|
||
}{
|
||
{
|
||
name: "no filter returns all",
|
||
query: repository.TransactionQuery{
|
||
PageIndex: 1,
|
||
PageSize: 50,
|
||
},
|
||
wantCount: 3,
|
||
wantIDs: []int64{rows[2].ID, rows[1].ID, rows[0].ID},
|
||
wantErr: nil,
|
||
},
|
||
{
|
||
name: "filter by UID",
|
||
query: repository.TransactionQuery{
|
||
PageIndex: 1,
|
||
PageSize: 50,
|
||
UID: proto.String("u1"),
|
||
},
|
||
wantCount: 2,
|
||
wantIDs: []int64{rows[2].ID, rows[0].ID},
|
||
},
|
||
{
|
||
name: "filter by BusinessType",
|
||
query: repository.TransactionQuery{
|
||
PageIndex: 1,
|
||
PageSize: 50,
|
||
BusinessType: []int8{2},
|
||
},
|
||
wantCount: 1,
|
||
wantIDs: []int64{rows[1].ID},
|
||
},
|
||
{
|
||
name: "filter by Asset + TxType",
|
||
query: repository.TransactionQuery{
|
||
PageIndex: 1,
|
||
PageSize: 50,
|
||
Assets: proto.String("BTC"),
|
||
TxTypes: []wallet.TxType{wallet.Deposit},
|
||
},
|
||
wantCount: 2,
|
||
wantIDs: []int64{rows[2].ID, rows[0].ID},
|
||
},
|
||
{
|
||
name: "time range filter",
|
||
query: repository.TransactionQuery{
|
||
StartTime: proto.Int64(now - 25),
|
||
EndTime: proto.Int64(now - 5),
|
||
},
|
||
wantCount: 2,
|
||
wantIDs: []int64{rows[2].ID, rows[1].ID},
|
||
},
|
||
{
|
||
name: "paging page 1 size 1",
|
||
query: repository.TransactionQuery{
|
||
PageIndex: 1,
|
||
PageSize: 1,
|
||
},
|
||
wantCount: 3,
|
||
wantIDs: []int64{rows[2].ID},
|
||
},
|
||
{
|
||
name: "paging page 2 size 1",
|
||
query: repository.TransactionQuery{
|
||
PageIndex: 2,
|
||
PageSize: 1,
|
||
},
|
||
wantCount: 3,
|
||
wantIDs: []int64{rows[1].ID},
|
||
},
|
||
{
|
||
name: "no match returns ErrRecordNotFound",
|
||
query: repository.TransactionQuery{UID: proto.String("nonexist")},
|
||
wantErr: repository.ErrRecordNotFound,
|
||
},
|
||
}
|
||
|
||
ctx := context.Background()
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got, cnt, err := repo.List(ctx, repository.TransactionQuery{
|
||
BusinessType: tt.query.BusinessType,
|
||
UID: tt.query.UID,
|
||
OrderID: tt.query.OrderID,
|
||
Assets: tt.query.Assets,
|
||
TxTypes: tt.query.TxTypes,
|
||
StartTime: tt.query.StartTime,
|
||
EndTime: tt.query.EndTime,
|
||
PageIndex: tt.query.PageIndex,
|
||
PageSize: tt.query.PageSize,
|
||
})
|
||
if tt.wantErr != nil {
|
||
assert.ErrorIs(t, err, tt.wantErr)
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, tt.wantCount, cnt)
|
||
|
||
var ids []int64
|
||
for _, tx := range got {
|
||
ids = append(ids, tx.ID)
|
||
}
|
||
assert.Equal(t, tt.wantIDs, ids)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTransactionRepository_FindByDueTimeRange(t *testing.T) {
|
||
// start container and connect
|
||
repo, db, tearDown, err := SetupTestTransactionRepository()
|
||
assert.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
// prepare table
|
||
createTransactionTable(t, db)
|
||
|
||
// seed rows
|
||
now := time.Now().Unix()
|
||
rows := []entity.Transaction{
|
||
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 10, Status: wallet.WalletNonStatus},
|
||
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now + 10, Status: wallet.WalletNonStatus},
|
||
{TxType: wallet.Deposit, BusinessType: 0, DueTime: 0, Status: wallet.WalletNonStatus},
|
||
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 5, Status: wallet.Enable(1)}, // status != non
|
||
}
|
||
|
||
assert.NoError(t, db.Create(&rows).Error)
|
||
|
||
tests := []struct {
|
||
name string
|
||
cutoff time.Time
|
||
types []wallet.TxType
|
||
wantIDs []int64
|
||
}{
|
||
{
|
||
name: "due before now for type=1",
|
||
cutoff: time.Unix(now, 0),
|
||
types: []wallet.TxType{wallet.Deposit},
|
||
wantIDs: []int64{rows[0].ID},
|
||
},
|
||
{
|
||
name: "include multiple types",
|
||
cutoff: time.Unix(now+20, 0),
|
||
types: []wallet.TxType{wallet.Deposit},
|
||
wantIDs: []int64{rows[0].ID, rows[1].ID},
|
||
},
|
||
{
|
||
name: "zero due_time is skipped",
|
||
cutoff: time.Unix(now+100, 0),
|
||
types: []wallet.TxType{wallet.Deposit},
|
||
wantIDs: []int64{rows[0].ID, rows[1].ID},
|
||
},
|
||
{
|
||
name: "no matches",
|
||
cutoff: time.Unix(now-100, 0),
|
||
types: []wallet.TxType{wallet.Deposit},
|
||
wantIDs: nil,
|
||
},
|
||
}
|
||
|
||
ctx := context.Background()
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got, err := repo.FindByDueTimeRange(ctx, tt.cutoff, tt.types)
|
||
assert.NoError(t, err)
|
||
var ids []int64
|
||
for _, tx := range got {
|
||
ids = append(ids, tx.ID)
|
||
}
|
||
assert.Equal(t, tt.wantIDs, ids)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTransactionRepository_UpdateStatusByID(t *testing.T) {
|
||
// start container and connect
|
||
repo, db, tearDown, err := SetupTestTransactionRepository()
|
||
assert.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
createTransactionTable(t, db)
|
||
|
||
ctx := context.Background()
|
||
|
||
tx := &entity.Transaction{
|
||
OrderID: "ord-123",
|
||
TransactionID: "tx-abc",
|
||
Brand: "brand1",
|
||
UID: "user1",
|
||
ToUID: "user2",
|
||
TxType: wallet.TxType(1),
|
||
BusinessType: int8(2),
|
||
Asset: "BTC",
|
||
Amount: decimal.NewFromInt(100),
|
||
PostTransferBalance: decimal.NewFromInt(1000),
|
||
BeforeBalance: decimal.NewFromInt(900),
|
||
Status: wallet.Enable(0),
|
||
CreateAt: time.Now().Unix(),
|
||
DueTime: time.Now().Add(time.Hour).Unix(),
|
||
}
|
||
err = repo.Insert(ctx, tx)
|
||
assert.NoError(t, err)
|
||
existingID := tx.ID
|
||
|
||
tests := []struct {
|
||
name string
|
||
id int64
|
||
newStatus int
|
||
wantErr bool
|
||
wantRow bool
|
||
}{
|
||
{
|
||
name: "update existing row",
|
||
id: existingID,
|
||
newStatus: 1,
|
||
wantErr: false,
|
||
wantRow: true,
|
||
},
|
||
{
|
||
name: "non-existent id does not error",
|
||
id: existingID + 999,
|
||
newStatus: 5,
|
||
wantErr: false,
|
||
wantRow: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := repo.UpdateStatusByID(ctx, tt.id, tt.newStatus)
|
||
if tt.wantErr {
|
||
assert.Error(t, err)
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
|
||
var got entity.Transaction
|
||
res := db.First(&got, tt.id)
|
||
if tt.wantRow {
|
||
// existing row should reflect the new status
|
||
assert.NoError(t, res.Error)
|
||
assert.Equal(t, tt.newStatus, int(got.Status))
|
||
} else {
|
||
// non-existent id: no record found
|
||
assert.Error(t, res.Error)
|
||
assert.True(t, errors.Is(res.Error, gorm.ErrRecordNotFound))
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTransactionRepository_ListWalletTransactions(t *testing.T) {
|
||
// start container and connect
|
||
repo, db, tearDown, err := SetupTestTransactionRepository()
|
||
assert.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
createWalletTransactionTable(t, db)
|
||
ctx := context.Background()
|
||
now := time.Now().UnixNano()
|
||
transactions := []entity.WalletTransaction{
|
||
{OrderID: "o1", UID: "u1", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(10), Balance: decimal.NewFromInt(110), CreateAt: now - 3},
|
||
{OrderID: "o2", UID: "u1", WalletType: wallet.TypeFreeze, Asset: "ETH", Amount: decimal.NewFromInt(20), Balance: decimal.NewFromInt(220), CreateAt: now - 2},
|
||
{OrderID: "o1", UID: "u2", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(30), Balance: decimal.NewFromInt(330), CreateAt: now - 1},
|
||
}
|
||
|
||
assert.NoError(t, db.Create(&transactions).Error)
|
||
|
||
tests := []struct {
|
||
name string
|
||
uid string
|
||
orderIDs []string
|
||
walletType wallet.Types
|
||
wantIDs []int64
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "filter by uid=u1, both orders",
|
||
uid: "u1",
|
||
orderIDs: []string{"o1", "o2"},
|
||
walletType: 0,
|
||
wantIDs: []int64{transactions[0].ID, transactions[1].ID},
|
||
},
|
||
{
|
||
name: "filter by uid=u1, walletType=1",
|
||
uid: "u1",
|
||
orderIDs: []string{"o1", "o2"},
|
||
walletType: wallet.Types(1),
|
||
wantIDs: []int64{transactions[0].ID},
|
||
},
|
||
{
|
||
name: "filter by order=o1 only, any uid",
|
||
uid: "",
|
||
orderIDs: []string{"o1"},
|
||
walletType: wallet.Types(1),
|
||
wantIDs: []int64{transactions[0].ID, transactions[2].ID},
|
||
},
|
||
{
|
||
name: "no match yields empty",
|
||
uid: "nope",
|
||
orderIDs: []string{"o1"},
|
||
walletType: 0,
|
||
wantIDs: []int64{},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got, err := repo.ListWalletTransactions(ctx, tt.uid, tt.orderIDs, tt.walletType)
|
||
if tt.wantErr {
|
||
assert.Error(t, err)
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
// 提取返回的 IDs 进行无序比较
|
||
gotIDs := make([]int64, len(got))
|
||
for i, w := range got {
|
||
gotIDs[i] = w.ID
|
||
}
|
||
assert.ElementsMatch(t, tt.wantIDs, gotIDs)
|
||
})
|
||
}
|
||
}
|