backend/pkg/post/repository/post.go

512 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"fmt"
"math"
"backend/pkg/library/cassandra"
"backend/pkg/post/domain/entity"
"backend/pkg/post/domain/post"
domainRepo "backend/pkg/post/domain/repository"
"github.com/gocql/gocql"
)
// PostRepositoryParam 定義 PostRepository 的初始化參數
type PostRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// PostRepository 實作 domain repository 介面
type PostRepository struct {
repo cassandra.Repository[*entity.Post]
db *cassandra.DB
keyspace string
}
// NewPostRepository 創建新的 PostRepository
func NewPostRepository(param PostRepositoryParam) domainRepo.PostRepository {
repo, err := cassandra.NewRepository[*entity.Post](param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create post repository: %v", err))
}
keyspace := param.Keyspace
if keyspace == "" {
keyspace = param.DB.GetDefaultKeyspace()
}
return &PostRepository{
repo: repo,
db: param.DB,
keyspace: keyspace,
}
}
// Insert 插入單筆貼文
func (r *PostRepository) Insert(ctx context.Context, data *entity.Post) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 設置時間戳
data.SetTimestamps()
// 如果是新貼文,生成 ID
if data.IsNew() {
data.ID = gocql.TimeUUID()
}
// 如果狀態是 published設置發布時間
if data.Status == post.PostStatusPublished && data.PublishedAt == nil {
now := data.CreatedAt
data.PublishedAt = &now
}
return r.repo.Insert(ctx, data)
}
// FindOne 根據 ID 查詢單筆貼文
func (r *PostRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Post, error) {
var zeroUUID gocql.UUID
if id == zeroUUID {
return nil, ErrInvalidInput
}
post, err := r.repo.Get(ctx, id)
if err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to find post: %w", err)
}
return post, nil
}
// Update 更新貼文
func (r *PostRepository) Update(ctx context.Context, data *entity.Post) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 更新時間戳
data.SetTimestamps()
return r.repo.Update(ctx, data)
}
// Delete 刪除貼文(軟刪除)
func (r *PostRepository) Delete(ctx context.Context, id gocql.UUID) error {
var zeroUUID gocql.UUID
if id == zeroUUID {
return ErrInvalidInput
}
// 先查詢貼文
post, err := r.FindOne(ctx, id)
if err != nil {
return err
}
// 軟刪除:標記為已刪除
post.Delete()
return r.Update(ctx, post)
}
// FindByAuthorUID 根據作者 UID 查詢貼文
func (r *PostRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
if authorUID == "" {
return nil, 0, ErrInvalidInput
}
// 構建查詢
query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID))
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
}
// 添加排序
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
// 添加分頁
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
pageIndex := int64(1)
if params != nil && params.PageIndex > 0 {
pageIndex = params.PageIndex
}
limit := int(pageSize)
query = query.Limit(limit)
// 執行查詢
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
}
result := posts
// 計算總數(簡化實作,實際應該使用 COUNT 查詢)
total := int64(len(posts))
if params != nil && params.PageIndex > 1 {
// 這裡應該執行 COUNT 查詢,但為了簡化,我們假設有更多結果
total = pageSize * pageIndex
}
return result, total, nil
}
// FindByCategoryID 根據分類 ID 查詢貼文
func (r *PostRepository) FindByCategoryID(ctx context.Context, categoryID gocql.UUID, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
var zeroUUID gocql.UUID
if categoryID == zeroUUID {
return nil, 0, ErrInvalidInput
}
// 構建查詢
query := r.repo.Query().Where(cassandra.Eq("category_id", categoryID))
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
}
// 添加排序和分頁(類似 FindByAuthorUID
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
limit := int(pageSize)
query = query.Limit(limit)
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
}
result := posts
total := int64(len(posts))
return result, total, nil
}
// FindByTag 根據標籤查詢貼文
func (r *PostRepository) FindByTag(ctx context.Context, tagName string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
if tagName == "" {
return nil, 0, ErrInvalidInput
}
// 構建查詢注意Cassandra 的集合查詢需要使用 CONTAINS這裡簡化處理
// 實際實作中,可能需要使用 SAI 索引或 Materialized View
query := r.repo.Query()
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
}
// 添加排序和分頁
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
limit := int(pageSize)
query = query.Limit(limit)
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
}
// 過濾包含指定標籤的貼文
filtered := make([]*entity.Post, 0)
for _, p := range posts {
for _, tag := range p.Tags {
if tag == tagName {
filtered = append(filtered, p)
break
}
}
}
total := int64(len(filtered))
return filtered, total, nil
}
// FindPinnedPosts 查詢置頂貼文
func (r *PostRepository) FindPinnedPosts(ctx context.Context, limit int64) ([]*entity.Post, error) {
query := r.repo.Query().
Where(cassandra.Eq("is_pinned", true)).
Where(cassandra.Eq("status", post.PostStatusPublished)).
OrderBy("pinned_at", cassandra.DESC).
Limit(int(limit))
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, fmt.Errorf("failed to query pinned posts: %w", err)
}
return posts, nil
}
// FindByStatus 根據狀態查詢貼文
func (r *PostRepository) FindByStatus(ctx context.Context, status post.Status, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
query := r.repo.Query().Where(cassandra.Eq("status", status))
// 添加排序和分頁
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
limit := int(pageSize)
query = query.Limit(limit)
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
}
result := posts
total := int64(len(posts))
return result, total, nil
}
// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件)
// 注意like_count 欄位必須是 counter 類型
func (r *PostRepository) IncrementLikeCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment like count: %w", err)
}
return nil
}
// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件)
// 注意like_count 欄位必須是 counter 類型
func (r *PostRepository) DecrementLikeCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement like count: %w", err)
}
return nil
}
// IncrementCommentCount 增加評論數(使用 counter 原子操作避免競爭條件)
// 注意comment_count 欄位必須是 counter 類型
func (r *PostRepository) IncrementCommentCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment comment count: %w", err)
}
return nil
}
// DecrementCommentCount 減少評論數(使用 counter 原子操作避免競爭條件)
// 注意comment_count 欄位必須是 counter 類型
func (r *PostRepository) DecrementCommentCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement comment count: %w", err)
}
return nil
}
// IncrementViewCount 增加瀏覽數(使用 counter 原子操作避免競爭條件)
// 注意view_count 欄位必須是 counter 類型
func (r *PostRepository) IncrementViewCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET view_count = view_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment view count: %w", err)
}
return nil
}
// UpdateStatus 更新貼文狀態
func (r *PostRepository) UpdateStatus(ctx context.Context, postID gocql.UUID, status post.Status) error {
postEntity, err := r.FindOne(ctx, postID)
if err != nil {
return err
}
postEntity.Status = status
publishedStatus := post.PostStatusPublished
if status == publishedStatus && postEntity.PublishedAt == nil {
now := postEntity.UpdatedAt
postEntity.PublishedAt = &now
}
return r.Update(ctx, postEntity)
}
// PinPost 置頂貼文
func (r *PostRepository) PinPost(ctx context.Context, postID gocql.UUID) error {
post, err := r.FindOne(ctx, postID)
if err != nil {
return err
}
post.Pin()
return r.Update(ctx, post)
}
// UnpinPost 取消置頂
func (r *PostRepository) UnpinPost(ctx context.Context, postID gocql.UUID) error {
post, err := r.FindOne(ctx, postID)
if err != nil {
return err
}
post.Unpin()
return r.Update(ctx, post)
}
// calculateTotalPages 計算總頁數
func calculateTotalPages(total, pageSize int64) int64 {
if pageSize <= 0 {
return 0
}
return int64(math.Ceil(float64(total) / float64(pageSize)))
}