backend/pkg/post/repository/category.go

264 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"fmt"
"strings"
"backend/pkg/library/cassandra"
"backend/pkg/post/domain/entity"
domainRepo "backend/pkg/post/domain/repository"
"github.com/gocql/gocql"
)
// CategoryRepositoryParam 定義 CategoryRepository 的初始化參數
type CategoryRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// CategoryRepository 實作 domain repository 介面
type CategoryRepository struct {
repo cassandra.Repository[*entity.Category]
db *cassandra.DB
keyspace string
}
// NewCategoryRepository 創建新的 CategoryRepository
func NewCategoryRepository(param CategoryRepositoryParam) domainRepo.CategoryRepository {
repo, err := cassandra.NewRepository[*entity.Category](param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create category repository: %v", err))
}
keyspace := param.Keyspace
if keyspace == "" {
keyspace = param.DB.GetDefaultKeyspace()
}
return &CategoryRepository{
repo: repo,
db: param.DB,
keyspace: keyspace,
}
}
// Insert 插入單筆分類
func (r *CategoryRepository) Insert(ctx context.Context, data *entity.Category) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
if data.ParentID == nil {
data.ParentID = &gocql.UUID{}
}
// 設置時間戳
data.SetTimestamps()
// 如果是新分類,生成 ID
if data.IsNew() {
data.ID = gocql.TimeUUID()
}
// Slug 轉為小寫
data.Slug = strings.ToLower(strings.TrimSpace(data.Slug))
return r.repo.Insert(ctx, data)
}
// FindOne 根據 ID 查詢單筆分類
func (r *CategoryRepository) FindOne(ctx context.Context, id string) (*entity.Category, error) {
var zeroUUID gocql.UUID
uuid, err := gocql.ParseUUID(id)
if err != nil {
return nil, err
}
if uuid == zeroUUID {
return nil, ErrInvalidInput
}
category, err := r.repo.Get(ctx, id)
if err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to find category: %w", err)
}
return category, nil
}
// Update 更新分類
func (r *CategoryRepository) Update(ctx context.Context, data *entity.Category) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 更新時間戳
data.SetTimestamps()
// Slug 轉為小寫
data.Slug = strings.ToLower(strings.TrimSpace(data.Slug))
return r.repo.Update(ctx, data)
}
// Delete 刪除分類
func (r *CategoryRepository) Delete(ctx context.Context, id string) error {
var zeroUUID gocql.UUID
uuid, err := gocql.ParseUUID(id)
if err != nil {
return err
}
if uuid == zeroUUID {
return ErrInvalidInput
}
return r.repo.Delete(ctx, id)
}
// FindBySlug 根據 slug 查詢分類
func (r *CategoryRepository) FindBySlug(ctx context.Context, slug string) (*entity.Category, error) {
if slug == "" {
return nil, ErrInvalidInput
}
// 標準化 slug
slug = strings.ToLower(strings.TrimSpace(slug))
// 構建查詢(要有 SAI 索引在 slug 欄位上)
query := r.repo.Query().Where(cassandra.Eq("slug", slug))
var categories []*entity.Category
if err := query.Scan(ctx, &categories); err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to query category: %w", err)
}
if len(categories) == 0 {
return nil, ErrNotFound
}
return categories[0], nil
}
// FindByParentID 根據父分類 ID 查詢子分類
func (r *CategoryRepository) FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error) {
query := r.repo.Query()
var zeroUUID gocql.UUID
if parentID != "" {
// 構建查詢(有 SAI 索引在 parentID 欄位上)
uuid, err := gocql.ParseUUID(parentID)
if err != nil {
return nil, err
}
if uuid != zeroUUID {
query = query.Where(cassandra.Eq("parent_id", uuid))
}
} else {
query = query.Where(cassandra.Eq("parent_id", zeroUUID))
}
// 按 sort_order 排序
query = query.OrderBy("sort_order", cassandra.ASC)
var categories []*entity.Category
if err := query.Scan(ctx, &categories); err != nil {
return nil, fmt.Errorf("failed to query categories: %w", err)
}
return categories, nil
}
// FindRootCategories 查詢根分類
func (r *CategoryRepository) FindRootCategories(ctx context.Context) ([]*entity.Category, error) {
return r.FindByParentID(ctx, "")
}
// FindActive 查詢啟用的分類
func (r *CategoryRepository) FindActive(ctx context.Context) ([]*entity.Category, error) {
query := r.repo.Query().
Where(cassandra.Eq("is_active", true)).
OrderBy("sort_order", cassandra.ASC)
var categories []*entity.Category
if err := query.Scan(ctx, &categories); err != nil {
return nil, fmt.Errorf("failed to query active categories: %w", err)
}
result := categories
return result, nil
}
// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件)
// 注意post_count 欄位必須是 counter 類型
func (r *CategoryRepository) IncrementPostCount(ctx context.Context, categoryID string) error {
uuid, err := gocql.ParseUUID(categoryID)
if err != nil {
return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err)
}
// 使用 counter 原子更新操作UPDATE categories SET post_count = post_count + 1 WHERE id = ?
var zeroCategory entity.Category
tableName := zeroCategory.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(uuid)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment post count: %w", err)
}
return nil
}
// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件)
// 注意post_count 欄位必須是 counter 類型
func (r *CategoryRepository) DecrementPostCount(ctx context.Context, categoryID string) error {
uuid, err := gocql.ParseUUID(categoryID)
if err != nil {
return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err)
}
// 使用 counter 原子更新操作UPDATE categories SET post_count = post_count - 1 WHERE id = ?
var zeroCategory entity.Category
tableName := zeroCategory.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(uuid)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement post count: %w", err)
}
return nil
}