package repository import ( "context" "fmt" "strings" "backend/pkg/library/cassandra" "backend/pkg/post/domain/entity" domainRepo "backend/pkg/post/domain/repository" "github.com/gocql/gocql" ) // CategoryRepositoryParam 定義 CategoryRepository 的初始化參數 type CategoryRepositoryParam struct { DB *cassandra.DB Keyspace string } // CategoryRepository 實作 domain repository 介面 type CategoryRepository struct { repo cassandra.Repository[*entity.Category] db *cassandra.DB keyspace string } // NewCategoryRepository 創建新的 CategoryRepository func NewCategoryRepository(param CategoryRepositoryParam) domainRepo.CategoryRepository { repo, err := cassandra.NewRepository[*entity.Category](param.DB, param.Keyspace) if err != nil { panic(fmt.Sprintf("failed to create category repository: %v", err)) } keyspace := param.Keyspace if keyspace == "" { keyspace = param.DB.GetDefaultKeyspace() } return &CategoryRepository{ repo: repo, db: param.DB, keyspace: keyspace, } } // Insert 插入單筆分類 func (r *CategoryRepository) Insert(ctx context.Context, data *entity.Category) error { if data == nil { return ErrInvalidInput } // 驗證資料 if err := data.Validate(); err != nil { return fmt.Errorf("%w: %v", ErrInvalidInput, err) } if data.ParentID == nil { data.ParentID = &gocql.UUID{} } // 設置時間戳 data.SetTimestamps() // 如果是新分類,生成 ID if data.IsNew() { data.ID = gocql.TimeUUID() } // Slug 轉為小寫 data.Slug = strings.ToLower(strings.TrimSpace(data.Slug)) return r.repo.Insert(ctx, data) } // FindOne 根據 ID 查詢單筆分類 func (r *CategoryRepository) FindOne(ctx context.Context, id string) (*entity.Category, error) { var zeroUUID gocql.UUID uuid, err := gocql.ParseUUID(id) if err != nil { return nil, err } if uuid == zeroUUID { return nil, ErrInvalidInput } category, err := r.repo.Get(ctx, id) if err != nil { if cassandra.IsNotFound(err) { return nil, ErrNotFound } return nil, fmt.Errorf("failed to find category: %w", err) } return category, nil } // Update 更新分類 func (r *CategoryRepository) Update(ctx context.Context, data *entity.Category) error { if data == nil { return ErrInvalidInput } // 驗證資料 if err := data.Validate(); err != nil { return fmt.Errorf("%w: %v", ErrInvalidInput, err) } // 更新時間戳 data.SetTimestamps() // Slug 轉為小寫 data.Slug = strings.ToLower(strings.TrimSpace(data.Slug)) return r.repo.Update(ctx, data) } // Delete 刪除分類 func (r *CategoryRepository) Delete(ctx context.Context, id string) error { var zeroUUID gocql.UUID uuid, err := gocql.ParseUUID(id) if err != nil { return err } if uuid == zeroUUID { return ErrInvalidInput } return r.repo.Delete(ctx, id) } // FindBySlug 根據 slug 查詢分類 func (r *CategoryRepository) FindBySlug(ctx context.Context, slug string) (*entity.Category, error) { if slug == "" { return nil, ErrInvalidInput } // 標準化 slug slug = strings.ToLower(strings.TrimSpace(slug)) // 構建查詢(要有 SAI 索引在 slug 欄位上) query := r.repo.Query().Where(cassandra.Eq("slug", slug)) var categories []*entity.Category if err := query.Scan(ctx, &categories); err != nil { if cassandra.IsNotFound(err) { return nil, ErrNotFound } return nil, fmt.Errorf("failed to query category: %w", err) } if len(categories) == 0 { return nil, ErrNotFound } return categories[0], nil } // FindByParentID 根據父分類 ID 查詢子分類 func (r *CategoryRepository) FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error) { query := r.repo.Query() var zeroUUID gocql.UUID if parentID != "" { // 構建查詢(有 SAI 索引在 parentID 欄位上) uuid, err := gocql.ParseUUID(parentID) if err != nil { return nil, err } if uuid != zeroUUID { query = query.Where(cassandra.Eq("parent_id", uuid)) } } else { query = query.Where(cassandra.Eq("parent_id", zeroUUID)) } // 按 sort_order 排序 query = query.OrderBy("sort_order", cassandra.ASC) var categories []*entity.Category if err := query.Scan(ctx, &categories); err != nil { return nil, fmt.Errorf("failed to query categories: %w", err) } return categories, nil } // FindRootCategories 查詢根分類 func (r *CategoryRepository) FindRootCategories(ctx context.Context) ([]*entity.Category, error) { return r.FindByParentID(ctx, "") } // FindActive 查詢啟用的分類 func (r *CategoryRepository) FindActive(ctx context.Context) ([]*entity.Category, error) { query := r.repo.Query(). Where(cassandra.Eq("is_active", true)). OrderBy("sort_order", cassandra.ASC) var categories []*entity.Category if err := query.Scan(ctx, &categories); err != nil { return nil, fmt.Errorf("failed to query active categories: %w", err) } result := categories return result, nil } // IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件) // 注意:post_count 欄位必須是 counter 類型 func (r *CategoryRepository) IncrementPostCount(ctx context.Context, categoryID string) error { uuid, err := gocql.ParseUUID(categoryID) if err != nil { return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err) } // 使用 counter 原子更新操作:UPDATE categories SET post_count = post_count + 1 WHERE id = ? var zeroCategory entity.Category tableName := zeroCategory.TableName() if r.keyspace == "" { return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) } stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName) query := r.db.GetSession().Query(stmt, nil). WithContext(ctx). Consistency(gocql.Quorum). Bind(uuid) if err := query.ExecRelease(); err != nil { return fmt.Errorf("failed to increment post count: %w", err) } return nil } // DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件) // 注意:post_count 欄位必須是 counter 類型 func (r *CategoryRepository) DecrementPostCount(ctx context.Context, categoryID string) error { uuid, err := gocql.ParseUUID(categoryID) if err != nil { return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err) } // 使用 counter 原子更新操作:UPDATE categories SET post_count = post_count - 1 WHERE id = ? var zeroCategory entity.Category tableName := zeroCategory.TableName() if r.keyspace == "" { return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) } stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName) query := r.db.GetSession().Query(stmt, nil). WithContext(ctx). Consistency(gocql.Quorum). Bind(uuid) if err := query.ExecRelease(); err != nil { return fmt.Errorf("failed to decrement post count: %w", err) } return nil }