app-cloudep-wallet-service/pkg/repository/transaction_test.go

640 lines
18 KiB
Go
Raw Normal View History

2025-04-21 07:46:43 +00:00
package repository
import (
"code.30cm.net/digimon/app-cloudep-wallet-service/internal/config"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/lib/sql_client"
"context"
"errors"
"fmt"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"gorm.io/gorm"
"testing"
"time"
)
func SetupTestTransactionRepository() (repository.TransactionRepository, *gorm.DB, func(), error) {
host, port, _, tearDown, err := startMySQLContainer()
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to start MySQL container: %w", err)
}
conf := config.Config{
MySQL: struct {
UserName string
Password string
Host string
Port string
Database string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime time.Duration
LogLevel string
}{
UserName: MySQLUser,
Password: MySQLPassword,
Host: host,
Port: port,
Database: MySQLDatabase,
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 300,
LogLevel: "info",
},
}
db, err := sql_client.NewMySQLClient(conf)
if err != nil {
tearDown()
return nil, nil, nil, fmt.Errorf("failed to create db client: %w", err)
}
repo := MustTransactionRepository(TransactionRepositoryParam{DB: db})
return repo, db, tearDown, nil
}
func createTransactionTable(t *testing.T, db *gorm.DB) {
sql := `
CREATE TABLE IF NOT EXISTS transaction (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
order_id VARCHAR(64) NOT NULL,
transaction_id VARCHAR(64),
brand VARCHAR(32),
uid VARCHAR(64),
to_uid VARCHAR(64),
type TINYINT,
business_type TINYINT,
asset VARCHAR(32),
amount DECIMAL(30,18),
before_balance DECIMAL(30,18),
post_transfer_balance DECIMAL(30,18),
status TINYINT,
create_at BIGINT,
due_time BIGINT
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
assert.NoError(t, db.Exec(sql).Error)
}
func createWalletTransactionTable(t *testing.T, db *gorm.DB) {
createTableSQL := `
CREATE TABLE IF NOT EXISTS wallet_transaction (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID自動遞增',
transaction_id BIGINT NOT NULL COMMENT '交易流水號可對應某次業務操作例如同一訂單的多筆變化',
order_id VARCHAR(64) NOT NULL COMMENT '訂單編號對應實際訂單或業務事件',
brand VARCHAR(50) NOT NULL COMMENT '品牌多租戶或多平台識別',
uid VARCHAR(64) NOT NULL COMMENT '使用者 UID',
wallet_type TINYINT NOT NULL COMMENT '錢包類型如主錢包獎勵錢包凍結錢包等',
business_type TINYINT NOT NULL COMMENT '業務類型如購物退款加值等',
asset VARCHAR(32) NOT NULL COMMENT '資產代號 BTCETHGEM_REDUSD ',
amount DECIMAL(30,18) NOT NULL COMMENT '變動金額正數為收入負數為支出',
balance DECIMAL(30,18) NOT NULL COMMENT '當前錢包餘額這筆交易後的餘額快照',
create_at BIGINT NOT NULL DEFAULT 0 COMMENT '建立時間UnixNano紀錄交易發生時間',
PRIMARY KEY (id),
KEY idx_uid (uid),
KEY idx_transaction_id (transaction_id),
KEY idx_order_id (order_id),
KEY idx_brand (brand),
KEY idx_wallet_type (wallet_type)
) ENGINE=InnoDB
DEFAULT CHARSET=utf8mb4
COLLATE=utf8mb4_unicode_ci
COMMENT='錢包資金異動紀錄每一次交易行為的快照記錄';`
assert.NoError(t, db.Exec(createTableSQL).Error)
}
func TestTransactionRepository_InsertAndBatchInsert(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
// prepare table
createTransactionTable(t, db)
now := time.Now().Unix()
template := entity.Transaction{
OrderID: "o1",
TransactionID: "tx1",
Brand: "b1",
UID: "u1",
ToUID: "u2",
TxType: 1,
BusinessType: 2,
Asset: "BTC",
Amount: decimal.RequireFromString("100.5"),
BeforeBalance: decimal.RequireFromString("50.0"),
PostTransferBalance: decimal.RequireFromString("150.5"),
Status: 1,
CreateAt: now,
DueTime: now + 3600,
}
tests := []struct {
name string
setup func(t *testing.T)
action func() error
validate func(t *testing.T)
}{
{
name: "single insert",
setup: func(t *testing.T) {
// clean table
assert.NoError(t, db.Exec("DELETE FROM transaction").Error)
},
action: func() error {
tx := template
return repo.Insert(context.Background(), &tx)
},
validate: func(t *testing.T) {
var count int64
assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error)
assert.Equal(t, int64(1), count)
var got entity.Transaction
err := db.Take(&got, "order_id = ?", template.OrderID).Error
assert.NoError(t, err)
},
},
{
name: "batch insert",
setup: func(t *testing.T) {
assert.NoError(t, db.Exec("DELETE FROM transaction").Error)
},
action: func() error {
// clone two entries with different order IDs
tx1 := template
tx2 := template
tx1.OrderID = "o2"
tx2.OrderID = "o3"
return repo.BatchInsert(context.Background(), []*entity.Transaction{&tx1, &tx2})
},
validate: func(t *testing.T) {
var count int64
assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error)
assert.Equal(t, int64(2), count)
var orders []string
rows, _ := db.Raw("SELECT order_id FROM transaction ORDER BY order_id").Rows()
defer rows.Close()
for rows.Next() {
var oid string
rows.Scan(&oid)
orders = append(orders, oid)
}
assert.Equal(t, []string{"o2", "o3"}, orders)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup(t)
err := tt.action()
assert.NoError(t, err)
tt.validate(t)
})
}
}
func TestTransactionRepository_FindByOrderID(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
// prepare table
createTransactionTable(t, db)
// 4) seed one row
now := time.Now().Unix()
res := db.Exec(
`INSERT INTO transaction
(order_id, transaction_id, brand, uid, to_uid, type, business_type, asset,
amount, before_balance, post_transfer_balance, status, create_at, due_time)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
"order-123", "tx-abc", "brandX", "user1", "user2",
1, 2, "BTC",
"10.0", "5.0", "15.0",
1, now, now+3600,
)
assert.NoError(t, res.Error)
assert.Equal(t, int64(1), res.RowsAffected)
type want struct {
tx entity.Transaction
wantErr bool
}
tests := []struct {
name string
orderID string
want want
}{
{
name: "found existing",
orderID: "order-123",
want: want{
tx: entity.Transaction{
ID: 1,
OrderID: "order-123",
TransactionID: "tx-abc",
Brand: "brandX",
UID: "user1",
ToUID: "user2",
TxType: wallet.TxType(1),
BusinessType: int8(2),
Asset: "BTC",
Amount: decimal.RequireFromString("10.0"),
BeforeBalance: decimal.RequireFromString("5.0"),
PostTransferBalance: decimal.RequireFromString("15.0"),
Status: wallet.Enable(1),
CreateAt: now,
DueTime: now + 3600,
},
wantErr: false,
},
},
{
name: "not found",
orderID: "missing",
want: want{wantErr: true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.FindByOrderID(context.Background(), tt.orderID)
if tt.want.wantErr {
assert.Error(t, err)
assert.True(t, errors.Is(err, repository.ErrRecordNotFound))
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want.tx.ID, got.ID)
assert.Equal(t, tt.want.tx.OrderID, got.OrderID)
assert.Equal(t, tt.want.tx.TransactionID, got.TransactionID)
assert.Equal(t, tt.want.tx.Brand, got.Brand)
assert.Equal(t, tt.want.tx.UID, got.UID)
assert.Equal(t, tt.want.tx.ToUID, got.ToUID)
assert.Equal(t, tt.want.tx.TxType, got.TxType)
assert.Equal(t, tt.want.tx.BusinessType, got.BusinessType)
assert.Equal(t, tt.want.tx.Asset, got.Asset)
assert.True(t, tt.want.tx.Amount.Equal(got.Amount))
assert.True(t, tt.want.tx.BeforeBalance.Equal(got.BeforeBalance))
assert.True(t, tt.want.tx.PostTransferBalance.Equal(got.PostTransferBalance))
assert.Equal(t, tt.want.tx.Status, got.Status)
assert.Equal(t, tt.want.tx.CreateAt, got.CreateAt)
assert.Equal(t, tt.want.tx.DueTime, got.DueTime)
})
}
}
func TestTransactionRepository_List(t *testing.T) {
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
createTransactionTable(t, db)
now := time.Now().Unix()
rows := []entity.Transaction{
{OrderID: "A", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 30},
{OrderID: "B", UID: "u2", TxType: wallet.Withdraw, BusinessType: 2, Asset: "ETH", CreateAt: now - 20},
{OrderID: "C", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 10},
}
assert.NoError(t, db.Create(&rows).Error)
tests := []struct {
name string
query repository.TransactionQuery
wantCount int64
wantIDs []int64
wantErr error
}{
{
name: "no filter returns all",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 50,
},
wantCount: 3,
wantIDs: []int64{rows[2].ID, rows[1].ID, rows[0].ID},
wantErr: nil,
},
{
name: "filter by UID",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 50,
UID: proto.String("u1"),
},
wantCount: 2,
wantIDs: []int64{rows[2].ID, rows[0].ID},
},
{
name: "filter by BusinessType",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 50,
BusinessType: []int8{2},
},
wantCount: 1,
wantIDs: []int64{rows[1].ID},
},
{
name: "filter by Asset + TxType",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 50,
Assets: proto.String("BTC"),
TxTypes: []wallet.TxType{wallet.Deposit},
},
wantCount: 2,
wantIDs: []int64{rows[2].ID, rows[0].ID},
},
{
name: "time range filter",
query: repository.TransactionQuery{
StartTime: proto.Int64(now - 25),
EndTime: proto.Int64(now - 5),
},
wantCount: 2,
wantIDs: []int64{rows[2].ID, rows[1].ID},
},
{
name: "paging page 1 size 1",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 1,
},
wantCount: 3,
wantIDs: []int64{rows[2].ID},
},
{
name: "paging page 2 size 1",
query: repository.TransactionQuery{
PageIndex: 2,
PageSize: 1,
},
wantCount: 3,
wantIDs: []int64{rows[1].ID},
},
{
name: "no match returns ErrRecordNotFound",
query: repository.TransactionQuery{UID: proto.String("nonexist")},
wantErr: repository.ErrRecordNotFound,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, cnt, err := repo.List(ctx, repository.TransactionQuery{
BusinessType: tt.query.BusinessType,
UID: tt.query.UID,
OrderID: tt.query.OrderID,
Assets: tt.query.Assets,
TxTypes: tt.query.TxTypes,
StartTime: tt.query.StartTime,
EndTime: tt.query.EndTime,
PageIndex: tt.query.PageIndex,
PageSize: tt.query.PageSize,
})
if tt.wantErr != nil {
assert.ErrorIs(t, err, tt.wantErr)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantCount, cnt)
var ids []int64
for _, tx := range got {
ids = append(ids, tx.ID)
}
assert.Equal(t, tt.wantIDs, ids)
})
}
}
func TestTransactionRepository_FindByDueTimeRange(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
// prepare table
createTransactionTable(t, db)
// seed rows
now := time.Now().Unix()
rows := []entity.Transaction{
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 10, Status: wallet.WalletNonStatus},
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now + 10, Status: wallet.WalletNonStatus},
{TxType: wallet.Deposit, BusinessType: 0, DueTime: 0, Status: wallet.WalletNonStatus},
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 5, Status: wallet.Enable(1)}, // status != non
}
assert.NoError(t, db.Create(&rows).Error)
tests := []struct {
name string
cutoff time.Time
types []wallet.TxType
wantIDs []int64
}{
{
name: "due before now for type=1",
cutoff: time.Unix(now, 0),
types: []wallet.TxType{wallet.Deposit},
wantIDs: []int64{rows[0].ID},
},
{
name: "include multiple types",
cutoff: time.Unix(now+20, 0),
types: []wallet.TxType{wallet.Deposit},
wantIDs: []int64{rows[0].ID, rows[1].ID},
},
{
name: "zero due_time is skipped",
cutoff: time.Unix(now+100, 0),
types: []wallet.TxType{wallet.Deposit},
wantIDs: []int64{rows[0].ID, rows[1].ID},
},
{
name: "no matches",
cutoff: time.Unix(now-100, 0),
types: []wallet.TxType{wallet.Deposit},
wantIDs: nil,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.FindByDueTimeRange(ctx, tt.cutoff, tt.types)
assert.NoError(t, err)
var ids []int64
for _, tx := range got {
ids = append(ids, tx.ID)
}
assert.Equal(t, tt.wantIDs, ids)
})
}
}
func TestTransactionRepository_UpdateStatusByID(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
createTransactionTable(t, db)
ctx := context.Background()
tx := &entity.Transaction{
OrderID: "ord-123",
TransactionID: "tx-abc",
Brand: "brand1",
UID: "user1",
ToUID: "user2",
TxType: wallet.TxType(1),
BusinessType: int8(2),
Asset: "BTC",
Amount: decimal.NewFromInt(100),
PostTransferBalance: decimal.NewFromInt(1000),
BeforeBalance: decimal.NewFromInt(900),
Status: wallet.Enable(0),
CreateAt: time.Now().Unix(),
DueTime: time.Now().Add(time.Hour).Unix(),
}
err = repo.Insert(ctx, tx)
assert.NoError(t, err)
existingID := tx.ID
tests := []struct {
name string
id int64
newStatus int
wantErr bool
wantRow bool
}{
{
name: "update existing row",
id: existingID,
newStatus: 1,
wantErr: false,
wantRow: true,
},
{
name: "non-existent id does not error",
id: existingID + 999,
newStatus: 5,
wantErr: false,
wantRow: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := repo.UpdateStatusByID(ctx, tt.id, tt.newStatus)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
var got entity.Transaction
res := db.First(&got, tt.id)
if tt.wantRow {
// existing row should reflect the new status
assert.NoError(t, res.Error)
assert.Equal(t, tt.newStatus, int(got.Status))
} else {
// non-existent id: no record found
assert.Error(t, res.Error)
assert.True(t, errors.Is(res.Error, gorm.ErrRecordNotFound))
}
})
}
}
func TestTransactionRepository_ListWalletTransactions(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
createWalletTransactionTable(t, db)
ctx := context.Background()
now := time.Now().UnixNano()
transactions := []entity.WalletTransaction{
{OrderID: "o1", UID: "u1", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(10), Balance: decimal.NewFromInt(110), CreateAt: now - 3},
{OrderID: "o2", UID: "u1", WalletType: wallet.TypeFreeze, Asset: "ETH", Amount: decimal.NewFromInt(20), Balance: decimal.NewFromInt(220), CreateAt: now - 2},
{OrderID: "o1", UID: "u2", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(30), Balance: decimal.NewFromInt(330), CreateAt: now - 1},
}
assert.NoError(t, db.Create(&transactions).Error)
tests := []struct {
name string
uid string
orderIDs []string
walletType wallet.Types
wantIDs []int64
wantErr bool
}{
{
name: "filter by uid=u1, both orders",
uid: "u1",
orderIDs: []string{"o1", "o2"},
walletType: 0,
wantIDs: []int64{transactions[0].ID, transactions[1].ID},
},
{
name: "filter by uid=u1, walletType=1",
uid: "u1",
orderIDs: []string{"o1", "o2"},
walletType: wallet.Types(1),
wantIDs: []int64{transactions[0].ID},
},
{
name: "filter by order=o1 only, any uid",
uid: "",
orderIDs: []string{"o1"},
walletType: wallet.Types(1),
wantIDs: []int64{transactions[0].ID, transactions[2].ID},
},
{
name: "no match yields empty",
uid: "nope",
orderIDs: []string{"o1"},
walletType: 0,
wantIDs: []int64{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.ListWalletTransactions(ctx, tt.uid, tt.orderIDs, tt.walletType)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 提取返回的 IDs 进行无序比较
gotIDs := make([]int64, len(got))
for i, w := range got {
gotIDs[i] = w.ID
}
assert.ElementsMatch(t, tt.wantIDs, gotIDs)
})
}
}