640 lines
18 KiB
Go
640 lines
18 KiB
Go
|
package repository
|
|||
|
|
|||
|
import (
|
|||
|
"code.30cm.net/digimon/app-cloudep-wallet-service/internal/config"
|
|||
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity"
|
|||
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository"
|
|||
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet"
|
|||
|
"code.30cm.net/digimon/app-cloudep-wallet-service/pkg/lib/sql_client"
|
|||
|
"context"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"github.com/shopspring/decimal"
|
|||
|
"github.com/stretchr/testify/assert"
|
|||
|
"google.golang.org/protobuf/proto"
|
|||
|
"gorm.io/gorm"
|
|||
|
"testing"
|
|||
|
"time"
|
|||
|
)
|
|||
|
|
|||
|
func SetupTestTransactionRepository() (repository.TransactionRepository, *gorm.DB, func(), error) {
|
|||
|
host, port, _, tearDown, err := startMySQLContainer()
|
|||
|
if err != nil {
|
|||
|
return nil, nil, nil, fmt.Errorf("failed to start MySQL container: %w", err)
|
|||
|
}
|
|||
|
|
|||
|
conf := config.Config{
|
|||
|
MySQL: struct {
|
|||
|
UserName string
|
|||
|
Password string
|
|||
|
Host string
|
|||
|
Port string
|
|||
|
Database string
|
|||
|
MaxIdleConns int
|
|||
|
MaxOpenConns int
|
|||
|
ConnMaxLifetime time.Duration
|
|||
|
LogLevel string
|
|||
|
}{
|
|||
|
UserName: MySQLUser,
|
|||
|
Password: MySQLPassword,
|
|||
|
Host: host,
|
|||
|
Port: port,
|
|||
|
Database: MySQLDatabase,
|
|||
|
MaxIdleConns: 10,
|
|||
|
MaxOpenConns: 100,
|
|||
|
ConnMaxLifetime: 300,
|
|||
|
LogLevel: "info",
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
db, err := sql_client.NewMySQLClient(conf)
|
|||
|
if err != nil {
|
|||
|
tearDown()
|
|||
|
return nil, nil, nil, fmt.Errorf("failed to create db client: %w", err)
|
|||
|
}
|
|||
|
|
|||
|
repo := MustTransactionRepository(TransactionRepositoryParam{DB: db})
|
|||
|
return repo, db, tearDown, nil
|
|||
|
}
|
|||
|
|
|||
|
func createTransactionTable(t *testing.T, db *gorm.DB) {
|
|||
|
sql := `
|
|||
|
CREATE TABLE IF NOT EXISTS transaction (
|
|||
|
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
|||
|
order_id VARCHAR(64) NOT NULL,
|
|||
|
transaction_id VARCHAR(64),
|
|||
|
brand VARCHAR(32),
|
|||
|
uid VARCHAR(64),
|
|||
|
to_uid VARCHAR(64),
|
|||
|
type TINYINT,
|
|||
|
business_type TINYINT,
|
|||
|
asset VARCHAR(32),
|
|||
|
amount DECIMAL(30,18),
|
|||
|
before_balance DECIMAL(30,18),
|
|||
|
post_transfer_balance DECIMAL(30,18),
|
|||
|
status TINYINT,
|
|||
|
create_at BIGINT,
|
|||
|
due_time BIGINT
|
|||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`
|
|||
|
assert.NoError(t, db.Exec(sql).Error)
|
|||
|
}
|
|||
|
|
|||
|
func createWalletTransactionTable(t *testing.T, db *gorm.DB) {
|
|||
|
createTableSQL := `
|
|||
|
CREATE TABLE IF NOT EXISTS wallet_transaction (
|
|||
|
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID,自動遞增',
|
|||
|
transaction_id BIGINT NOT NULL COMMENT '交易流水號(可對應某次業務操作,例如同一訂單的多筆變化)',
|
|||
|
order_id VARCHAR(64) NOT NULL COMMENT '訂單編號(對應實際訂單或業務事件)',
|
|||
|
brand VARCHAR(50) NOT NULL COMMENT '品牌(多租戶或多平台識別)',
|
|||
|
uid VARCHAR(64) NOT NULL COMMENT '使用者 UID',
|
|||
|
wallet_type TINYINT NOT NULL COMMENT '錢包類型(如主錢包、獎勵錢包、凍結錢包等)',
|
|||
|
business_type TINYINT NOT NULL COMMENT '業務類型(如購物、退款、加值等)',
|
|||
|
asset VARCHAR(32) NOT NULL COMMENT '資產代號(如 BTC、ETH、GEM_RED、USD 等)',
|
|||
|
amount DECIMAL(30,18) NOT NULL COMMENT '變動金額(正數為收入,負數為支出)',
|
|||
|
balance DECIMAL(30,18) NOT NULL COMMENT '當前錢包餘額(這筆交易後的餘額快照)',
|
|||
|
create_at BIGINT NOT NULL DEFAULT 0 COMMENT '建立時間(UnixNano,紀錄交易發生時間)',
|
|||
|
PRIMARY KEY (id),
|
|||
|
KEY idx_uid (uid),
|
|||
|
KEY idx_transaction_id (transaction_id),
|
|||
|
KEY idx_order_id (order_id),
|
|||
|
KEY idx_brand (brand),
|
|||
|
KEY idx_wallet_type (wallet_type)
|
|||
|
) ENGINE=InnoDB
|
|||
|
DEFAULT CHARSET=utf8mb4
|
|||
|
COLLATE=utf8mb4_unicode_ci
|
|||
|
COMMENT='錢包資金異動紀錄(每一次交易行為的快照記錄)';`
|
|||
|
|
|||
|
assert.NoError(t, db.Exec(createTableSQL).Error)
|
|||
|
}
|
|||
|
|
|||
|
func TestTransactionRepository_InsertAndBatchInsert(t *testing.T) {
|
|||
|
// start container and connect
|
|||
|
repo, db, tearDown, err := SetupTestTransactionRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer tearDown()
|
|||
|
|
|||
|
// prepare table
|
|||
|
createTransactionTable(t, db)
|
|||
|
|
|||
|
now := time.Now().Unix()
|
|||
|
template := entity.Transaction{
|
|||
|
OrderID: "o1",
|
|||
|
TransactionID: "tx1",
|
|||
|
Brand: "b1",
|
|||
|
UID: "u1",
|
|||
|
ToUID: "u2",
|
|||
|
TxType: 1,
|
|||
|
BusinessType: 2,
|
|||
|
Asset: "BTC",
|
|||
|
Amount: decimal.RequireFromString("100.5"),
|
|||
|
BeforeBalance: decimal.RequireFromString("50.0"),
|
|||
|
PostTransferBalance: decimal.RequireFromString("150.5"),
|
|||
|
Status: 1,
|
|||
|
CreateAt: now,
|
|||
|
DueTime: now + 3600,
|
|||
|
}
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
setup func(t *testing.T)
|
|||
|
action func() error
|
|||
|
validate func(t *testing.T)
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "single insert",
|
|||
|
setup: func(t *testing.T) {
|
|||
|
// clean table
|
|||
|
assert.NoError(t, db.Exec("DELETE FROM transaction").Error)
|
|||
|
},
|
|||
|
action: func() error {
|
|||
|
tx := template
|
|||
|
return repo.Insert(context.Background(), &tx)
|
|||
|
},
|
|||
|
validate: func(t *testing.T) {
|
|||
|
var count int64
|
|||
|
assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error)
|
|||
|
assert.Equal(t, int64(1), count)
|
|||
|
|
|||
|
var got entity.Transaction
|
|||
|
err := db.Take(&got, "order_id = ?", template.OrderID).Error
|
|||
|
assert.NoError(t, err)
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "batch insert",
|
|||
|
setup: func(t *testing.T) {
|
|||
|
assert.NoError(t, db.Exec("DELETE FROM transaction").Error)
|
|||
|
},
|
|||
|
action: func() error {
|
|||
|
// clone two entries with different order IDs
|
|||
|
tx1 := template
|
|||
|
tx2 := template
|
|||
|
tx1.OrderID = "o2"
|
|||
|
tx2.OrderID = "o3"
|
|||
|
return repo.BatchInsert(context.Background(), []*entity.Transaction{&tx1, &tx2})
|
|||
|
},
|
|||
|
validate: func(t *testing.T) {
|
|||
|
var count int64
|
|||
|
assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error)
|
|||
|
assert.Equal(t, int64(2), count)
|
|||
|
|
|||
|
var orders []string
|
|||
|
rows, _ := db.Raw("SELECT order_id FROM transaction ORDER BY order_id").Rows()
|
|||
|
defer rows.Close()
|
|||
|
for rows.Next() {
|
|||
|
var oid string
|
|||
|
rows.Scan(&oid)
|
|||
|
orders = append(orders, oid)
|
|||
|
}
|
|||
|
assert.Equal(t, []string{"o2", "o3"}, orders)
|
|||
|
},
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
tt.setup(t)
|
|||
|
err := tt.action()
|
|||
|
assert.NoError(t, err)
|
|||
|
tt.validate(t)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTransactionRepository_FindByOrderID(t *testing.T) {
|
|||
|
// start container and connect
|
|||
|
repo, db, tearDown, err := SetupTestTransactionRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer tearDown()
|
|||
|
|
|||
|
// prepare table
|
|||
|
createTransactionTable(t, db)
|
|||
|
|
|||
|
// 4) seed one row
|
|||
|
now := time.Now().Unix()
|
|||
|
res := db.Exec(
|
|||
|
`INSERT INTO transaction
|
|||
|
(order_id, transaction_id, brand, uid, to_uid, type, business_type, asset,
|
|||
|
amount, before_balance, post_transfer_balance, status, create_at, due_time)
|
|||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|||
|
"order-123", "tx-abc", "brandX", "user1", "user2",
|
|||
|
1, 2, "BTC",
|
|||
|
"10.0", "5.0", "15.0",
|
|||
|
1, now, now+3600,
|
|||
|
)
|
|||
|
assert.NoError(t, res.Error)
|
|||
|
assert.Equal(t, int64(1), res.RowsAffected)
|
|||
|
|
|||
|
type want struct {
|
|||
|
tx entity.Transaction
|
|||
|
wantErr bool
|
|||
|
}
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
orderID string
|
|||
|
want want
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "found existing",
|
|||
|
orderID: "order-123",
|
|||
|
want: want{
|
|||
|
tx: entity.Transaction{
|
|||
|
ID: 1,
|
|||
|
OrderID: "order-123",
|
|||
|
TransactionID: "tx-abc",
|
|||
|
Brand: "brandX",
|
|||
|
UID: "user1",
|
|||
|
ToUID: "user2",
|
|||
|
TxType: wallet.TxType(1),
|
|||
|
BusinessType: int8(2),
|
|||
|
Asset: "BTC",
|
|||
|
Amount: decimal.RequireFromString("10.0"),
|
|||
|
BeforeBalance: decimal.RequireFromString("5.0"),
|
|||
|
PostTransferBalance: decimal.RequireFromString("15.0"),
|
|||
|
Status: wallet.Enable(1),
|
|||
|
CreateAt: now,
|
|||
|
DueTime: now + 3600,
|
|||
|
},
|
|||
|
wantErr: false,
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "not found",
|
|||
|
orderID: "missing",
|
|||
|
want: want{wantErr: true},
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
got, err := repo.FindByOrderID(context.Background(), tt.orderID)
|
|||
|
if tt.want.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
assert.True(t, errors.Is(err, repository.ErrRecordNotFound))
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
assert.Equal(t, tt.want.tx.ID, got.ID)
|
|||
|
assert.Equal(t, tt.want.tx.OrderID, got.OrderID)
|
|||
|
assert.Equal(t, tt.want.tx.TransactionID, got.TransactionID)
|
|||
|
assert.Equal(t, tt.want.tx.Brand, got.Brand)
|
|||
|
assert.Equal(t, tt.want.tx.UID, got.UID)
|
|||
|
assert.Equal(t, tt.want.tx.ToUID, got.ToUID)
|
|||
|
assert.Equal(t, tt.want.tx.TxType, got.TxType)
|
|||
|
assert.Equal(t, tt.want.tx.BusinessType, got.BusinessType)
|
|||
|
assert.Equal(t, tt.want.tx.Asset, got.Asset)
|
|||
|
assert.True(t, tt.want.tx.Amount.Equal(got.Amount))
|
|||
|
assert.True(t, tt.want.tx.BeforeBalance.Equal(got.BeforeBalance))
|
|||
|
assert.True(t, tt.want.tx.PostTransferBalance.Equal(got.PostTransferBalance))
|
|||
|
assert.Equal(t, tt.want.tx.Status, got.Status)
|
|||
|
assert.Equal(t, tt.want.tx.CreateAt, got.CreateAt)
|
|||
|
assert.Equal(t, tt.want.tx.DueTime, got.DueTime)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTransactionRepository_List(t *testing.T) {
|
|||
|
repo, db, tearDown, err := SetupTestTransactionRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer tearDown()
|
|||
|
|
|||
|
createTransactionTable(t, db)
|
|||
|
|
|||
|
now := time.Now().Unix()
|
|||
|
rows := []entity.Transaction{
|
|||
|
{OrderID: "A", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 30},
|
|||
|
{OrderID: "B", UID: "u2", TxType: wallet.Withdraw, BusinessType: 2, Asset: "ETH", CreateAt: now - 20},
|
|||
|
{OrderID: "C", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 10},
|
|||
|
}
|
|||
|
|
|||
|
assert.NoError(t, db.Create(&rows).Error)
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
query repository.TransactionQuery
|
|||
|
wantCount int64
|
|||
|
wantIDs []int64
|
|||
|
wantErr error
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "no filter returns all",
|
|||
|
query: repository.TransactionQuery{
|
|||
|
PageIndex: 1,
|
|||
|
PageSize: 50,
|
|||
|
},
|
|||
|
wantCount: 3,
|
|||
|
wantIDs: []int64{rows[2].ID, rows[1].ID, rows[0].ID},
|
|||
|
wantErr: nil,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "filter by UID",
|
|||
|
query: repository.TransactionQuery{
|
|||
|
PageIndex: 1,
|
|||
|
PageSize: 50,
|
|||
|
UID: proto.String("u1"),
|
|||
|
},
|
|||
|
wantCount: 2,
|
|||
|
wantIDs: []int64{rows[2].ID, rows[0].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "filter by BusinessType",
|
|||
|
query: repository.TransactionQuery{
|
|||
|
PageIndex: 1,
|
|||
|
PageSize: 50,
|
|||
|
BusinessType: []int8{2},
|
|||
|
},
|
|||
|
wantCount: 1,
|
|||
|
wantIDs: []int64{rows[1].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "filter by Asset + TxType",
|
|||
|
query: repository.TransactionQuery{
|
|||
|
PageIndex: 1,
|
|||
|
PageSize: 50,
|
|||
|
Assets: proto.String("BTC"),
|
|||
|
TxTypes: []wallet.TxType{wallet.Deposit},
|
|||
|
},
|
|||
|
wantCount: 2,
|
|||
|
wantIDs: []int64{rows[2].ID, rows[0].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "time range filter",
|
|||
|
query: repository.TransactionQuery{
|
|||
|
StartTime: proto.Int64(now - 25),
|
|||
|
EndTime: proto.Int64(now - 5),
|
|||
|
},
|
|||
|
wantCount: 2,
|
|||
|
wantIDs: []int64{rows[2].ID, rows[1].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "paging page 1 size 1",
|
|||
|
query: repository.TransactionQuery{
|
|||
|
PageIndex: 1,
|
|||
|
PageSize: 1,
|
|||
|
},
|
|||
|
wantCount: 3,
|
|||
|
wantIDs: []int64{rows[2].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "paging page 2 size 1",
|
|||
|
query: repository.TransactionQuery{
|
|||
|
PageIndex: 2,
|
|||
|
PageSize: 1,
|
|||
|
},
|
|||
|
wantCount: 3,
|
|||
|
wantIDs: []int64{rows[1].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "no match returns ErrRecordNotFound",
|
|||
|
query: repository.TransactionQuery{UID: proto.String("nonexist")},
|
|||
|
wantErr: repository.ErrRecordNotFound,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
ctx := context.Background()
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
got, cnt, err := repo.List(ctx, repository.TransactionQuery{
|
|||
|
BusinessType: tt.query.BusinessType,
|
|||
|
UID: tt.query.UID,
|
|||
|
OrderID: tt.query.OrderID,
|
|||
|
Assets: tt.query.Assets,
|
|||
|
TxTypes: tt.query.TxTypes,
|
|||
|
StartTime: tt.query.StartTime,
|
|||
|
EndTime: tt.query.EndTime,
|
|||
|
PageIndex: tt.query.PageIndex,
|
|||
|
PageSize: tt.query.PageSize,
|
|||
|
})
|
|||
|
if tt.wantErr != nil {
|
|||
|
assert.ErrorIs(t, err, tt.wantErr)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
assert.Equal(t, tt.wantCount, cnt)
|
|||
|
|
|||
|
var ids []int64
|
|||
|
for _, tx := range got {
|
|||
|
ids = append(ids, tx.ID)
|
|||
|
}
|
|||
|
assert.Equal(t, tt.wantIDs, ids)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTransactionRepository_FindByDueTimeRange(t *testing.T) {
|
|||
|
// start container and connect
|
|||
|
repo, db, tearDown, err := SetupTestTransactionRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer tearDown()
|
|||
|
|
|||
|
// prepare table
|
|||
|
createTransactionTable(t, db)
|
|||
|
|
|||
|
// seed rows
|
|||
|
now := time.Now().Unix()
|
|||
|
rows := []entity.Transaction{
|
|||
|
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 10, Status: wallet.WalletNonStatus},
|
|||
|
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now + 10, Status: wallet.WalletNonStatus},
|
|||
|
{TxType: wallet.Deposit, BusinessType: 0, DueTime: 0, Status: wallet.WalletNonStatus},
|
|||
|
{TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 5, Status: wallet.Enable(1)}, // status != non
|
|||
|
}
|
|||
|
|
|||
|
assert.NoError(t, db.Create(&rows).Error)
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
cutoff time.Time
|
|||
|
types []wallet.TxType
|
|||
|
wantIDs []int64
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "due before now for type=1",
|
|||
|
cutoff: time.Unix(now, 0),
|
|||
|
types: []wallet.TxType{wallet.Deposit},
|
|||
|
wantIDs: []int64{rows[0].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "include multiple types",
|
|||
|
cutoff: time.Unix(now+20, 0),
|
|||
|
types: []wallet.TxType{wallet.Deposit},
|
|||
|
wantIDs: []int64{rows[0].ID, rows[1].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "zero due_time is skipped",
|
|||
|
cutoff: time.Unix(now+100, 0),
|
|||
|
types: []wallet.TxType{wallet.Deposit},
|
|||
|
wantIDs: []int64{rows[0].ID, rows[1].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "no matches",
|
|||
|
cutoff: time.Unix(now-100, 0),
|
|||
|
types: []wallet.TxType{wallet.Deposit},
|
|||
|
wantIDs: nil,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
ctx := context.Background()
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
got, err := repo.FindByDueTimeRange(ctx, tt.cutoff, tt.types)
|
|||
|
assert.NoError(t, err)
|
|||
|
var ids []int64
|
|||
|
for _, tx := range got {
|
|||
|
ids = append(ids, tx.ID)
|
|||
|
}
|
|||
|
assert.Equal(t, tt.wantIDs, ids)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTransactionRepository_UpdateStatusByID(t *testing.T) {
|
|||
|
// start container and connect
|
|||
|
repo, db, tearDown, err := SetupTestTransactionRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer tearDown()
|
|||
|
|
|||
|
createTransactionTable(t, db)
|
|||
|
|
|||
|
ctx := context.Background()
|
|||
|
|
|||
|
tx := &entity.Transaction{
|
|||
|
OrderID: "ord-123",
|
|||
|
TransactionID: "tx-abc",
|
|||
|
Brand: "brand1",
|
|||
|
UID: "user1",
|
|||
|
ToUID: "user2",
|
|||
|
TxType: wallet.TxType(1),
|
|||
|
BusinessType: int8(2),
|
|||
|
Asset: "BTC",
|
|||
|
Amount: decimal.NewFromInt(100),
|
|||
|
PostTransferBalance: decimal.NewFromInt(1000),
|
|||
|
BeforeBalance: decimal.NewFromInt(900),
|
|||
|
Status: wallet.Enable(0),
|
|||
|
CreateAt: time.Now().Unix(),
|
|||
|
DueTime: time.Now().Add(time.Hour).Unix(),
|
|||
|
}
|
|||
|
err = repo.Insert(ctx, tx)
|
|||
|
assert.NoError(t, err)
|
|||
|
existingID := tx.ID
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
id int64
|
|||
|
newStatus int
|
|||
|
wantErr bool
|
|||
|
wantRow bool
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "update existing row",
|
|||
|
id: existingID,
|
|||
|
newStatus: 1,
|
|||
|
wantErr: false,
|
|||
|
wantRow: true,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "non-existent id does not error",
|
|||
|
id: existingID + 999,
|
|||
|
newStatus: 5,
|
|||
|
wantErr: false,
|
|||
|
wantRow: false,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
err := repo.UpdateStatusByID(ctx, tt.id, tt.newStatus)
|
|||
|
if tt.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
var got entity.Transaction
|
|||
|
res := db.First(&got, tt.id)
|
|||
|
if tt.wantRow {
|
|||
|
// existing row should reflect the new status
|
|||
|
assert.NoError(t, res.Error)
|
|||
|
assert.Equal(t, tt.newStatus, int(got.Status))
|
|||
|
} else {
|
|||
|
// non-existent id: no record found
|
|||
|
assert.Error(t, res.Error)
|
|||
|
assert.True(t, errors.Is(res.Error, gorm.ErrRecordNotFound))
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTransactionRepository_ListWalletTransactions(t *testing.T) {
|
|||
|
// start container and connect
|
|||
|
repo, db, tearDown, err := SetupTestTransactionRepository()
|
|||
|
assert.NoError(t, err)
|
|||
|
defer tearDown()
|
|||
|
|
|||
|
createWalletTransactionTable(t, db)
|
|||
|
ctx := context.Background()
|
|||
|
now := time.Now().UnixNano()
|
|||
|
transactions := []entity.WalletTransaction{
|
|||
|
{OrderID: "o1", UID: "u1", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(10), Balance: decimal.NewFromInt(110), CreateAt: now - 3},
|
|||
|
{OrderID: "o2", UID: "u1", WalletType: wallet.TypeFreeze, Asset: "ETH", Amount: decimal.NewFromInt(20), Balance: decimal.NewFromInt(220), CreateAt: now - 2},
|
|||
|
{OrderID: "o1", UID: "u2", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(30), Balance: decimal.NewFromInt(330), CreateAt: now - 1},
|
|||
|
}
|
|||
|
|
|||
|
assert.NoError(t, db.Create(&transactions).Error)
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
uid string
|
|||
|
orderIDs []string
|
|||
|
walletType wallet.Types
|
|||
|
wantIDs []int64
|
|||
|
wantErr bool
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "filter by uid=u1, both orders",
|
|||
|
uid: "u1",
|
|||
|
orderIDs: []string{"o1", "o2"},
|
|||
|
walletType: 0,
|
|||
|
wantIDs: []int64{transactions[0].ID, transactions[1].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "filter by uid=u1, walletType=1",
|
|||
|
uid: "u1",
|
|||
|
orderIDs: []string{"o1", "o2"},
|
|||
|
walletType: wallet.Types(1),
|
|||
|
wantIDs: []int64{transactions[0].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "filter by order=o1 only, any uid",
|
|||
|
uid: "",
|
|||
|
orderIDs: []string{"o1"},
|
|||
|
walletType: wallet.Types(1),
|
|||
|
wantIDs: []int64{transactions[0].ID, transactions[2].ID},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "no match yields empty",
|
|||
|
uid: "nope",
|
|||
|
orderIDs: []string{"o1"},
|
|||
|
walletType: 0,
|
|||
|
wantIDs: []int64{},
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
got, err := repo.ListWalletTransactions(ctx, tt.uid, tt.orderIDs, tt.walletType)
|
|||
|
if tt.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
// 提取返回的 IDs 进行无序比较
|
|||
|
gotIDs := make([]int64, len(got))
|
|||
|
for i, w := range got {
|
|||
|
gotIDs[i] = w.ID
|
|||
|
}
|
|||
|
assert.ElementsMatch(t, tt.wantIDs, gotIDs)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|