package repository import ( "context" "fmt" "testing" "time" "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/entity" "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/product" "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/repository" mgo "code.30cm.net/digimon/library-go/mongo" "github.com/alicebob/miniredis/v2" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/mon" "github.com/zeromicro/go-zero/core/stores/redis" "go.mongodb.org/mongo-driver/bson/primitive" "google.golang.org/protobuf/proto" ) func SetupTestProductItemRepository(db string) (repository.ProductItemRepository, func(), error) { h, p, tearDown, err := startMongoContainer() if err != nil { return nil, nil, err } s, _ := miniredis.Run() conf := &mgo.Conf{ Schema: Schema, Host: fmt.Sprintf("%s:%s", h, p), Database: db, MaxStaleness: 300, MaxPoolSize: 100, MinPoolSize: 100, MaxConnIdleTime: 300, Compressors: []string{}, EnableStandardReadWriteSplitMode: false, ConnectTimeoutMs: 3000, } cacheConf := cache.CacheConf{ cache.NodeConf{ RedisConf: redis.RedisConf{ Host: s.Addr(), Type: redis.NodeType, }, Weight: 100, }, } cacheOpts := []cache.Option{ cache.WithExpiry(1000 * time.Microsecond), cache.WithNotFoundExpiry(1000 * time.Microsecond), } param := ProductItemRepositoryParam{ Conf: conf, CacheConf: cacheConf, CacheOpts: cacheOpts, DBOpts: []mon.Option{ mgo.SetCustomDecimalType(), mgo.InitMongoOptions(*conf), }, } repo := NewProductItemRepository(param) _, _ = repo.Index20250317001UP(context.Background()) return repo, tearDown, nil } func TestInsertProductItems(t *testing.T) { model, tearDown, err := SetupTestProductItemRepository("testDB") defer tearDown() assert.NoError(t, err) // 建立多筆 items items := []entity.ProductItems{ { ReferenceID: primitive.NewObjectID().Hex(), Name: "Item A", }, { ReferenceID: primitive.NewObjectID().Hex(), Name: "Item B", }, } // 呼叫插入 err = model.Insert(context.Background(), items) assert.NoError(t, err) // 驗證插入是否成功(逐筆查詢) for _, item := range items { // 檢查 ID 是否自動填入 assert.False(t, item.ID.IsZero(), "ID should be generated") assert.NotZero(t, item.CreatedAt) assert.NotZero(t, item.UpdatedAt) // 查詢 DB 確認存在 var result *entity.ProductItems result, err = model.FindByID(context.Background(), item.ID.Hex()) assert.NoError(t, err) assert.Equal(t, item.Name, result.Name) } // 🧪 空陣列插入測試(不應報錯) t.Run("Insert empty slice", func(t *testing.T) { err := model.Insert(context.Background(), []entity.ProductItems{}) assert.NoError(t, err) }) } func TestDeleteProductItem(t *testing.T) { model, tearDown, err := SetupTestProductItemRepository("testDB") defer tearDown() assert.NoError(t, err) // 建立多筆 items items := []entity.ProductItems{ { ReferenceID: primitive.NewObjectID().Hex(), Name: "Item A", }, { ReferenceID: primitive.NewObjectID().Hex(), Name: "Item B", }, } // 呼叫插入 err = model.Insert(context.Background(), items) assert.NoError(t, err) tests := []struct { name string inputID string expectErr error checkAfter bool // 是否需要確認資料已刪除 }{ { name: "Delete existing product", inputID: items[0].ID.Hex(), expectErr: nil, checkAfter: true, }, { name: "Delete non-existing product", inputID: primitive.NewObjectID().Hex(), expectErr: nil, checkAfter: false, }, { name: "Invalid ObjectID format", inputID: "not-an-object-id", expectErr: ErrInvalidObjectID, checkAfter: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := model.Delete(context.Background(), []string{tt.inputID}) if tt.expectErr != nil { assert.ErrorIs(t, err, tt.expectErr) } else { assert.NoError(t, err) } // 驗證資料是否真的刪除了(僅當需要) if tt.checkAfter { _, err := model.FindByID(context.Background(), tt.inputID) assert.ErrorIs(t, err, ErrNotFound) } }) } } func TestListProductItem(t *testing.T) { model, tearDown, err := SetupTestProductItemRepository("testDB") defer tearDown() assert.NoError(t, err) now := time.Now().UnixNano() rfcID := primitive.NewObjectID().Hex() // 插入測試資料 item1 := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: rfcID, Name: "Item A", IsFree: true, Status: product.StatusActive, CreatedAt: now, UpdatedAt: now, } item2 := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: rfcID, Name: "Item B", IsFree: false, Status: product.StatusInactive, CreatedAt: now, UpdatedAt: now, } item3 := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: rfcID, Name: "Item C", IsFree: true, Status: product.StatusActive, CreatedAt: now, UpdatedAt: now, } err = model.Insert(context.Background(), []entity.ProductItems{item1, item2, item3}) assert.NoError(t, err) tests := []struct { name string params repository.ProductItemQueryParams expectCount int64 expectIDs []primitive.ObjectID }{ { name: "Filter by ReferenceID", params: repository.ProductItemQueryParams{ ReferenceID: ptr(rfcID), PageSize: 10, PageIndex: 1, }, expectCount: 3, expectIDs: []primitive.ObjectID{item3.ID, item2.ID, item1.ID}, }, { name: "Filter by IsFree = true", params: repository.ProductItemQueryParams{ IsFree: ptr(true), PageSize: 10, PageIndex: 1, }, expectCount: 2, expectIDs: []primitive.ObjectID{item3.ID, item1.ID}, }, { name: "Filter by Status = 2", params: repository.ProductItemQueryParams{ Status: ptr(product.StatusInactive), PageSize: 10, PageIndex: 1, }, expectCount: 1, expectIDs: []primitive.ObjectID{item2.ID}, }, { name: "Filter by ItemIDs", params: repository.ProductItemQueryParams{ ItemID: []string{item1.ID.Hex(), item2.ID.Hex()}, PageSize: 10, PageIndex: 1, }, expectCount: 2, expectIDs: []primitive.ObjectID{item2.ID, item1.ID}, }, { name: "Pagination works", params: repository.ProductItemQueryParams{ PageSize: 1, PageIndex: 2, }, expectCount: 3, expectIDs: []primitive.ObjectID{item2.ID}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { results, count, err := model.ListProductItem(context.Background(), tt.params) assert.NoError(t, err) assert.Equal(t, tt.expectCount, count) var gotIDs []primitive.ObjectID for _, r := range results { gotIDs = append(gotIDs, r.ID) } assert.ElementsMatch(t, tt.expectIDs, gotIDs) }) } } func TestDeleteByReferenceID(t *testing.T) { model, tearDown, err := SetupTestProductItemRepository("testDB") defer tearDown() assert.NoError(t, err) ctx := context.Background() now := time.Now().UnixNano() // 插入測試資料 refID := primitive.NewObjectID().Hex() item1 := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: refID, Name: "Item A", CreatedAt: now, UpdatedAt: now, } item2 := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: refID, Name: "Item B", CreatedAt: now, UpdatedAt: now, } itemOther := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: primitive.NewObjectID().Hex(), Name: "Should not be deleted", CreatedAt: now, UpdatedAt: now, } err = model.Insert(ctx, []entity.ProductItems{item1, item2, itemOther}) assert.NoError(t, err) tests := []struct { name string referenceID string expectDeleted []primitive.ObjectID expectRemained []primitive.ObjectID }{ { name: "Delete existing reference_id items", referenceID: refID, expectDeleted: []primitive.ObjectID{item1.ID, item2.ID}, expectRemained: []primitive.ObjectID{itemOther.ID}, }, { name: "Delete non-existent reference_id", referenceID: "no-match-ref", expectDeleted: nil, expectRemained: []primitive.ObjectID{itemOther.ID}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := model.DeleteByReferenceID(ctx, tt.referenceID) assert.NoError(t, err) // 檢查指定應被刪除的項目是否真的被刪除 for _, id := range tt.expectDeleted { _, err = model.FindByID(ctx, id.Hex()) assert.Error(t, err, "Expected item to be deleted: "+id.Hex()) } // 檢查不應刪除的項目仍存在 for _, id := range tt.expectRemained { _, err = model.FindByID(ctx, id.Hex()) assert.NoError(t, err, "Expected item to remain: "+id.Hex()) } }) } } func TestUpdateProductItem(t *testing.T) { model, tearDown, err := SetupTestProductItemRepository("testDB") defer tearDown() assert.NoError(t, err) ctx := context.Background() // 建立一筆測試資料 item := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: primitive.NewObjectID().Hex(), Name: "Original Name", Stock: 10, Price: decimal.NewFromInt(500), IsUnLimit: false, IsFree: false, CreatedAt: time.Now().UnixNano(), UpdatedAt: time.Now().UnixNano(), } err = model.Insert(ctx, []entity.ProductItems{item}) assert.NoError(t, err) tests := []struct { name string id string update *repository.ProductUpdateItem expectErr error validate func(t *testing.T, updated *entity.ProductItems) }{ { name: "Update is_un_limit and is_free", id: item.ID.Hex(), update: &repository.ProductUpdateItem{ IsUnLimit: ptr(true), IsFree: ptr(true), }, expectErr: nil, validate: func(t *testing.T, updated *entity.ProductItems) { assert.True(t, updated.IsUnLimit) assert.True(t, updated.IsFree) }, }, { name: "Update SKU and TimeSeries", id: item.ID.Hex(), update: &repository.ProductUpdateItem{ SKU: ptr("SKU-XYZ-001"), TimeSeries: ptr(product.TimeSeriesTenMinutes), }, expectErr: nil, validate: func(t *testing.T, updated *entity.ProductItems) { assert.Equal(t, "SKU-XYZ-001", updated.SKU) assert.Equal(t, product.TimeSeriesTenMinutes, updated.TimeSeries) }, }, { name: "Update Media field", id: item.ID.Hex(), update: &repository.ProductUpdateItem{ Media: []entity.Media{ {Sort: 1, Type: "image", URL: "https://example.com/img.jpg"}, }, }, expectErr: nil, validate: func(t *testing.T, updated *entity.ProductItems) { assert.Len(t, updated.Media, 1) assert.Equal(t, "image", updated.Media[0].Type) }, }, { name: "Update CustomFields", id: item.ID.Hex(), update: &repository.ProductUpdateItem{ CustomFields: []entity.CustomFields{ {Key: "color", Value: "red"}, }, }, expectErr: nil, validate: func(t *testing.T, updated *entity.ProductItems) { assert.Len(t, updated.CustomFields, 1) assert.Equal(t, "color", updated.CustomFields[0].Key) }, }, { name: "Update Freight", id: item.ID.Hex(), update: &repository.ProductUpdateItem{ Freight: []entity.CustomFields{ { Key: "color", Value: "red", }, }, }, expectErr: nil, validate: func(t *testing.T, updated *entity.ProductItems) { assert.Equal(t, "color", updated.Freight[0].Key) assert.Equal(t, "red", updated.Freight[0].Value) }, }, { name: "Update name field", id: item.ID.Hex(), update: &repository.ProductUpdateItem{ Name: ptr("Updated Name"), }, expectErr: nil, validate: func(t *testing.T, updated *entity.ProductItems) { assert.Equal(t, "Updated Name", updated.Name) }, }, { name: "Update stock and price", id: item.ID.Hex(), update: &repository.ProductUpdateItem{ Stock: proto.Int64(99), Price: ptr(decimal.NewFromInt(999)), }, expectErr: nil, validate: func(t *testing.T, updated *entity.ProductItems) { assert.Equal(t, uint64(99), updated.Stock) assert.Equal(t, "999", updated.Price.String()) }, }, { name: "Invalid ObjectID", id: "not-an-id", update: &repository.ProductUpdateItem{Name: ptr("Invalid")}, expectErr: ErrInvalidObjectID, }, { name: "Empty update struct", id: item.ID.Hex(), update: &repository.ProductUpdateItem{}, // no fields expectErr: fmt.Errorf("no fields to update"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := model.Update(ctx, tt.id, tt.update) if tt.expectErr != nil { assert.Error(t, err) assert.Contains(t, err.Error(), tt.expectErr.Error()) return } assert.NoError(t, err) updated, err := model.FindByID(ctx, item.ID.Hex()) assert.NoError(t, err) if tt.validate != nil { tt.validate(t, updated) } }) } } func TestUpdateStatus_TableDriven(t *testing.T) { repo, tearDown, err := SetupTestProductItemRepository("testDB") require.NoError(t, err) defer tearDown() ctx := context.Background() // Insert a sample product item. now := time.Now().UnixNano() item := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: primitive.NewObjectID().Hex(), Name: "Test Item", Status: product.StatusInactive, // initial status CreatedAt: now, UpdatedAt: now, } err = repo.Insert(ctx, []entity.ProductItems{item}) require.NoError(t, err) tests := []struct { name string id string newStatus product.ItemStatus expectErr error check bool // whether to verify the update in DB }{ { name: "Valid update", id: item.ID.Hex(), newStatus: product.StatusActive, expectErr: nil, check: true, }, { name: "Invalid ObjectID", id: "invalid-id", newStatus: product.StatusActive, expectErr: ErrInvalidObjectID, check: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := repo.UpdateStatus(ctx, tt.id, tt.newStatus) if tt.expectErr != nil { assert.ErrorIs(t, err, tt.expectErr) } else { assert.NoError(t, err) } // If expected to check update, verify that the product's status is updated. if tt.check { updated, err := repo.FindByID(ctx, tt.id) assert.NoError(t, err) assert.Equal(t, tt.newStatus, updated.Status) } }) } } func TestIncSalesCount(t *testing.T) { repo, tearDown, err := SetupTestProductItemRepository("testDB") require.NoError(t, err) defer tearDown() ctx := context.Background() // Insert a sample product item with initial SalesCount = 0. now := time.Now().UnixNano() item := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: primitive.NewObjectID().Hex(), Name: "Sales Count Test Item", SalesCount: 0, CreatedAt: now, UpdatedAt: now, } err = repo.Insert(ctx, []entity.ProductItems{item}) require.NoError(t, err) tests := []struct { name string id string count int64 expectErr error check bool // whether to verify the updated sales count in the DB. }{ { name: "Increment sales count by 5", id: item.ID.Hex(), count: 5, expectErr: nil, check: true, }, { name: "Invalid ObjectID", id: "invalid-id", count: 3, expectErr: ErrInvalidObjectID, check: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := repo.IncSalesCount(ctx, tt.id, tt.count) if tt.expectErr != nil { assert.ErrorIs(t, err, tt.expectErr) } else { assert.NoError(t, err) } // If check is true, verify that the sales_count is updated correctly. if tt.check { updated, err := repo.FindByID(ctx, tt.id) assert.NoError(t, err) // Since initial SalesCount was 0, after increment it should equal tt.count. assert.Equal(t, uint64(tt.count), updated.SalesCount) } }) } } func TestDecSalesCount(t *testing.T) { repo, tearDown, err := SetupTestProductItemRepository("testDB") require.NoError(t, err) defer tearDown() ctx := context.Background() now := time.Now().UnixNano() // Insert an item with an initial SalesCount of 10. item := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: primitive.NewObjectID().Hex(), Name: "Dec Sales Count Test Item", SalesCount: 10, CreatedAt: now, UpdatedAt: now, } // Insert an item with SalesCount equal to 0 (to test underflow behavior). zeroItem := entity.ProductItems{ ID: primitive.NewObjectID(), ReferenceID: primitive.NewObjectID().Hex(), Name: "Zero Sales Count Item", SalesCount: 0, CreatedAt: now, UpdatedAt: now, } err = repo.Insert(ctx, []entity.ProductItems{item, zeroItem}) require.NoError(t, err) tests := []struct { name string id string decCount int64 expectErr error check bool // whether to verify the updated sales count in the DB expected uint64 // expected SalesCount if check is true }{ { name: "Valid decrement from 10", id: item.ID.Hex(), decCount: 3, expectErr: nil, check: true, expected: 7, }, { name: "Invalid ObjectID", id: "invalid-id", decCount: 3, expectErr: ErrInvalidObjectID, check: false, }, { name: "Decrement from zero (should not allow underflow)", id: zeroItem.ID.Hex(), decCount: 5, expectErr: nil, check: true, expected: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := repo.DecSalesCount(ctx, tt.id, tt.decCount) if tt.expectErr != nil { assert.Error(t, err) // Check that the error message contains "failed to decrease stock" (for the underflow case) if tt.expectErr.Error() != ErrInvalidObjectID.Error() { assert.Contains(t, err.Error(), "failed to decrease stock") } else { assert.ErrorIs(t, err, tt.expectErr) } } else { assert.NoError(t, err) } // If expected to check, verify the updated SalesCount. if tt.check { updated, err := repo.FindByID(ctx, tt.id) assert.NoError(t, err) assert.Equal(t, tt.expected, updated.SalesCount) } }) } } func TestGetSalesCount(t *testing.T) { // 取得測試 repository repo, tearDown, err := SetupTestProductItemRepository("testDB") require.NoError(t, err) defer tearDown() ctx := context.Background() // 預先建立測試資料,設定各項目的 SalesCount 值 testItems := []entity.ProductItems{ { ReferenceID: primitive.NewObjectID().Hex(), Name: "Test Item 1", SalesCount: 5, }, { ReferenceID: primitive.NewObjectID().Hex(), Name: "Test Item 2", SalesCount: 15, }, } // 插入測試資料 err = repo.Insert(ctx, testItems) require.NoError(t, err) // 建立一組包含有效與無效 ID 的字串陣列 validIDs := []string{ testItems[0].ID.Hex(), testItems[1].ID.Hex(), } // 在陣列前面加上一個無法轉換的 ID mixedIDs := append([]string{"invalidID"}, validIDs...) t.Run("with valid and invalid IDs", func(t *testing.T) { salesCounts, err := repo.GetSalesCount(ctx, mixedIDs) require.NoError(t, err) // 預期只會回傳有效的兩筆資料 require.Len(t, salesCounts, len(validIDs)) // 驗證每筆資料的 SalesCount 是否正確 for _, sc := range salesCounts { switch sc.ID { case testItems[0].ID.Hex(): assert.Equal(t, uint64(5), sc.Count) case testItems[1].ID.Hex(): assert.Equal(t, uint64(15), sc.Count) default: t.Errorf("unexpected product item ID: %s", sc.ID) } } }) t.Run("with all invalid IDs", func(t *testing.T) { salesCounts, err := repo.GetSalesCount(ctx, []string{"badid1", "badid2"}) // 因無法轉換成 ObjectID,預期會回傳 ErrInvalidObjectID assert.Error(t, err) assert.Equal(t, ErrInvalidObjectID, err) assert.Nil(t, salesCounts) }) }