801 lines
20 KiB
Go
801 lines
20 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"testing"
|
||
"time"
|
||
|
||
"code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/entity"
|
||
"code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/product"
|
||
"code.30cm.net/digimon/app-cloudep-product-service/pkg/domain/repository"
|
||
mgo "code.30cm.net/digimon/library-go/mongo"
|
||
"github.com/alicebob/miniredis/v2"
|
||
"github.com/shopspring/decimal"
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/zeromicro/go-zero/core/stores/cache"
|
||
"github.com/zeromicro/go-zero/core/stores/mon"
|
||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||
"google.golang.org/protobuf/proto"
|
||
)
|
||
|
||
func SetupTestProductItemRepository(db string) (repository.ProductItemRepository, func(), error) {
|
||
h, p, tearDown, err := startMongoContainer()
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
s, _ := miniredis.Run()
|
||
|
||
conf := &mgo.Conf{
|
||
Schema: Schema,
|
||
Host: fmt.Sprintf("%s:%s", h, p),
|
||
Database: db,
|
||
MaxStaleness: 300,
|
||
MaxPoolSize: 100,
|
||
MinPoolSize: 100,
|
||
MaxConnIdleTime: 300,
|
||
Compressors: []string{},
|
||
EnableStandardReadWriteSplitMode: false,
|
||
ConnectTimeoutMs: 3000,
|
||
}
|
||
|
||
cacheConf := cache.CacheConf{
|
||
cache.NodeConf{
|
||
RedisConf: redis.RedisConf{
|
||
Host: s.Addr(),
|
||
Type: redis.NodeType,
|
||
},
|
||
Weight: 100,
|
||
},
|
||
}
|
||
|
||
cacheOpts := []cache.Option{
|
||
cache.WithExpiry(1000 * time.Microsecond),
|
||
cache.WithNotFoundExpiry(1000 * time.Microsecond),
|
||
}
|
||
|
||
param := ProductItemRepositoryParam{
|
||
Conf: conf,
|
||
CacheConf: cacheConf,
|
||
CacheOpts: cacheOpts,
|
||
DBOpts: []mon.Option{
|
||
mgo.SetCustomDecimalType(),
|
||
mgo.InitMongoOptions(*conf),
|
||
},
|
||
}
|
||
|
||
repo := NewProductItemRepository(param)
|
||
_, _ = repo.Index20250317001UP(context.Background())
|
||
|
||
return repo, tearDown, nil
|
||
}
|
||
|
||
func TestInsertProductItems(t *testing.T) {
|
||
model, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
defer tearDown()
|
||
assert.NoError(t, err)
|
||
|
||
// 建立多筆 items
|
||
items := []entity.ProductItems{
|
||
{
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Item A",
|
||
},
|
||
{
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Item B",
|
||
},
|
||
}
|
||
|
||
// 呼叫插入
|
||
err = model.Insert(context.Background(), items)
|
||
assert.NoError(t, err)
|
||
|
||
// 驗證插入是否成功(逐筆查詢)
|
||
for _, item := range items {
|
||
// 檢查 ID 是否自動填入
|
||
assert.False(t, item.ID.IsZero(), "ID should be generated")
|
||
assert.NotZero(t, item.CreatedAt)
|
||
assert.NotZero(t, item.UpdatedAt)
|
||
|
||
// 查詢 DB 確認存在
|
||
var result *entity.ProductItems
|
||
result, err = model.FindByID(context.Background(), item.ID.Hex())
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, item.Name, result.Name)
|
||
}
|
||
|
||
// 🧪 空陣列插入測試(不應報錯)
|
||
t.Run("Insert empty slice", func(t *testing.T) {
|
||
err := model.Insert(context.Background(), []entity.ProductItems{})
|
||
assert.NoError(t, err)
|
||
})
|
||
}
|
||
|
||
func TestDeleteProductItem(t *testing.T) {
|
||
model, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
defer tearDown()
|
||
assert.NoError(t, err)
|
||
|
||
// 建立多筆 items
|
||
items := []entity.ProductItems{
|
||
{
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Item A",
|
||
},
|
||
{
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Item B",
|
||
},
|
||
}
|
||
|
||
// 呼叫插入
|
||
err = model.Insert(context.Background(), items)
|
||
assert.NoError(t, err)
|
||
|
||
tests := []struct {
|
||
name string
|
||
inputID string
|
||
expectErr error
|
||
checkAfter bool // 是否需要確認資料已刪除
|
||
}{
|
||
{
|
||
name: "Delete existing product",
|
||
inputID: items[0].ID.Hex(),
|
||
expectErr: nil,
|
||
checkAfter: true,
|
||
},
|
||
{
|
||
name: "Delete non-existing product",
|
||
inputID: primitive.NewObjectID().Hex(),
|
||
expectErr: nil,
|
||
checkAfter: false,
|
||
},
|
||
{
|
||
name: "Invalid ObjectID format",
|
||
inputID: "not-an-object-id",
|
||
expectErr: ErrInvalidObjectID,
|
||
checkAfter: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := model.Delete(context.Background(), []string{tt.inputID})
|
||
if tt.expectErr != nil {
|
||
assert.ErrorIs(t, err, tt.expectErr)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
// 驗證資料是否真的刪除了(僅當需要)
|
||
if tt.checkAfter {
|
||
_, err := model.FindByID(context.Background(), tt.inputID)
|
||
assert.ErrorIs(t, err, ErrNotFound)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestListProductItem(t *testing.T) {
|
||
model, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
defer tearDown()
|
||
assert.NoError(t, err)
|
||
|
||
now := time.Now().UnixNano()
|
||
rfcID := primitive.NewObjectID().Hex()
|
||
// 插入測試資料
|
||
item1 := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: rfcID,
|
||
Name: "Item A",
|
||
IsFree: true,
|
||
Status: product.StatusActive,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
item2 := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: rfcID,
|
||
Name: "Item B",
|
||
IsFree: false,
|
||
Status: product.StatusInactive,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
item3 := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: rfcID,
|
||
Name: "Item C",
|
||
IsFree: true,
|
||
Status: product.StatusActive,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
|
||
err = model.Insert(context.Background(), []entity.ProductItems{item1, item2, item3})
|
||
assert.NoError(t, err)
|
||
|
||
tests := []struct {
|
||
name string
|
||
params repository.ProductItemQueryParams
|
||
expectCount int64
|
||
expectIDs []primitive.ObjectID
|
||
}{
|
||
{
|
||
name: "Filter by ReferenceID",
|
||
params: repository.ProductItemQueryParams{
|
||
ReferenceID: ptr(rfcID),
|
||
PageSize: 10,
|
||
PageIndex: 1,
|
||
},
|
||
expectCount: 3,
|
||
expectIDs: []primitive.ObjectID{item3.ID, item2.ID, item1.ID},
|
||
},
|
||
{
|
||
name: "Filter by IsFree = true",
|
||
params: repository.ProductItemQueryParams{
|
||
IsFree: ptr(true),
|
||
PageSize: 10,
|
||
PageIndex: 1,
|
||
},
|
||
expectCount: 2,
|
||
expectIDs: []primitive.ObjectID{item3.ID, item1.ID},
|
||
},
|
||
{
|
||
name: "Filter by Status = 2",
|
||
params: repository.ProductItemQueryParams{
|
||
Status: ptr(product.StatusInactive),
|
||
PageSize: 10,
|
||
PageIndex: 1,
|
||
},
|
||
expectCount: 1,
|
||
expectIDs: []primitive.ObjectID{item2.ID},
|
||
},
|
||
{
|
||
name: "Filter by ItemIDs",
|
||
params: repository.ProductItemQueryParams{
|
||
ItemID: []string{item1.ID.Hex(), item2.ID.Hex()},
|
||
PageSize: 10,
|
||
PageIndex: 1,
|
||
},
|
||
expectCount: 2,
|
||
expectIDs: []primitive.ObjectID{item2.ID, item1.ID},
|
||
},
|
||
{
|
||
name: "Pagination works",
|
||
params: repository.ProductItemQueryParams{
|
||
PageSize: 1,
|
||
PageIndex: 2,
|
||
},
|
||
expectCount: 3,
|
||
expectIDs: []primitive.ObjectID{item2.ID},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
results, count, err := model.ListProductItem(context.Background(), tt.params)
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, tt.expectCount, count)
|
||
|
||
var gotIDs []primitive.ObjectID
|
||
for _, r := range results {
|
||
gotIDs = append(gotIDs, r.ID)
|
||
}
|
||
|
||
assert.ElementsMatch(t, tt.expectIDs, gotIDs)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestDeleteByReferenceID(t *testing.T) {
|
||
model, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
defer tearDown()
|
||
assert.NoError(t, err)
|
||
|
||
ctx := context.Background()
|
||
now := time.Now().UnixNano()
|
||
|
||
// 插入測試資料
|
||
refID := primitive.NewObjectID().Hex()
|
||
item1 := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: refID,
|
||
Name: "Item A",
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
item2 := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: refID,
|
||
Name: "Item B",
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
itemOther := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Should not be deleted",
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
|
||
err = model.Insert(ctx, []entity.ProductItems{item1, item2, itemOther})
|
||
assert.NoError(t, err)
|
||
|
||
tests := []struct {
|
||
name string
|
||
referenceID string
|
||
expectDeleted []primitive.ObjectID
|
||
expectRemained []primitive.ObjectID
|
||
}{
|
||
{
|
||
name: "Delete existing reference_id items",
|
||
referenceID: refID,
|
||
expectDeleted: []primitive.ObjectID{item1.ID, item2.ID},
|
||
expectRemained: []primitive.ObjectID{itemOther.ID},
|
||
},
|
||
{
|
||
name: "Delete non-existent reference_id",
|
||
referenceID: "no-match-ref",
|
||
expectDeleted: nil,
|
||
expectRemained: []primitive.ObjectID{itemOther.ID},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := model.DeleteByReferenceID(ctx, tt.referenceID)
|
||
assert.NoError(t, err)
|
||
|
||
// 檢查指定應被刪除的項目是否真的被刪除
|
||
for _, id := range tt.expectDeleted {
|
||
_, err = model.FindByID(ctx, id.Hex())
|
||
assert.Error(t, err, "Expected item to be deleted: "+id.Hex())
|
||
}
|
||
|
||
// 檢查不應刪除的項目仍存在
|
||
for _, id := range tt.expectRemained {
|
||
_, err = model.FindByID(ctx, id.Hex())
|
||
assert.NoError(t, err, "Expected item to remain: "+id.Hex())
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestUpdateProductItem(t *testing.T) {
|
||
model, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
defer tearDown()
|
||
assert.NoError(t, err)
|
||
|
||
ctx := context.Background()
|
||
|
||
// 建立一筆測試資料
|
||
item := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Original Name",
|
||
Stock: 10,
|
||
Price: decimal.NewFromInt(500),
|
||
IsUnLimit: false,
|
||
IsFree: false,
|
||
CreatedAt: time.Now().UnixNano(),
|
||
UpdatedAt: time.Now().UnixNano(),
|
||
}
|
||
err = model.Insert(ctx, []entity.ProductItems{item})
|
||
assert.NoError(t, err)
|
||
|
||
tests := []struct {
|
||
name string
|
||
id string
|
||
update *repository.ProductUpdateItem
|
||
expectErr error
|
||
validate func(t *testing.T, updated *entity.ProductItems)
|
||
}{
|
||
{
|
||
name: "Update is_un_limit and is_free",
|
||
id: item.ID.Hex(),
|
||
update: &repository.ProductUpdateItem{
|
||
IsUnLimit: ptr(true),
|
||
IsFree: ptr(true),
|
||
},
|
||
expectErr: nil,
|
||
validate: func(t *testing.T, updated *entity.ProductItems) {
|
||
assert.True(t, updated.IsUnLimit)
|
||
assert.True(t, updated.IsFree)
|
||
},
|
||
},
|
||
{
|
||
name: "Update SKU and TimeSeries",
|
||
id: item.ID.Hex(),
|
||
update: &repository.ProductUpdateItem{
|
||
SKU: ptr("SKU-XYZ-001"),
|
||
TimeSeries: ptr(product.TimeSeriesTenMinutes),
|
||
},
|
||
expectErr: nil,
|
||
validate: func(t *testing.T, updated *entity.ProductItems) {
|
||
assert.Equal(t, "SKU-XYZ-001", updated.SKU)
|
||
assert.Equal(t, product.TimeSeriesTenMinutes, updated.TimeSeries)
|
||
},
|
||
},
|
||
{
|
||
name: "Update Media field",
|
||
id: item.ID.Hex(),
|
||
update: &repository.ProductUpdateItem{
|
||
Media: []entity.Media{
|
||
{Sort: 1, Type: "image", URL: "https://example.com/img.jpg"},
|
||
},
|
||
},
|
||
expectErr: nil,
|
||
validate: func(t *testing.T, updated *entity.ProductItems) {
|
||
assert.Len(t, updated.Media, 1)
|
||
assert.Equal(t, "image", updated.Media[0].Type)
|
||
},
|
||
},
|
||
{
|
||
name: "Update CustomFields",
|
||
id: item.ID.Hex(),
|
||
update: &repository.ProductUpdateItem{
|
||
CustomFields: []entity.CustomFields{
|
||
{Key: "color", Value: "red"},
|
||
},
|
||
},
|
||
expectErr: nil,
|
||
validate: func(t *testing.T, updated *entity.ProductItems) {
|
||
assert.Len(t, updated.CustomFields, 1)
|
||
assert.Equal(t, "color", updated.CustomFields[0].Key)
|
||
},
|
||
},
|
||
{
|
||
name: "Update Freight",
|
||
id: item.ID.Hex(),
|
||
update: &repository.ProductUpdateItem{
|
||
Freight: []entity.CustomFields{
|
||
{
|
||
Key: "color",
|
||
Value: "red",
|
||
},
|
||
},
|
||
},
|
||
expectErr: nil,
|
||
validate: func(t *testing.T, updated *entity.ProductItems) {
|
||
assert.Equal(t, "color", updated.Freight[0].Key)
|
||
assert.Equal(t, "red", updated.Freight[0].Value)
|
||
},
|
||
},
|
||
{
|
||
name: "Update name field",
|
||
id: item.ID.Hex(),
|
||
update: &repository.ProductUpdateItem{
|
||
Name: ptr("Updated Name"),
|
||
},
|
||
expectErr: nil,
|
||
validate: func(t *testing.T, updated *entity.ProductItems) {
|
||
assert.Equal(t, "Updated Name", updated.Name)
|
||
},
|
||
},
|
||
{
|
||
name: "Update stock and price",
|
||
id: item.ID.Hex(),
|
||
update: &repository.ProductUpdateItem{
|
||
Stock: proto.Int64(99),
|
||
Price: ptr(decimal.NewFromInt(999)),
|
||
},
|
||
expectErr: nil,
|
||
validate: func(t *testing.T, updated *entity.ProductItems) {
|
||
assert.Equal(t, uint64(99), updated.Stock)
|
||
assert.Equal(t, "999", updated.Price.String())
|
||
},
|
||
},
|
||
{
|
||
name: "Invalid ObjectID",
|
||
id: "not-an-id",
|
||
update: &repository.ProductUpdateItem{Name: ptr("Invalid")},
|
||
expectErr: ErrInvalidObjectID,
|
||
},
|
||
{
|
||
name: "Empty update struct",
|
||
id: item.ID.Hex(),
|
||
update: &repository.ProductUpdateItem{}, // no fields
|
||
expectErr: fmt.Errorf("no fields to update"),
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := model.Update(ctx, tt.id, tt.update)
|
||
if tt.expectErr != nil {
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), tt.expectErr.Error())
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
|
||
updated, err := model.FindByID(ctx, item.ID.Hex())
|
||
assert.NoError(t, err)
|
||
|
||
if tt.validate != nil {
|
||
tt.validate(t, updated)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestUpdateStatus_TableDriven(t *testing.T) {
|
||
repo, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
require.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
ctx := context.Background()
|
||
|
||
// Insert a sample product item.
|
||
now := time.Now().UnixNano()
|
||
item := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Test Item",
|
||
Status: product.StatusInactive, // initial status
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
err = repo.Insert(ctx, []entity.ProductItems{item})
|
||
require.NoError(t, err)
|
||
|
||
tests := []struct {
|
||
name string
|
||
id string
|
||
newStatus product.ItemStatus
|
||
expectErr error
|
||
check bool // whether to verify the update in DB
|
||
}{
|
||
{
|
||
name: "Valid update",
|
||
id: item.ID.Hex(),
|
||
newStatus: product.StatusActive,
|
||
expectErr: nil,
|
||
check: true,
|
||
},
|
||
{
|
||
name: "Invalid ObjectID",
|
||
id: "invalid-id",
|
||
newStatus: product.StatusActive,
|
||
expectErr: ErrInvalidObjectID,
|
||
check: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := repo.UpdateStatus(ctx, tt.id, tt.newStatus)
|
||
if tt.expectErr != nil {
|
||
assert.ErrorIs(t, err, tt.expectErr)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
// If expected to check update, verify that the product's status is updated.
|
||
if tt.check {
|
||
updated, err := repo.FindByID(ctx, tt.id)
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, tt.newStatus, updated.Status)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestIncSalesCount(t *testing.T) {
|
||
repo, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
require.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
ctx := context.Background()
|
||
|
||
// Insert a sample product item with initial SalesCount = 0.
|
||
now := time.Now().UnixNano()
|
||
item := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Sales Count Test Item",
|
||
SalesCount: 0,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
err = repo.Insert(ctx, []entity.ProductItems{item})
|
||
require.NoError(t, err)
|
||
|
||
tests := []struct {
|
||
name string
|
||
id string
|
||
count int64
|
||
expectErr error
|
||
check bool // whether to verify the updated sales count in the DB.
|
||
}{
|
||
{
|
||
name: "Increment sales count by 5",
|
||
id: item.ID.Hex(),
|
||
count: 5,
|
||
expectErr: nil,
|
||
check: true,
|
||
},
|
||
{
|
||
name: "Invalid ObjectID",
|
||
id: "invalid-id",
|
||
count: 3,
|
||
expectErr: ErrInvalidObjectID,
|
||
check: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := repo.IncSalesCount(ctx, tt.id, tt.count)
|
||
if tt.expectErr != nil {
|
||
assert.ErrorIs(t, err, tt.expectErr)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
// If check is true, verify that the sales_count is updated correctly.
|
||
if tt.check {
|
||
updated, err := repo.FindByID(ctx, tt.id)
|
||
assert.NoError(t, err)
|
||
// Since initial SalesCount was 0, after increment it should equal tt.count.
|
||
assert.Equal(t, uint64(tt.count), updated.SalesCount)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestDecSalesCount(t *testing.T) {
|
||
repo, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
require.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
ctx := context.Background()
|
||
now := time.Now().UnixNano()
|
||
|
||
// Insert an item with an initial SalesCount of 10.
|
||
item := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Dec Sales Count Test Item",
|
||
SalesCount: 10,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
|
||
// Insert an item with SalesCount equal to 0 (to test underflow behavior).
|
||
zeroItem := entity.ProductItems{
|
||
ID: primitive.NewObjectID(),
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Zero Sales Count Item",
|
||
SalesCount: 0,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
err = repo.Insert(ctx, []entity.ProductItems{item, zeroItem})
|
||
require.NoError(t, err)
|
||
|
||
tests := []struct {
|
||
name string
|
||
id string
|
||
decCount int64
|
||
expectErr error
|
||
check bool // whether to verify the updated sales count in the DB
|
||
expected uint64 // expected SalesCount if check is true
|
||
}{
|
||
{
|
||
name: "Valid decrement from 10",
|
||
id: item.ID.Hex(),
|
||
decCount: 3,
|
||
expectErr: nil,
|
||
check: true,
|
||
expected: 7,
|
||
},
|
||
{
|
||
name: "Invalid ObjectID",
|
||
id: "invalid-id",
|
||
decCount: 3,
|
||
expectErr: ErrInvalidObjectID,
|
||
check: false,
|
||
},
|
||
{
|
||
name: "Decrement from zero (should not allow underflow)",
|
||
id: zeroItem.ID.Hex(),
|
||
decCount: 5,
|
||
expectErr: nil,
|
||
check: true,
|
||
expected: 0,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := repo.DecSalesCount(ctx, tt.id, tt.decCount)
|
||
if tt.expectErr != nil {
|
||
assert.Error(t, err)
|
||
// Check that the error message contains "failed to decrease stock" (for the underflow case)
|
||
if tt.expectErr.Error() != ErrInvalidObjectID.Error() {
|
||
assert.Contains(t, err.Error(), "failed to decrease stock")
|
||
} else {
|
||
assert.ErrorIs(t, err, tt.expectErr)
|
||
}
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
// If expected to check, verify the updated SalesCount.
|
||
if tt.check {
|
||
updated, err := repo.FindByID(ctx, tt.id)
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, tt.expected, updated.SalesCount)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestGetSalesCount(t *testing.T) {
|
||
// 取得測試 repository
|
||
repo, tearDown, err := SetupTestProductItemRepository("testDB")
|
||
require.NoError(t, err)
|
||
defer tearDown()
|
||
|
||
ctx := context.Background()
|
||
|
||
// 預先建立測試資料,設定各項目的 SalesCount 值
|
||
testItems := []entity.ProductItems{
|
||
{
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Test Item 1",
|
||
SalesCount: 5,
|
||
},
|
||
{
|
||
ReferenceID: primitive.NewObjectID().Hex(),
|
||
Name: "Test Item 2",
|
||
SalesCount: 15,
|
||
},
|
||
}
|
||
|
||
// 插入測試資料
|
||
err = repo.Insert(ctx, testItems)
|
||
require.NoError(t, err)
|
||
|
||
// 建立一組包含有效與無效 ID 的字串陣列
|
||
validIDs := []string{
|
||
testItems[0].ID.Hex(),
|
||
testItems[1].ID.Hex(),
|
||
}
|
||
// 在陣列前面加上一個無法轉換的 ID
|
||
mixedIDs := append([]string{"invalidID"}, validIDs...)
|
||
|
||
t.Run("with valid and invalid IDs", func(t *testing.T) {
|
||
salesCounts, err := repo.GetSalesCount(ctx, mixedIDs)
|
||
require.NoError(t, err)
|
||
// 預期只會回傳有效的兩筆資料
|
||
require.Len(t, salesCounts, len(validIDs))
|
||
|
||
// 驗證每筆資料的 SalesCount 是否正確
|
||
for _, sc := range salesCounts {
|
||
switch sc.ID {
|
||
case testItems[0].ID.Hex():
|
||
assert.Equal(t, uint64(5), sc.Count)
|
||
case testItems[1].ID.Hex():
|
||
assert.Equal(t, uint64(15), sc.Count)
|
||
default:
|
||
t.Errorf("unexpected product item ID: %s", sc.ID)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("with all invalid IDs", func(t *testing.T) {
|
||
salesCounts, err := repo.GetSalesCount(ctx, []string{"badid1", "badid2"})
|
||
// 因無法轉換成 ObjectID,預期會回傳 ErrInvalidObjectID
|
||
assert.Error(t, err)
|
||
assert.Equal(t, ErrInvalidObjectID, err)
|
||
assert.Nil(t, salesCounts)
|
||
})
|
||
}
|