From fdc0799fccb3c632d26743bbd0f75b3a4f89a03d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=A7=E9=A9=8A?= Date: Wed, 19 Mar 2025 19:28:07 +0800 Subject: [PATCH] feat: product statustucs --- pkg/domain/entity/product_statistics.go | 2 + pkg/domain/redis.go | 9 +- pkg/domain/repository/product_statistics.go | 32 +- pkg/repository/product.go | 8 +- pkg/repository/product_statistics.go | 286 ++++++++++ pkg/repository/product_statistics_test.go | 562 ++++++++++++++++++++ 6 files changed, 895 insertions(+), 4 deletions(-) create mode 100644 pkg/repository/product_statistics.go create mode 100644 pkg/repository/product_statistics_test.go diff --git a/pkg/domain/entity/product_statistics.go b/pkg/domain/entity/product_statistics.go index 3672d93..d139faa 100644 --- a/pkg/domain/entity/product_statistics.go +++ b/pkg/domain/entity/product_statistics.go @@ -12,6 +12,8 @@ type ProductStatistics struct { AverageRatingUpdateTime int64 `bson:"average_rating_time"` // 更新評價的時間 FansCount uint64 `bson:"fans_count"` // 追蹤數量 FansCountUpdateTime int64 `bson:"fans_count_update_time"` // 更新追蹤的時間 + UpdatedAt int64 `bson:"updated_at"` // 更新時間 + CreatedAt int64 `bson:"created_at"` // 創建時間 } func (p *ProductStatistics) CollectionName() string { diff --git a/pkg/domain/redis.go b/pkg/domain/redis.go index b4e6467..fba0c3f 100644 --- a/pkg/domain/redis.go +++ b/pkg/domain/redis.go @@ -15,8 +15,9 @@ func (key RedisKey) With(s ...string) RedisKey { } const ( - GetProductRedisKey RedisKey = "get" - GetProductItemRedisKey RedisKey = "get_item" + GetProductRedisKey RedisKey = "get" + GetProductItemRedisKey RedisKey = "get_item" + GetProductStatisticsRedisKey RedisKey = "statistics" ) func GetProductRK(id string) string { @@ -26,3 +27,7 @@ func GetProductRK(id string) string { func GetProductItemRK(id string) string { return GetProductItemRedisKey.With(id).ToString() } + +func GetProductStatisticsRK(id string) string { + return GetProductStatisticsRedisKey.With(id).ToString() +} diff --git a/pkg/domain/repository/product_statistics.go b/pkg/domain/repository/product_statistics.go index 04d5457..e7f3621 100644 --- a/pkg/domain/repository/product_statistics.go +++ b/pkg/domain/repository/product_statistics.go @@ -1,3 +1,33 @@ package repository -type ProductStatisticsRepo interface{} +import ( + "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/entity" + "context" + "go.mongodb.org/mongo-driver/mongo" +) + +type ProductStatisticsRepo interface { + // Create 新增一筆產品統計資料 + Create(ctx context.Context, stats *entity.ProductStatistics) error + // GetByID 根據內部 ID 取得統計資料 + GetByID(ctx context.Context, id string) (*entity.ProductStatistics, error) + // GetByProductID 根據產品 ID 取得統計資料 + GetByProductID(ctx context.Context, productID string) (*entity.ProductStatistics, error) + // IncOrders 新增訂單數 + IncOrders(ctx context.Context, productID string, count int64) error + // DecOrders 減少訂單數。-> 退貨時專用 + DecOrders(ctx context.Context, productID string, count int64) error + // UpdateAverageRating 只更新綜合評價及其更新時間 + UpdateAverageRating(ctx context.Context, productID string, averageRating float64) error + // IncFansCount 新增粉絲數 + IncFansCount(ctx context.Context, productID string, fansCount uint64) error + // DecFansCount 減少粉絲數。-> 退貨時專用 + DecFansCount(ctx context.Context, productID string, fansCount uint64) error + // Delete 刪除統計資料 + Delete(ctx context.Context, id string) error + ProductStatisticsIndex +} + +type ProductStatisticsIndex interface { + Index20250317001UP(ctx context.Context) (*mongo.Cursor, error) +} diff --git a/pkg/repository/product.go b/pkg/repository/product.go index 19a9953..cf8c0d3 100644 --- a/pkg/repository/product.go +++ b/pkg/repository/product.go @@ -163,7 +163,13 @@ func (repo *ProductRepository) Delete(ctx context.Context, id string) error { return err } - _, err = repo.DB.DeleteOne(ctx, domain.GetProductRK(id), item) + oid, err := primitive.ObjectIDFromHex(id) + if err != nil { + return ErrInvalidObjectID + } + + filter := bson.M{"_id": oid} + _, err = repo.DB.DeleteOne(ctx, domain.GetProductRK(id), filter) if err != nil { return err } diff --git a/pkg/repository/product_statistics.go b/pkg/repository/product_statistics.go new file mode 100644 index 0000000..3433bd0 --- /dev/null +++ b/pkg/repository/product_statistics.go @@ -0,0 +1,286 @@ +package repository + +import ( + "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain" + "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/entity" + "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/repository" + mgo "code.30cm.net/digimon/library-go/mongo" + "context" + "errors" + "fmt" + "github.com/zeromicro/go-zero/core/stores/cache" + "github.com/zeromicro/go-zero/core/stores/mon" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "time" +) + +type ProductStatisticsRepositoryParam struct { + Conf *mgo.Conf + CacheConf cache.CacheConf + DBOpts []mon.Option + CacheOpts []cache.Option +} + +type ProductStatisticsRepository struct { + DB mgo.DocumentDBWithCacheUseCase +} + +func NewProductStatisticsRepository(param ProductStatisticsRepositoryParam) repository.ProductStatisticsRepo { + e := entity.ProductStatistics{} + documentDB, err := mgo.MustDocumentDBWithCache( + param.Conf, + e.CollectionName(), + param.CacheConf, + param.DBOpts, + param.CacheOpts, + ) + if err != nil { + panic(err) + } + + return &ProductStatisticsRepository{ + DB: documentDB, + } +} + +func (repo *ProductStatisticsRepository) Create(ctx context.Context, data *entity.ProductStatistics) error { + if data.ID.IsZero() { + now := time.Now().UTC().UnixNano() + data.ID = primitive.NewObjectID() + data.CreatedAt = now + data.UpdatedAt = now + } + rk := domain.GetProductStatisticsRK(data.ID.Hex()) + _, err := repo.DB.InsertOne(ctx, rk, data) + + productKey := domain.GetProductStatisticsRK(data.ProductID) + _ = repo.DB.SetCache(productKey, data) + + return err +} + +func (repo *ProductStatisticsRepository) GetByID(ctx context.Context, id string) (*entity.ProductStatistics, error) { + oid, err := primitive.ObjectIDFromHex(id) + if err != nil { + return nil, err + } + + var result *entity.ProductStatistics + err = repo.DB.FindOne(ctx, domain.GetProductStatisticsRK(id), &result, bson.M{"_id": oid}) + switch { + case err == nil: + return result, nil + case errors.Is(err, mon.ErrNotFound): + return nil, ErrNotFound + default: + return nil, err + } +} + +func (repo *ProductStatisticsRepository) GetByProductID(ctx context.Context, productID string) (*entity.ProductStatistics, error) { + var result *entity.ProductStatistics + err := repo.DB.FindOne(ctx, domain.GetProductStatisticsRK(productID), &result, bson.M{"product_id": productID}) + switch { + case err == nil: + return result, nil + case errors.Is(err, mon.ErrNotFound): + return nil, ErrNotFound + default: + return nil, err + } +} + +func (repo *ProductStatisticsRepository) IncOrders(ctx context.Context, productID string, count int64) error { + filter := bson.M{"product_id": productID} + update := bson.M{"$inc": bson.M{"total_orders": count}} + + rk := domain.GetProductStatisticsRK(productID) + _, err := repo.DB.UpdateOne(ctx, rk, filter, update) + if err != nil { + return fmt.Errorf("failed to decrease stock: %w", err) + } + + id, err := repo.getIDByProductID(ctx, productID) + if err != nil { + return fmt.Errorf("failed to get product_id by id: %w", err) + } + + repo.clearCache(ctx, id, productID) + + return nil +} + +func (repo *ProductStatisticsRepository) DecOrders(ctx context.Context, productID string, count int64) error { + filter := bson.M{"product_id": productID, "total_orders": bson.M{"$gte": count}} + update := bson.M{"$inc": bson.M{"total_orders": -count}} + + rk := domain.GetProductStatisticsRK(productID) + _, err := repo.DB.UpdateOne(ctx, rk, filter, update) + if err != nil { + return fmt.Errorf("failed to decrease stock: %w", err) + } + + id, err := repo.getIDByProductID(ctx, productID) + if err != nil { + return fmt.Errorf("failed to get product_id by id: %w", err) + } + + repo.clearCache(ctx, id, productID) + + return nil +} + +func (repo *ProductStatisticsRepository) UpdateAverageRating(ctx context.Context, productID string, averageRating float64) error { + filter := bson.M{"product_id": productID} + now := time.Now().UnixNano() + update := bson.M{ + "$set": bson.M{ + "average_rating": averageRating, + "average_rating_time": now, + "updated_at": now, + }, + } + + _, err := repo.DB.UpdateOne(ctx, domain.GetProductStatisticsRK(productID), filter, update) + if err != nil { + return err + } + + id, err := repo.getIDByProductID(ctx, productID) + if err != nil { + return fmt.Errorf("failed to get product_id by id: %w", err) + } + + repo.clearCache(ctx, id, productID) + + return nil +} + +func (repo *ProductStatisticsRepository) IncFansCount(ctx context.Context, productID string, fansCount uint64) error { + filter := bson.M{"product_id": productID} + now := time.Now().UnixNano() + update := bson.M{ + "$inc": bson.M{"fans_count": fansCount}, + "$set": bson.M{ + "fans_count_update_time": now, + "updated_at": now, + }, + } + + rk := domain.GetProductStatisticsRK(productID) + _, err := repo.DB.UpdateOne(ctx, rk, filter, update) + if err != nil { + return fmt.Errorf("failed to increment fans count: %w", err) + } + + id, err := repo.getIDByProductID(ctx, productID) + if err != nil { + return fmt.Errorf("failed to get product_id by id: %w", err) + } + + repo.clearCache(ctx, id, productID) + + return nil +} + +func (repo *ProductStatisticsRepository) DecFansCount(ctx context.Context, productID string, fansCount uint64) error { + // 只允許在 fans_count 大於或等於欲扣減值時進行扣減 + filter := bson.M{"product_id": productID, "fans_count": bson.M{"$gte": fansCount}} + now := time.Now().UnixNano() + update := bson.M{ + "$inc": bson.M{"fans_count": -int64(fansCount)}, + "$set": bson.M{ + "fans_count_update_time": now, + "updated_at": now, + }, + } + + rk := domain.GetProductStatisticsRK(productID) + _, err := repo.DB.UpdateOne(ctx, rk, filter, update) + if err != nil { + return fmt.Errorf("failed to decrement fans count: %w", err) + } + + id, err := repo.getIDByProductID(ctx, productID) + if err != nil { + return fmt.Errorf("failed to get product_id by id: %w", err) + } + + repo.clearCache(ctx, id, productID) + + return nil +} + +func (repo *ProductStatisticsRepository) Delete(ctx context.Context, id string) error { + oid, err := primitive.ObjectIDFromHex(id) + if err != nil { + return ErrInvalidObjectID + } + productID, err := repo.getProductIDByID(ctx, id) + if err != nil { + return fmt.Errorf("failed to get product_id by id: %w", err) + } + + filter := bson.M{"_id": oid} + _, err = repo.DB.DeleteOne(ctx, domain.GetProductStatisticsRK(id), filter) + if err != nil { + return err + } + + repo.clearCache(ctx, id, productID) + + return nil +} + +func (repo *ProductStatisticsRepository) Index20250317001UP(ctx context.Context) (*mongo.Cursor, error) { + // 等價於 db.account.createIndex({"product_id": 1}) + repo.DB.PopulateIndex(ctx, "product_id", 1, true) + + return repo.DB.GetClient().Indexes().List(ctx) +} + +// 快取輔助函數 +// clearCache 同時刪除 product_id 與 _id 兩個 cache key +func (repo *ProductStatisticsRepository) clearCache(ctx context.Context, id, productID string) { + keys := []string{ + domain.GetProductStatisticsRK(productID), + domain.GetProductStatisticsRK(id), + } + for _, key := range keys { + _ = repo.DB.DelCache(ctx, key) + } +} + +func (repo *ProductStatisticsRepository) getIDByProductID(ctx context.Context, productID string) (string, error) { + filter := bson.M{"product_id": productID} + var e entity.ProductStatistics + projection := bson.M{"_id": 1} + opts := options.FindOne().SetProjection(projection) + err := repo.DB.GetClient().FindOne(ctx, &e, filter, opts) + if err != nil { + return "", fmt.Errorf("failed to set projection: %w", err) + } + + return e.ID.Hex(), nil +} + +func (repo *ProductStatisticsRepository) getProductIDByID(ctx context.Context, id string) (string, error) { + oid, err := primitive.ObjectIDFromHex(id) + if err != nil { + return "", err + } + + filter := bson.M{"_id": oid} + var e entity.ProductStatistics + projection := bson.M{"product_id": 1} + opts := options.FindOne().SetProjection(projection) + err = repo.DB.GetClient().FindOne(ctx, &e, filter, opts) + if err != nil { + return "", fmt.Errorf("failed to set projection: %w", err) + } + + return e.ProductID, nil +} diff --git a/pkg/repository/product_statistics_test.go b/pkg/repository/product_statistics_test.go new file mode 100644 index 0000000..fd13a90 --- /dev/null +++ b/pkg/repository/product_statistics_test.go @@ -0,0 +1,562 @@ +package repository + +import ( + "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/entity" + "code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/repository" + mgo "code.30cm.net/digimon/library-go/mongo" + "context" + "fmt" + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zeromicro/go-zero/core/stores/cache" + "github.com/zeromicro/go-zero/core/stores/mon" + "github.com/zeromicro/go-zero/core/stores/redis" + "testing" + "time" +) + +func SetupTestProductStatisticsRepo(db string) (repository.ProductStatisticsRepo, func(), error) { + h, p, tearDown, err := startMongoContainer() + if err != nil { + return nil, nil, err + } + s, _ := miniredis.Run() + + conf := &mgo.Conf{ + Schema: Schema, + Host: fmt.Sprintf("%s:%s", h, p), + Database: db, + MaxStaleness: 300, + MaxPoolSize: 100, + MinPoolSize: 100, + MaxConnIdleTime: 300, + Compressors: []string{}, + EnableStandardReadWriteSplitMode: false, + ConnectTimeoutMs: 3000, + } + + cacheConf := cache.CacheConf{ + cache.NodeConf{ + RedisConf: redis.RedisConf{ + Host: s.Addr(), + Type: redis.NodeType, + }, + Weight: 100, + }, + } + + cacheOpts := []cache.Option{ + cache.WithExpiry(1000 * time.Microsecond), + cache.WithNotFoundExpiry(1000 * time.Microsecond), + } + + param := ProductStatisticsRepositoryParam{ + Conf: conf, + CacheConf: cacheConf, + CacheOpts: cacheOpts, + DBOpts: []mon.Option{ + mgo.SetCustomDecimalType(), + mgo.InitMongoOptions(*conf), + }, + } + + repo := NewProductStatisticsRepository(param) + _, _ = repo.Index20250317001UP(context.Background()) + + return repo, tearDown, nil +} + +func TestCreateProductStatistics(t *testing.T) { + // 假設有 SetupTestProductStatisticsRepository 可用來建立測試用的 ProductStatisticsRepository + repo, tearDown, err := SetupTestProductStatisticsRepo("testDB") + require.NoError(t, err) + defer tearDown() + + ctx := context.Background() + + // 定義多筆測試資料(不包含 ID、CreatedAt、UpdatedAt,由 Create 自動填入) + statsList := []*entity.ProductStatistics{ + { + ProductID: "prod-001", + Orders: 100, + AverageRating: 4.5, + FansCount: 200, + }, + { + ProductID: "prod-002", + Orders: 50, + AverageRating: 3.8, + FansCount: 150, + }, + } + + // 逐筆呼叫 Create 新增資料 + for _, ps := range statsList { + err := repo.Create(ctx, ps) + assert.NoError(t, err, "Create should not return error") + } + + // 驗證每筆資料的自動欄位與內容 + for _, ps := range statsList { + // 檢查 ID 與時間欄位是否有自動填入 + assert.False(t, ps.ID.IsZero(), "ID should be generated") + assert.NotZero(t, ps.CreatedAt, "CreatedAt should be set") + assert.NotZero(t, ps.UpdatedAt, "UpdatedAt should be set") + + // 查詢 DB 確認資料是否存在且欄位值正確 + result, err := repo.GetByID(ctx, ps.ID.Hex()) + assert.NoError(t, err, "GetByID should not return error") + assert.Equal(t, ps.ProductID, result.ProductID, "ProductID should match") + assert.Equal(t, ps.Orders, result.Orders, "Orders should match") + assert.Equal(t, ps.AverageRating, result.AverageRating, "AverageRating should match") + assert.Equal(t, ps.FansCount, result.FansCount, "FansCount should match") + } + +} + +func TestGetProductStatisticsByProductID(t *testing.T) { + repo, tearDown, err := SetupTestProductStatisticsRepo("testDB") + require.NoError(t, err) + defer tearDown() + + ctx := context.Background() + + stats1 := &entity.ProductStatistics{ + ProductID: "prod-001", + Orders: 100, + AverageRating: 4.5, + FansCount: 200, + } + stats2 := &entity.ProductStatistics{ + ProductID: "prod-002", + Orders: 50, + AverageRating: 3.8, + FansCount: 150, + } + + err = repo.Create(ctx, stats1) + require.NoError(t, err) + err = repo.Create(ctx, stats2) + require.NoError(t, err) + + // 定義 table-driven 測試案例 + tests := []struct { + name string + inputProductID string + expectedErr error + expectedStatistics *entity.ProductStatistics + }{ + { + name: "record exists - prod-001", + inputProductID: "prod-001", + expectedErr: nil, + expectedStatistics: stats1, + }, + { + name: "record exists - prod-002", + inputProductID: "prod-002", + expectedErr: nil, + expectedStatistics: stats2, + }, + { + name: "record not found", + inputProductID: "non-existent", + expectedErr: ErrNotFound, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + result, err := repo.GetByProductID(ctx, tt.inputProductID) + if tt.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, tt.expectedErr, err) + assert.Nil(t, result) + } else { + require.NoError(t, err) + // 驗證回傳的資料是否符合預期 + assert.Equal(t, tt.expectedStatistics.ProductID, result.ProductID) + assert.Equal(t, tt.expectedStatistics.Orders, result.Orders) + assert.Equal(t, tt.expectedStatistics.AverageRating, result.AverageRating) + assert.Equal(t, tt.expectedStatistics.FansCount, result.FansCount) + // 其他自動產生欄位如 ID、CreatedAt、UpdatedAt 也可檢查非零 + assert.False(t, result.ID.IsZero(), "ID should be generated") + assert.NotZero(t, result.CreatedAt, "CreatedAt should be set") + assert.NotZero(t, result.UpdatedAt, "UpdatedAt should be set") + } + }) + } +} + +func TestIncOrders(t *testing.T) { + repo, tearDown, err := SetupTestProductStatisticsRepo("testDB") + require.NoError(t, err) + defer tearDown() + + ctx := context.Background() + + // 插入一筆測試資料:初始 Orders 為 100 + stats := &entity.ProductStatistics{ + ProductID: "prod-001", + Orders: 100, + AverageRating: 4.5, + FansCount: 200, + } + err = repo.Create(ctx, stats) + require.NoError(t, err) + + tests := []struct { + name string + productID string + increment int64 + expectedOrders uint64 // 預期的 Orders 值 + expectErr bool + }{ + { + name: "Valid increment", + productID: "prod-001", + increment: 10, + expectedOrders: 110, + expectErr: false, + }, + { + name: "Non-existent product", + productID: "prod-not-exist", + increment: 5, + // 當產品不存在時,應回傳錯誤,不檢查 Orders + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + err := repo.IncOrders(ctx, tt.productID, tt.increment) + if tt.expectErr { + require.Error(t, err) + // 可進一步檢查錯誤訊息中是否包含特定關鍵字 + assert.Contains(t, err.Error(), "failed to get product_id by id") + } else { + require.NoError(t, err) + // 若成功,利用 GetByProductID 取得最新資料,檢查 Orders 是否正確更新 + ps, err := repo.GetByProductID(ctx, tt.productID) + require.NoError(t, err) + assert.Equal(t, tt.expectedOrders, ps.Orders) + } + }) + } +} + +func TestDecOrders(t *testing.T) { + repo, tearDown, err := SetupTestProductStatisticsRepo("testDB") + require.NoError(t, err) + defer tearDown() + + ctx := context.Background() + + // 插入一筆測試資料,初始 Orders 為 100 + stats := &entity.ProductStatistics{ + ProductID: "prod-001", + Orders: 100, + AverageRating: 4.5, + FansCount: 200, + } + err = repo.Create(ctx, stats) + require.NoError(t, err) + + tests := []struct { + name string + productID string + decrement int64 + expectedOrders uint64 // 減少成功後預期的 Orders 值 + expectErr bool + }{ + { + name: "Valid decrease", + productID: "prod-001", + decrement: 30, + expectedOrders: 70, // 100 - 30 + expectErr: false, + }, + { + name: "Insufficient orders", + productID: "prod-001", + decrement: 150, // 超過現有數量 + expectedOrders: 70, // 100 - 30 + expectErr: false, + }, + { + name: "Non-existent product", + productID: "prod-not-exist", + decrement: 10, + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + err := repo.DecOrders(ctx, tt.productID, tt.decrement) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + // 成功減少後,利用 GetByProductID 取得最新資料 + updated, err := repo.GetByProductID(ctx, tt.productID) + require.NoError(t, err) + assert.Equal(t, tt.expectedOrders, updated.Orders) + } + }) + } +} + +func TestUpdateAverageRating(t *testing.T) { + repo, tearDown, err := SetupTestProductStatisticsRepo("testDB") + require.NoError(t, err) + defer tearDown() + + ctx := context.Background() + + // 插入一筆測試資料,初始 AverageRating 為 4.0 + stats := &entity.ProductStatistics{ + ProductID: "prod-001", + Orders: 100, + AverageRating: 4.0, + FansCount: 50, + } + err = repo.Create(ctx, stats) + require.NoError(t, err) + + tests := []struct { + name string + productID string + newRating float64 + expectErr bool + expectedValue float64 // 預期更新後的 AverageRating + }{ + { + name: "Update existing product", + productID: "prod-001", + newRating: 4.8, + expectErr: false, + expectedValue: 4.8, + }, + { + name: "Non-existent product", + productID: "prod-nonexist", + newRating: 3.5, + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + err := repo.UpdateAverageRating(ctx, tt.productID, tt.newRating) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + // 取得更新後的資料 + updated, err := repo.GetByProductID(ctx, tt.productID) + require.NoError(t, err) + assert.Equal(t, tt.expectedValue, updated.AverageRating, "AverageRating should be updated") + // 驗證更新時間不為 0 + assert.NotZero(t, updated.AverageRatingUpdateTime, "AverageRatingUpdateTime should be set") + assert.NotZero(t, updated.UpdatedAt, "UpdatedAt should be set") + } + }) + } +} + +func TestIncFansCount(t *testing.T) { + repo, tearDown, err := SetupTestProductStatisticsRepo("testDB") + require.NoError(t, err) + defer tearDown() + + ctx := context.Background() + + // 插入一筆測試資料,初始 FansCount 為 200 + stats := &entity.ProductStatistics{ + ProductID: "prod-001", + Orders: 100, + AverageRating: 4.0, + FansCount: 200, + } + err = repo.Create(ctx, stats) + require.NoError(t, err) + + tests := []struct { + name string + productID string + incCount uint64 + expectedCount uint64 // 預期更新後的 FansCount + expectErr bool + }{ + { + name: "Valid increment", + productID: "prod-001", + incCount: 50, + expectedCount: 250, // 200 + 50 + expectErr: false, + }, + { + name: "Non-existent product", + productID: "non-existent", + incCount: 10, + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + err := repo.IncFansCount(ctx, tt.productID, tt.incCount) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + // 取得更新後的資料,驗證 FansCount 是否正確更新 + updated, err := repo.GetByProductID(ctx, tt.productID) + require.NoError(t, err) + assert.Equal(t, tt.expectedCount, updated.FansCount, "FansCount should be incremented correctly") + assert.NotZero(t, updated.FansCountUpdateTime, "FansCountUpdateTime should be set") + assert.NotZero(t, updated.UpdatedAt, "UpdatedAt should be set") + } + }) + } +} + +func TestDecFansCount(t *testing.T) { + repo, tearDown, err := SetupTestProductStatisticsRepo("testDB") + require.NoError(t, err) + defer tearDown() + + ctx := context.Background() + + // 先建立兩筆測試資料 + + // 測試 valid case:初始 FansCount 為 200 + statsValid := &entity.ProductStatistics{ + ProductID: "prod-valid", + Orders: 100, + AverageRating: 4.5, + FansCount: 200, + } + err = repo.Create(ctx, statsValid) + require.NoError(t, err) + + // 測試 insufficient case:初始 FansCount 為 30 + statsInsufficient := &entity.ProductStatistics{ + ProductID: "prod-insufficient", + Orders: 50, + AverageRating: 3.8, + FansCount: 30, + } + err = repo.Create(ctx, statsInsufficient) + require.NoError(t, err) + + tests := []struct { + name string + productID string + decrement uint64 + expectedFans uint64 // 預期更新後的 FansCount (僅 valid case) + expectErr bool + }{ + { + name: "Valid decrement", + productID: "prod-valid", + decrement: 50, + expectedFans: 150, // 200 - 50 + expectErr: false, + }, + { + name: "Insufficient fans", + productID: "prod-insufficient", + decrement: 50, + expectedFans: 30, // 扣超過就跟原本一樣不會變 + expectErr: false, + }, + { + name: "Non-existent product", + productID: "prod-nonexistent", + decrement: 10, + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + err := repo.DecFansCount(ctx, tt.productID, tt.decrement) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + // 取得更新後的資料,驗證 FansCount 是否正確更新 + updated, err := repo.GetByProductID(ctx, tt.productID) + require.NoError(t, err) + assert.Equal(t, tt.expectedFans, updated.FansCount, "FansCount should be decremented correctly") + } + }) + } +} + +func TestDeleteProductStatistics(t *testing.T) { + repo, tearDown, err := SetupTestProductStatisticsRepo("testDB") + require.NoError(t, err) + defer tearDown() + + ctx := context.Background() + + // 插入一筆測試資料 + stats := &entity.ProductStatistics{ + ProductID: "prod-001", + Orders: 100, + AverageRating: 4.5, + FansCount: 200, + } + err = repo.Create(ctx, stats) + require.NoError(t, err) + + tests := []struct { + name string + id string + expectErr error // 預期錯誤 + checkDelete bool // 若為 true,刪除成功後進行資料查詢驗證 + }{ + { + name: "Delete existing record", + id: stats.ID.Hex(), + expectErr: nil, + checkDelete: true, + }, + { + name: "Invalid ObjectID format", + id: "invalid-id", + expectErr: ErrInvalidObjectID, + checkDelete: false, + }, + } + + for _, tc := range tests { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + err := repo.Delete(ctx, tc.id) + if tc.expectErr != nil { + require.Error(t, err) + assert.Equal(t, tc.expectErr, err) + } else { + require.NoError(t, err) + if tc.checkDelete { + // 刪除成功後,透過 GetByID 應查無資料 + _, err := repo.GetByID(ctx, tc.id) + require.Error(t, err) + assert.Equal(t, ErrNotFound, err) + } + } + }) + } +}