app-cloudep-wallet-service/pkg/repository/transaction_test.go

640 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"code.30cm.net/digimon/app-cloudep-wallet-service/internal/config"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/lib/sql_client"
"context"
"errors"
"fmt"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"gorm.io/gorm"
"testing"
"time"
)
func SetupTestTransactionRepository() (repository.TransactionRepository, *gorm.DB, func(), error) {
host, port, _, tearDown, err := startMySQLContainer()
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to start MySQL container: %w", err)
}
conf := config.Config{
MySQL: struct {
UserName string
Password string
Host string
Port string
Database string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime time.Duration
LogLevel string
}{
UserName: MySQLUser,
Password: MySQLPassword,
Host: host,
Port: port,
Database: MySQLDatabase,
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 300,
LogLevel: "info",
},
}
db, err := sql_client.NewMySQLClient(conf)
if err != nil {
tearDown()
return nil, nil, nil, fmt.Errorf("failed to create db client: %w", err)
}
repo := MustTransactionRepository(TransactionRepositoryParam{DB: db})
return repo, db, tearDown, nil
}
func createTransactionTable(t *testing.T, db *gorm.DB) {
sql := `
CREATE TABLE IF NOT EXISTS transaction (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
order_id VARCHAR(64) NOT NULL,
transaction_id VARCHAR(64),
brand VARCHAR(32),
uid VARCHAR(64),
to_uid VARCHAR(64),
type TINYINT,
business_type TINYINT,
asset VARCHAR(32),
amount DECIMAL(30,18),
before_balance DECIMAL(30,18),
post_transfer_balance DECIMAL(30,18),
status TINYINT,
create_at BIGINT,
due_time BIGINT
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
assert.NoError(t, db.Exec(sql).Error)
}
func createWalletTransactionTable(t *testing.T, db *gorm.DB) {
createTableSQL := `
CREATE TABLE IF NOT EXISTS wallet_transaction (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID自動遞增',
transaction_id BIGINT NOT NULL COMMENT '交易流水號(可對應某次業務操作,例如同一訂單的多筆變化)',
order_id VARCHAR(64) NOT NULL COMMENT '訂單編號(對應實際訂單或業務事件)',
brand VARCHAR(50) NOT NULL COMMENT '品牌(多租戶或多平台識別)',
uid VARCHAR(64) NOT NULL COMMENT '使用者 UID',
wallet_type TINYINT NOT NULL COMMENT '錢包類型(如主錢包、獎勵錢包、凍結錢包等)',
business_type TINYINT NOT NULL COMMENT '業務類型(如購物、退款、加值等)',
asset VARCHAR(32) NOT NULL COMMENT '資產代號(如 BTC、ETH、GEM_RED、USD 等)',
amount DECIMAL(30,18) NOT NULL COMMENT '變動金額(正數為收入,負數為支出)',
balance DECIMAL(30,18) NOT NULL COMMENT '當前錢包餘額(這筆交易後的餘額快照)',
create_at BIGINT NOT NULL DEFAULT 0 COMMENT '建立時間UnixNano紀錄交易發生時間',
PRIMARY KEY (id),
KEY idx_uid (uid),
KEY idx_transaction_id (transaction_id),
KEY idx_order_id (order_id),
KEY idx_brand (brand),
KEY idx_wallet_type (wallet_type)
) ENGINE=InnoDB
DEFAULT CHARSET=utf8mb4
COLLATE=utf8mb4_unicode_ci
COMMENT='錢包資金異動紀錄(每一次交易行為的快照記錄)';`
assert.NoError(t, db.Exec(createTableSQL).Error)
}
func TestTransactionRepository_InsertAndBatchInsert(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
// prepare table
createTransactionTable(t, db)
now := time.Now().Unix()
template := entity.Transaction{
OrderID: "o1",
TransactionID: "tx1",
Brand: "b1",
UID: "u1",
ToUID: "u2",
TxType: 1,
BusinessType: 2,
Asset: "BTC",
Amount: decimal.RequireFromString("100.5"),
BeforeBalance: decimal.RequireFromString("50.0"),
PostTransferBalance: decimal.RequireFromString("150.5"),
Status: 1,
CreateAt: now,
DueTime: now + 3600,
}
tests := []struct {
name string
setup func(t *testing.T)
action func() error
validate func(t *testing.T)
}{
{
name: "single insert",
setup: func(t *testing.T) {
// clean table
assert.NoError(t, db.Exec("DELETE FROM transaction").Error)
},
action: func() error {
tx := template
return repo.Insert(context.Background(), &tx)
},
validate: func(t *testing.T) {
var count int64
assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error)
assert.Equal(t, int64(1), count)
var got entity.Transaction
err := db.Take(&got, "order_id = ?", template.OrderID).Error
assert.NoError(t, err)
},
},
{
name: "batch insert",
setup: func(t *testing.T) {
assert.NoError(t, db.Exec("DELETE FROM transaction").Error)
},
action: func() error {
// clone two entries with different order IDs
tx1 := template
tx2 := template
tx1.OrderID = "o2"
tx2.OrderID = "o3"
return repo.BatchInsert(context.Background(), []*entity.Transaction{&tx1, &tx2})
},
validate: func(t *testing.T) {
var count int64
assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error)
assert.Equal(t, int64(2), count)
var orders []string
rows, _ := db.Raw("SELECT order_id FROM transaction ORDER BY order_id").Rows()
defer rows.Close()
for rows.Next() {
var oid string
rows.Scan(&oid)
orders = append(orders, oid)
}
assert.Equal(t, []string{"o2", "o3"}, orders)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup(t)
err := tt.action()
assert.NoError(t, err)
tt.validate(t)
})
}
}
func TestTransactionRepository_FindByOrderID(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
// prepare table
createTransactionTable(t, db)
// 4) seed one row
now := time.Now().Unix()
res := db.Exec(
`INSERT INTO transaction
(order_id, transaction_id, brand, uid, to_uid, type, business_type, asset,
amount, before_balance, post_transfer_balance, status, create_at, due_time)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
"order-123", "tx-abc", "brandX", "user1", "user2",
1, 2, "BTC",
"10.0", "5.0", "15.0",
1, now, now+3600,
)
assert.NoError(t, res.Error)
assert.Equal(t, int64(1), res.RowsAffected)
type want struct {
tx entity.Transaction
wantErr bool
}
tests := []struct {
name string
orderID string
want want
}{
{
name: "found existing",
orderID: "order-123",
want: want{
tx: entity.Transaction{
ID: 1,
OrderID: "order-123",
TransactionID: "tx-abc",
Brand: "brandX",
UID: "user1",
ToUID: "user2",
TxType: wallet.TxType(1),
BusinessType: int8(2),
Asset: "BTC",
Amount: decimal.RequireFromString("10.0"),
BeforeBalance: decimal.RequireFromString("5.0"),
PostTransferBalance: decimal.RequireFromString("15.0"),
Status: wallet.Enable(1),
CreateAt: now,
DueTime: now + 3600,
},
wantErr: false,
},
},
{
name: "not found",
orderID: "missing",
want: want{wantErr: true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.FindByOrderID(context.Background(), tt.orderID)
if tt.want.wantErr {
assert.Error(t, err)
assert.True(t, errors.Is(err, repository.ErrRecordNotFound))
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want.tx.ID, got.ID)
assert.Equal(t, tt.want.tx.OrderID, got.OrderID)
assert.Equal(t, tt.want.tx.TransactionID, got.TransactionID)
assert.Equal(t, tt.want.tx.Brand, got.Brand)
assert.Equal(t, tt.want.tx.UID, got.UID)
assert.Equal(t, tt.want.tx.ToUID, got.ToUID)
assert.Equal(t, tt.want.tx.TxType, got.TxType)
assert.Equal(t, tt.want.tx.BusinessType, got.BusinessType)
assert.Equal(t, tt.want.tx.Asset, got.Asset)
assert.True(t, tt.want.tx.Amount.Equal(got.Amount))
assert.True(t, tt.want.tx.BeforeBalance.Equal(got.BeforeBalance))
assert.True(t, tt.want.tx.PostTransferBalance.Equal(got.PostTransferBalance))
assert.Equal(t, tt.want.tx.Status, got.Status)
assert.Equal(t, tt.want.tx.CreateAt, got.CreateAt)
assert.Equal(t, tt.want.tx.DueTime, got.DueTime)
})
}
}
func TestTransactionRepository_List(t *testing.T) {
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
createTransactionTable(t, db)
now := time.Now().Unix()
rows := []entity.Transaction{
{OrderID: "A", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 30},
{OrderID: "B", UID: "u2", TxType: wallet.Withdraw, BusinessType: 2, Asset: "ETH", CreateAt: now - 20},
{OrderID: "C", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 10},
}
assert.NoError(t, db.Create(&rows).Error)
tests := []struct {
name string
query repository.TransactionQuery
wantCount int64
wantIDs []int64
wantErr error
}{
{
name: "no filter returns all",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 50,
},
wantCount: 3,
wantIDs: []int64{rows[2].ID, rows[1].ID, rows[0].ID},
wantErr: nil,
},
{
name: "filter by UID",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 50,
UID: proto.String("u1"),
},
wantCount: 2,
wantIDs: []int64{rows[2].ID, rows[0].ID},
},
{
name: "filter by BusinessType",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 50,
BusinessType: []int8{2},
},
wantCount: 1,
wantIDs: []int64{rows[1].ID},
},
{
name: "filter by Asset + TxType",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 50,
Assets: proto.String("BTC"),
TxTypes: []wallet.TxType{wallet.Deposit},
},
wantCount: 2,
wantIDs: []int64{rows[2].ID, rows[0].ID},
},
{
name: "time range filter",
query: repository.TransactionQuery{
StartTime: proto.Int64(now - 25),
EndTime: proto.Int64(now - 5),
},
wantCount: 2,
wantIDs: []int64{rows[2].ID, rows[1].ID},
},
{
name: "paging page 1 size 1",
query: repository.TransactionQuery{
PageIndex: 1,
PageSize: 1,
},
wantCount: 3,
wantIDs: []int64{rows[2].ID},
},
{
name: "paging page 2 size 1",
query: repository.TransactionQuery{
PageIndex: 2,
PageSize: 1,
},
wantCount: 3,
wantIDs: []int64{rows[1].ID},
},
{
name: "no match returns ErrRecordNotFound",
query: repository.TransactionQuery{UID: proto.String("nonexist")},
wantErr: repository.ErrRecordNotFound,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, cnt, err := repo.List(ctx, repository.TransactionQuery{
BusinessType: tt.query.BusinessType,
UID: tt.query.UID,
OrderID: tt.query.OrderID,
Assets: tt.query.Assets,
TxTypes: tt.query.TxTypes,
StartTime: tt.query.StartTime,
EndTime: tt.query.EndTime,
PageIndex: tt.query.PageIndex,
PageSize: tt.query.PageSize,
})
if tt.wantErr != nil {
assert.ErrorIs(t, err, tt.wantErr)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantCount, cnt)
var ids []int64
for _, tx := range got {
ids = append(ids, tx.ID)
}
assert.Equal(t, tt.wantIDs, ids)
})
}
}
func TestTransactionRepository_FindByDueTimeRange(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
// prepare table
createTransactionTable(t, db)
// seed rows
now := time.Now().Unix()
rows := []entity.Transaction{
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 10, Status: wallet.WalletNonStatus},
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now + 10, Status: wallet.WalletNonStatus},
{TxType: wallet.Deposit, BusinessType: 0, DueTime: 0, Status: wallet.WalletNonStatus},
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 5, Status: wallet.Enable(1)}, // status != non
}
assert.NoError(t, db.Create(&rows).Error)
tests := []struct {
name string
cutoff time.Time
types []wallet.TxType
wantIDs []int64
}{
{
name: "due before now for type=1",
cutoff: time.Unix(now, 0),
types: []wallet.TxType{wallet.Deposit},
wantIDs: []int64{rows[0].ID},
},
{
name: "include multiple types",
cutoff: time.Unix(now+20, 0),
types: []wallet.TxType{wallet.Deposit},
wantIDs: []int64{rows[0].ID, rows[1].ID},
},
{
name: "zero due_time is skipped",
cutoff: time.Unix(now+100, 0),
types: []wallet.TxType{wallet.Deposit},
wantIDs: []int64{rows[0].ID, rows[1].ID},
},
{
name: "no matches",
cutoff: time.Unix(now-100, 0),
types: []wallet.TxType{wallet.Deposit},
wantIDs: nil,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.FindByDueTimeRange(ctx, tt.cutoff, tt.types)
assert.NoError(t, err)
var ids []int64
for _, tx := range got {
ids = append(ids, tx.ID)
}
assert.Equal(t, tt.wantIDs, ids)
})
}
}
func TestTransactionRepository_UpdateStatusByID(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
createTransactionTable(t, db)
ctx := context.Background()
tx := &entity.Transaction{
OrderID: "ord-123",
TransactionID: "tx-abc",
Brand: "brand1",
UID: "user1",
ToUID: "user2",
TxType: wallet.TxType(1),
BusinessType: int8(2),
Asset: "BTC",
Amount: decimal.NewFromInt(100),
PostTransferBalance: decimal.NewFromInt(1000),
BeforeBalance: decimal.NewFromInt(900),
Status: wallet.Enable(0),
CreateAt: time.Now().Unix(),
DueTime: time.Now().Add(time.Hour).Unix(),
}
err = repo.Insert(ctx, tx)
assert.NoError(t, err)
existingID := tx.ID
tests := []struct {
name string
id int64
newStatus int
wantErr bool
wantRow bool
}{
{
name: "update existing row",
id: existingID,
newStatus: 1,
wantErr: false,
wantRow: true,
},
{
name: "non-existent id does not error",
id: existingID + 999,
newStatus: 5,
wantErr: false,
wantRow: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := repo.UpdateStatusByID(ctx, tt.id, tt.newStatus)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
var got entity.Transaction
res := db.First(&got, tt.id)
if tt.wantRow {
// existing row should reflect the new status
assert.NoError(t, res.Error)
assert.Equal(t, tt.newStatus, int(got.Status))
} else {
// non-existent id: no record found
assert.Error(t, res.Error)
assert.True(t, errors.Is(res.Error, gorm.ErrRecordNotFound))
}
})
}
}
func TestTransactionRepository_ListWalletTransactions(t *testing.T) {
// start container and connect
repo, db, tearDown, err := SetupTestTransactionRepository()
assert.NoError(t, err)
defer tearDown()
createWalletTransactionTable(t, db)
ctx := context.Background()
now := time.Now().UnixNano()
transactions := []entity.WalletTransaction{
{OrderID: "o1", UID: "u1", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(10), Balance: decimal.NewFromInt(110), CreateAt: now - 3},
{OrderID: "o2", UID: "u1", WalletType: wallet.TypeFreeze, Asset: "ETH", Amount: decimal.NewFromInt(20), Balance: decimal.NewFromInt(220), CreateAt: now - 2},
{OrderID: "o1", UID: "u2", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(30), Balance: decimal.NewFromInt(330), CreateAt: now - 1},
}
assert.NoError(t, db.Create(&transactions).Error)
tests := []struct {
name string
uid string
orderIDs []string
walletType wallet.Types
wantIDs []int64
wantErr bool
}{
{
name: "filter by uid=u1, both orders",
uid: "u1",
orderIDs: []string{"o1", "o2"},
walletType: 0,
wantIDs: []int64{transactions[0].ID, transactions[1].ID},
},
{
name: "filter by uid=u1, walletType=1",
uid: "u1",
orderIDs: []string{"o1", "o2"},
walletType: wallet.Types(1),
wantIDs: []int64{transactions[0].ID},
},
{
name: "filter by order=o1 only, any uid",
uid: "",
orderIDs: []string{"o1"},
walletType: wallet.Types(1),
wantIDs: []int64{transactions[0].ID, transactions[2].ID},
},
{
name: "no match yields empty",
uid: "nope",
orderIDs: []string{"o1"},
walletType: 0,
wantIDs: []int64{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.ListWalletTransactions(ctx, tt.uid, tt.orderIDs, tt.walletType)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 提取返回的 IDs 进行无序比较
gotIDs := make([]int64, len(got))
for i, w := range got {
gotIDs[i] = w.ID
}
assert.ElementsMatch(t, tt.wantIDs, gotIDs)
})
}
}