package repository import ( "code.30cm.net/digimon/app-cloudep-wallet-service/internal/config" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/entity" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/repository" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/domain/wallet" "code.30cm.net/digimon/app-cloudep-wallet-service/pkg/lib/sql_client" "context" "errors" "fmt" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" "gorm.io/gorm" "testing" "time" ) func SetupTestTransactionRepository() (repository.TransactionRepository, *gorm.DB, func(), error) { host, port, _, tearDown, err := startMySQLContainer() if err != nil { return nil, nil, nil, fmt.Errorf("failed to start MySQL container: %w", err) } conf := config.Config{ MySQL: struct { UserName string Password string Host string Port string Database string MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration LogLevel string }{ UserName: MySQLUser, Password: MySQLPassword, Host: host, Port: port, Database: MySQLDatabase, MaxIdleConns: 10, MaxOpenConns: 100, ConnMaxLifetime: 300, LogLevel: "info", }, } db, err := sql_client.NewMySQLClient(conf) if err != nil { tearDown() return nil, nil, nil, fmt.Errorf("failed to create db client: %w", err) } repo := MustTransactionRepository(TransactionRepositoryParam{DB: db}) return repo, db, tearDown, nil } func createTransactionTable(t *testing.T, db *gorm.DB) { sql := ` CREATE TABLE IF NOT EXISTS transaction ( id BIGINT AUTO_INCREMENT PRIMARY KEY, order_id VARCHAR(64) NOT NULL, transaction_id VARCHAR(64), brand VARCHAR(32), uid VARCHAR(64), to_uid VARCHAR(64), type TINYINT, business_type TINYINT, asset VARCHAR(32), amount DECIMAL(30,18), before_balance DECIMAL(30,18), post_transfer_balance DECIMAL(30,18), status TINYINT, create_at BIGINT, due_time BIGINT ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;` assert.NoError(t, db.Exec(sql).Error) } func createWalletTransactionTable(t *testing.T, db *gorm.DB) { createTableSQL := ` CREATE TABLE IF NOT EXISTS wallet_transaction ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主鍵 ID,自動遞增', transaction_id BIGINT NOT NULL COMMENT '交易流水號(可對應某次業務操作,例如同一訂單的多筆變化)', order_id VARCHAR(64) NOT NULL COMMENT '訂單編號(對應實際訂單或業務事件)', brand VARCHAR(50) NOT NULL COMMENT '品牌(多租戶或多平台識別)', uid VARCHAR(64) NOT NULL COMMENT '使用者 UID', wallet_type TINYINT NOT NULL COMMENT '錢包類型(如主錢包、獎勵錢包、凍結錢包等)', business_type TINYINT NOT NULL COMMENT '業務類型(如購物、退款、加值等)', asset VARCHAR(32) NOT NULL COMMENT '資產代號(如 BTC、ETH、GEM_RED、USD 等)', amount DECIMAL(30,18) NOT NULL COMMENT '變動金額(正數為收入,負數為支出)', balance DECIMAL(30,18) NOT NULL COMMENT '當前錢包餘額(這筆交易後的餘額快照)', create_at BIGINT NOT NULL DEFAULT 0 COMMENT '建立時間(UnixNano,紀錄交易發生時間)', PRIMARY KEY (id), KEY idx_uid (uid), KEY idx_transaction_id (transaction_id), KEY idx_order_id (order_id), KEY idx_brand (brand), KEY idx_wallet_type (wallet_type) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='錢包資金異動紀錄(每一次交易行為的快照記錄)';` assert.NoError(t, db.Exec(createTableSQL).Error) } func TestTransactionRepository_InsertAndBatchInsert(t *testing.T) { // start container and connect repo, db, tearDown, err := SetupTestTransactionRepository() assert.NoError(t, err) defer tearDown() // prepare table createTransactionTable(t, db) now := time.Now().Unix() template := entity.Transaction{ OrderID: "o1", TransactionID: "tx1", Brand: "b1", UID: "u1", ToUID: "u2", TxType: 1, BusinessType: 2, Asset: "BTC", Amount: decimal.RequireFromString("100.5"), BeforeBalance: decimal.RequireFromString("50.0"), PostTransferBalance: decimal.RequireFromString("150.5"), Status: 1, CreateAt: now, DueTime: now + 3600, } tests := []struct { name string setup func(t *testing.T) action func() error validate func(t *testing.T) }{ { name: "single insert", setup: func(t *testing.T) { // clean table assert.NoError(t, db.Exec("DELETE FROM transaction").Error) }, action: func() error { tx := template return repo.Insert(context.Background(), &tx) }, validate: func(t *testing.T) { var count int64 assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error) assert.Equal(t, int64(1), count) var got entity.Transaction err := db.Take(&got, "order_id = ?", template.OrderID).Error assert.NoError(t, err) }, }, { name: "batch insert", setup: func(t *testing.T) { assert.NoError(t, db.Exec("DELETE FROM transaction").Error) }, action: func() error { // clone two entries with different order IDs tx1 := template tx2 := template tx1.OrderID = "o2" tx2.OrderID = "o3" return repo.BatchInsert(context.Background(), []*entity.Transaction{&tx1, &tx2}) }, validate: func(t *testing.T) { var count int64 assert.NoError(t, db.Raw("SELECT COUNT(*) FROM transaction").Scan(&count).Error) assert.Equal(t, int64(2), count) var orders []string rows, _ := db.Raw("SELECT order_id FROM transaction ORDER BY order_id").Rows() defer rows.Close() for rows.Next() { var oid string rows.Scan(&oid) orders = append(orders, oid) } assert.Equal(t, []string{"o2", "o3"}, orders) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.setup(t) err := tt.action() assert.NoError(t, err) tt.validate(t) }) } } func TestTransactionRepository_FindByOrderID(t *testing.T) { // start container and connect repo, db, tearDown, err := SetupTestTransactionRepository() assert.NoError(t, err) defer tearDown() // prepare table createTransactionTable(t, db) // 4) seed one row now := time.Now().Unix() res := db.Exec( `INSERT INTO transaction (order_id, transaction_id, brand, uid, to_uid, type, business_type, asset, amount, before_balance, post_transfer_balance, status, create_at, due_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, "order-123", "tx-abc", "brandX", "user1", "user2", 1, 2, "BTC", "10.0", "5.0", "15.0", 1, now, now+3600, ) assert.NoError(t, res.Error) assert.Equal(t, int64(1), res.RowsAffected) type want struct { tx entity.Transaction wantErr bool } tests := []struct { name string orderID string want want }{ { name: "found existing", orderID: "order-123", want: want{ tx: entity.Transaction{ ID: 1, OrderID: "order-123", TransactionID: "tx-abc", Brand: "brandX", UID: "user1", ToUID: "user2", TxType: wallet.TxType(1), BusinessType: int8(2), Asset: "BTC", Amount: decimal.RequireFromString("10.0"), BeforeBalance: decimal.RequireFromString("5.0"), PostTransferBalance: decimal.RequireFromString("15.0"), Status: wallet.Enable(1), CreateAt: now, DueTime: now + 3600, }, wantErr: false, }, }, { name: "not found", orderID: "missing", want: want{wantErr: true}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := repo.FindByOrderID(context.Background(), tt.orderID) if tt.want.wantErr { assert.Error(t, err) assert.True(t, errors.Is(err, repository.ErrRecordNotFound)) return } assert.NoError(t, err) assert.Equal(t, tt.want.tx.ID, got.ID) assert.Equal(t, tt.want.tx.OrderID, got.OrderID) assert.Equal(t, tt.want.tx.TransactionID, got.TransactionID) assert.Equal(t, tt.want.tx.Brand, got.Brand) assert.Equal(t, tt.want.tx.UID, got.UID) assert.Equal(t, tt.want.tx.ToUID, got.ToUID) assert.Equal(t, tt.want.tx.TxType, got.TxType) assert.Equal(t, tt.want.tx.BusinessType, got.BusinessType) assert.Equal(t, tt.want.tx.Asset, got.Asset) assert.True(t, tt.want.tx.Amount.Equal(got.Amount)) assert.True(t, tt.want.tx.BeforeBalance.Equal(got.BeforeBalance)) assert.True(t, tt.want.tx.PostTransferBalance.Equal(got.PostTransferBalance)) assert.Equal(t, tt.want.tx.Status, got.Status) assert.Equal(t, tt.want.tx.CreateAt, got.CreateAt) assert.Equal(t, tt.want.tx.DueTime, got.DueTime) }) } } func TestTransactionRepository_List(t *testing.T) { repo, db, tearDown, err := SetupTestTransactionRepository() assert.NoError(t, err) defer tearDown() createTransactionTable(t, db) now := time.Now().Unix() rows := []entity.Transaction{ {OrderID: "A", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 30}, {OrderID: "B", UID: "u2", TxType: wallet.Withdraw, BusinessType: 2, Asset: "ETH", CreateAt: now - 20}, {OrderID: "C", UID: "u1", TxType: wallet.Deposit, BusinessType: 1, Asset: "BTC", CreateAt: now - 10}, } assert.NoError(t, db.Create(&rows).Error) tests := []struct { name string query repository.TransactionQuery wantCount int64 wantIDs []int64 wantErr error }{ { name: "no filter returns all", query: repository.TransactionQuery{ PageIndex: 1, PageSize: 50, }, wantCount: 3, wantIDs: []int64{rows[2].ID, rows[1].ID, rows[0].ID}, wantErr: nil, }, { name: "filter by UID", query: repository.TransactionQuery{ PageIndex: 1, PageSize: 50, UID: proto.String("u1"), }, wantCount: 2, wantIDs: []int64{rows[2].ID, rows[0].ID}, }, { name: "filter by BusinessType", query: repository.TransactionQuery{ PageIndex: 1, PageSize: 50, BusinessType: []int8{2}, }, wantCount: 1, wantIDs: []int64{rows[1].ID}, }, { name: "filter by Asset + TxType", query: repository.TransactionQuery{ PageIndex: 1, PageSize: 50, Assets: proto.String("BTC"), TxTypes: []wallet.TxType{wallet.Deposit}, }, wantCount: 2, wantIDs: []int64{rows[2].ID, rows[0].ID}, }, { name: "time range filter", query: repository.TransactionQuery{ StartTime: proto.Int64(now - 25), EndTime: proto.Int64(now - 5), }, wantCount: 2, wantIDs: []int64{rows[2].ID, rows[1].ID}, }, { name: "paging page 1 size 1", query: repository.TransactionQuery{ PageIndex: 1, PageSize: 1, }, wantCount: 3, wantIDs: []int64{rows[2].ID}, }, { name: "paging page 2 size 1", query: repository.TransactionQuery{ PageIndex: 2, PageSize: 1, }, wantCount: 3, wantIDs: []int64{rows[1].ID}, }, { name: "no match returns ErrRecordNotFound", query: repository.TransactionQuery{UID: proto.String("nonexist")}, wantErr: repository.ErrRecordNotFound, }, } ctx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, cnt, err := repo.List(ctx, repository.TransactionQuery{ BusinessType: tt.query.BusinessType, UID: tt.query.UID, OrderID: tt.query.OrderID, Assets: tt.query.Assets, TxTypes: tt.query.TxTypes, StartTime: tt.query.StartTime, EndTime: tt.query.EndTime, PageIndex: tt.query.PageIndex, PageSize: tt.query.PageSize, }) if tt.wantErr != nil { assert.ErrorIs(t, err, tt.wantErr) return } assert.NoError(t, err) assert.Equal(t, tt.wantCount, cnt) var ids []int64 for _, tx := range got { ids = append(ids, tx.ID) } assert.Equal(t, tt.wantIDs, ids) }) } } func TestTransactionRepository_FindByDueTimeRange(t *testing.T) { // start container and connect repo, db, tearDown, err := SetupTestTransactionRepository() assert.NoError(t, err) defer tearDown() // prepare table createTransactionTable(t, db) // seed rows now := time.Now().Unix() rows := []entity.Transaction{ {TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 10, Status: wallet.WalletNonStatus}, {TxType: wallet.Deposit, BusinessType: 0, DueTime: now + 10, Status: wallet.WalletNonStatus}, {TxType: wallet.Deposit, BusinessType: 0, DueTime: 0, Status: wallet.WalletNonStatus}, {TxType: wallet.Deposit, BusinessType: 0, DueTime: now - 5, Status: wallet.Enable(1)}, // status != non } assert.NoError(t, db.Create(&rows).Error) tests := []struct { name string cutoff time.Time types []wallet.TxType wantIDs []int64 }{ { name: "due before now for type=1", cutoff: time.Unix(now, 0), types: []wallet.TxType{wallet.Deposit}, wantIDs: []int64{rows[0].ID}, }, { name: "include multiple types", cutoff: time.Unix(now+20, 0), types: []wallet.TxType{wallet.Deposit}, wantIDs: []int64{rows[0].ID, rows[1].ID}, }, { name: "zero due_time is skipped", cutoff: time.Unix(now+100, 0), types: []wallet.TxType{wallet.Deposit}, wantIDs: []int64{rows[0].ID, rows[1].ID}, }, { name: "no matches", cutoff: time.Unix(now-100, 0), types: []wallet.TxType{wallet.Deposit}, wantIDs: nil, }, } ctx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := repo.FindByDueTimeRange(ctx, tt.cutoff, tt.types) assert.NoError(t, err) var ids []int64 for _, tx := range got { ids = append(ids, tx.ID) } assert.Equal(t, tt.wantIDs, ids) }) } } func TestTransactionRepository_UpdateStatusByID(t *testing.T) { // start container and connect repo, db, tearDown, err := SetupTestTransactionRepository() assert.NoError(t, err) defer tearDown() createTransactionTable(t, db) ctx := context.Background() tx := &entity.Transaction{ OrderID: "ord-123", TransactionID: "tx-abc", Brand: "brand1", UID: "user1", ToUID: "user2", TxType: wallet.TxType(1), BusinessType: int8(2), Asset: "BTC", Amount: decimal.NewFromInt(100), PostTransferBalance: decimal.NewFromInt(1000), BeforeBalance: decimal.NewFromInt(900), Status: wallet.Enable(0), CreateAt: time.Now().Unix(), DueTime: time.Now().Add(time.Hour).Unix(), } err = repo.Insert(ctx, tx) assert.NoError(t, err) existingID := tx.ID tests := []struct { name string id int64 newStatus int wantErr bool wantRow bool }{ { name: "update existing row", id: existingID, newStatus: 1, wantErr: false, wantRow: true, }, { name: "non-existent id does not error", id: existingID + 999, newStatus: 5, wantErr: false, wantRow: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := repo.UpdateStatusByID(ctx, tt.id, tt.newStatus) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) var got entity.Transaction res := db.First(&got, tt.id) if tt.wantRow { // existing row should reflect the new status assert.NoError(t, res.Error) assert.Equal(t, tt.newStatus, int(got.Status)) } else { // non-existent id: no record found assert.Error(t, res.Error) assert.True(t, errors.Is(res.Error, gorm.ErrRecordNotFound)) } }) } } func TestTransactionRepository_ListWalletTransactions(t *testing.T) { // start container and connect repo, db, tearDown, err := SetupTestTransactionRepository() assert.NoError(t, err) defer tearDown() createWalletTransactionTable(t, db) ctx := context.Background() now := time.Now().UnixNano() transactions := []entity.WalletTransaction{ {OrderID: "o1", UID: "u1", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(10), Balance: decimal.NewFromInt(110), CreateAt: now - 3}, {OrderID: "o2", UID: "u1", WalletType: wallet.TypeFreeze, Asset: "ETH", Amount: decimal.NewFromInt(20), Balance: decimal.NewFromInt(220), CreateAt: now - 2}, {OrderID: "o1", UID: "u2", WalletType: wallet.TypeAvailable, Asset: "BTC", Amount: decimal.NewFromInt(30), Balance: decimal.NewFromInt(330), CreateAt: now - 1}, } assert.NoError(t, db.Create(&transactions).Error) tests := []struct { name string uid string orderIDs []string walletType wallet.Types wantIDs []int64 wantErr bool }{ { name: "filter by uid=u1, both orders", uid: "u1", orderIDs: []string{"o1", "o2"}, walletType: 0, wantIDs: []int64{transactions[0].ID, transactions[1].ID}, }, { name: "filter by uid=u1, walletType=1", uid: "u1", orderIDs: []string{"o1", "o2"}, walletType: wallet.Types(1), wantIDs: []int64{transactions[0].ID}, }, { name: "filter by order=o1 only, any uid", uid: "", orderIDs: []string{"o1"}, walletType: wallet.Types(1), wantIDs: []int64{transactions[0].ID, transactions[2].ID}, }, { name: "no match yields empty", uid: "nope", orderIDs: []string{"o1"}, walletType: 0, wantIDs: []int64{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := repo.ListWalletTransactions(ctx, tt.uid, tt.orderIDs, tt.walletType) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) // 提取返回的 IDs 进行无序比较 gotIDs := make([]int64, len(got)) for i, w := range got { gotIDs[i] = w.ID } assert.ElementsMatch(t, tt.wantIDs, gotIDs) }) } }