package repository import ( "context" "fmt" "math" "backend/pkg/library/cassandra" "backend/pkg/post/domain/entity" "backend/pkg/post/domain/post" domainRepo "backend/pkg/post/domain/repository" "github.com/gocql/gocql" ) // PostRepositoryParam 定義 PostRepository 的初始化參數 type PostRepositoryParam struct { DB *cassandra.DB Keyspace string } // PostRepository 實作 domain repository 介面 type PostRepository struct { repo cassandra.Repository[*entity.Post] db *cassandra.DB keyspace string } // NewPostRepository 創建新的 PostRepository func NewPostRepository(param PostRepositoryParam) domainRepo.PostRepository { repo, err := cassandra.NewRepository[*entity.Post](param.DB, param.Keyspace) if err != nil { panic(fmt.Sprintf("failed to create post repository: %v", err)) } keyspace := param.Keyspace if keyspace == "" { keyspace = param.DB.GetDefaultKeyspace() } return &PostRepository{ repo: repo, db: param.DB, keyspace: keyspace, } } // Insert 插入單筆貼文 func (r *PostRepository) Insert(ctx context.Context, data *entity.Post) error { if data == nil { return ErrInvalidInput } // 驗證資料 if err := data.Validate(); err != nil { return fmt.Errorf("%w: %v", ErrInvalidInput, err) } // 設置時間戳 data.SetTimestamps() // 如果是新貼文,生成 ID if data.IsNew() { data.ID = gocql.TimeUUID() } // 如果狀態是 published,設置發布時間 if data.Status == post.PostStatusPublished && data.PublishedAt == nil { now := data.CreatedAt data.PublishedAt = &now } return r.repo.Insert(ctx, data) } // FindOne 根據 ID 查詢單筆貼文 func (r *PostRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Post, error) { var zeroUUID gocql.UUID if id == zeroUUID { return nil, ErrInvalidInput } post, err := r.repo.Get(ctx, id) if err != nil { if cassandra.IsNotFound(err) { return nil, ErrNotFound } return nil, fmt.Errorf("failed to find post: %w", err) } return post, nil } // Update 更新貼文 func (r *PostRepository) Update(ctx context.Context, data *entity.Post) error { if data == nil { return ErrInvalidInput } // 驗證資料 if err := data.Validate(); err != nil { return fmt.Errorf("%w: %v", ErrInvalidInput, err) } // 更新時間戳 data.SetTimestamps() return r.repo.Update(ctx, data) } // Delete 刪除貼文(軟刪除) func (r *PostRepository) Delete(ctx context.Context, id gocql.UUID) error { var zeroUUID gocql.UUID if id == zeroUUID { return ErrInvalidInput } // 先查詢貼文 post, err := r.FindOne(ctx, id) if err != nil { return err } // 軟刪除:標記為已刪除 post.Delete() return r.Update(ctx, post) } // FindByAuthorUID 根據作者 UID 查詢貼文 func (r *PostRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) { if authorUID == "" { return nil, 0, ErrInvalidInput } // 構建查詢 query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID)) // 添加狀態過濾 if params != nil && params.Status != nil { query = query.Where(cassandra.Eq("status", *params.Status)) } // 添加排序 orderBy := "created_at" if params != nil && params.OrderBy != "" { orderBy = params.OrderBy } order := cassandra.DESC if params != nil && params.OrderDirection == "ASC" { order = cassandra.ASC } query = query.OrderBy(orderBy, order) // 添加分頁 pageSize := int64(20) if params != nil && params.PageSize > 0 { pageSize = params.PageSize } pageIndex := int64(1) if params != nil && params.PageIndex > 0 { pageIndex = params.PageIndex } limit := int(pageSize) query = query.Limit(limit) // 執行查詢 var posts []*entity.Post if err := query.Scan(ctx, &posts); err != nil { return nil, 0, fmt.Errorf("failed to query posts: %w", err) } result := posts // 計算總數(簡化實作,實際應該使用 COUNT 查詢) total := int64(len(posts)) if params != nil && params.PageIndex > 1 { // 這裡應該執行 COUNT 查詢,但為了簡化,我們假設有更多結果 total = pageSize * pageIndex } return result, total, nil } // FindByCategoryID 根據分類 ID 查詢貼文 func (r *PostRepository) FindByCategoryID(ctx context.Context, categoryID gocql.UUID, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) { var zeroUUID gocql.UUID if categoryID == zeroUUID { return nil, 0, ErrInvalidInput } // 構建查詢 query := r.repo.Query().Where(cassandra.Eq("category_id", categoryID)) // 添加狀態過濾 if params != nil && params.Status != nil { query = query.Where(cassandra.Eq("status", *params.Status)) } // 添加排序和分頁(類似 FindByAuthorUID) orderBy := "created_at" if params != nil && params.OrderBy != "" { orderBy = params.OrderBy } order := cassandra.DESC if params != nil && params.OrderDirection == "ASC" { order = cassandra.ASC } query = query.OrderBy(orderBy, order) pageSize := int64(20) if params != nil && params.PageSize > 0 { pageSize = params.PageSize } limit := int(pageSize) query = query.Limit(limit) var posts []*entity.Post if err := query.Scan(ctx, &posts); err != nil { return nil, 0, fmt.Errorf("failed to query posts: %w", err) } result := posts total := int64(len(posts)) return result, total, nil } // FindByTag 根據標籤查詢貼文 func (r *PostRepository) FindByTag(ctx context.Context, tagName string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) { if tagName == "" { return nil, 0, ErrInvalidInput } // 構建查詢(注意:Cassandra 的集合查詢需要使用 CONTAINS,這裡簡化處理) // 實際實作中,可能需要使用 SAI 索引或 Materialized View query := r.repo.Query() // 添加狀態過濾 if params != nil && params.Status != nil { query = query.Where(cassandra.Eq("status", *params.Status)) } // 添加排序和分頁 orderBy := "created_at" if params != nil && params.OrderBy != "" { orderBy = params.OrderBy } order := cassandra.DESC if params != nil && params.OrderDirection == "ASC" { order = cassandra.ASC } query = query.OrderBy(orderBy, order) pageSize := int64(20) if params != nil && params.PageSize > 0 { pageSize = params.PageSize } limit := int(pageSize) query = query.Limit(limit) var posts []*entity.Post if err := query.Scan(ctx, &posts); err != nil { return nil, 0, fmt.Errorf("failed to query posts: %w", err) } // 過濾包含指定標籤的貼文 filtered := make([]*entity.Post, 0) for _, p := range posts { for _, tag := range p.Tags { if tag == tagName { filtered = append(filtered, p) break } } } total := int64(len(filtered)) return filtered, total, nil } // FindPinnedPosts 查詢置頂貼文 func (r *PostRepository) FindPinnedPosts(ctx context.Context, limit int64) ([]*entity.Post, error) { query := r.repo.Query(). Where(cassandra.Eq("is_pinned", true)). Where(cassandra.Eq("status", post.PostStatusPublished)). OrderBy("pinned_at", cassandra.DESC). Limit(int(limit)) var posts []*entity.Post if err := query.Scan(ctx, &posts); err != nil { return nil, fmt.Errorf("failed to query pinned posts: %w", err) } return posts, nil } // FindByStatus 根據狀態查詢貼文 func (r *PostRepository) FindByStatus(ctx context.Context, status post.Status, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) { query := r.repo.Query().Where(cassandra.Eq("status", status)) // 添加排序和分頁 orderBy := "created_at" if params != nil && params.OrderBy != "" { orderBy = params.OrderBy } order := cassandra.DESC if params != nil && params.OrderDirection == "ASC" { order = cassandra.ASC } query = query.OrderBy(orderBy, order) pageSize := int64(20) if params != nil && params.PageSize > 0 { pageSize = params.PageSize } limit := int(pageSize) query = query.Limit(limit) var posts []*entity.Post if err := query.Scan(ctx, &posts); err != nil { return nil, 0, fmt.Errorf("failed to query posts: %w", err) } result := posts total := int64(len(posts)) return result, total, nil } // IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件) // 注意:like_count 欄位必須是 counter 類型 func (r *PostRepository) IncrementLikeCount(ctx context.Context, postID gocql.UUID) error { var zeroUUID gocql.UUID if postID == zeroUUID { return ErrInvalidInput } var zeroPost entity.Post tableName := zeroPost.TableName() if r.keyspace == "" { return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) } stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName) query := r.db.GetSession().Query(stmt, nil). WithContext(ctx). Consistency(gocql.Quorum). Bind(postID) if err := query.ExecRelease(); err != nil { return fmt.Errorf("failed to increment like count: %w", err) } return nil } // DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件) // 注意:like_count 欄位必須是 counter 類型 func (r *PostRepository) DecrementLikeCount(ctx context.Context, postID gocql.UUID) error { var zeroUUID gocql.UUID if postID == zeroUUID { return ErrInvalidInput } var zeroPost entity.Post tableName := zeroPost.TableName() if r.keyspace == "" { return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) } stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName) query := r.db.GetSession().Query(stmt, nil). WithContext(ctx). Consistency(gocql.Quorum). Bind(postID) if err := query.ExecRelease(); err != nil { return fmt.Errorf("failed to decrement like count: %w", err) } return nil } // IncrementCommentCount 增加評論數(使用 counter 原子操作避免競爭條件) // 注意:comment_count 欄位必須是 counter 類型 func (r *PostRepository) IncrementCommentCount(ctx context.Context, postID gocql.UUID) error { var zeroUUID gocql.UUID if postID == zeroUUID { return ErrInvalidInput } var zeroPost entity.Post tableName := zeroPost.TableName() if r.keyspace == "" { return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) } stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count + 1 WHERE id = ?", r.keyspace, tableName) query := r.db.GetSession().Query(stmt, nil). WithContext(ctx). Consistency(gocql.Quorum). Bind(postID) if err := query.ExecRelease(); err != nil { return fmt.Errorf("failed to increment comment count: %w", err) } return nil } // DecrementCommentCount 減少評論數(使用 counter 原子操作避免競爭條件) // 注意:comment_count 欄位必須是 counter 類型 func (r *PostRepository) DecrementCommentCount(ctx context.Context, postID gocql.UUID) error { var zeroUUID gocql.UUID if postID == zeroUUID { return ErrInvalidInput } var zeroPost entity.Post tableName := zeroPost.TableName() if r.keyspace == "" { return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) } stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count - 1 WHERE id = ?", r.keyspace, tableName) query := r.db.GetSession().Query(stmt, nil). WithContext(ctx). Consistency(gocql.Quorum). Bind(postID) if err := query.ExecRelease(); err != nil { return fmt.Errorf("failed to decrement comment count: %w", err) } return nil } // IncrementViewCount 增加瀏覽數(使用 counter 原子操作避免競爭條件) // 注意:view_count 欄位必須是 counter 類型 func (r *PostRepository) IncrementViewCount(ctx context.Context, postID gocql.UUID) error { var zeroUUID gocql.UUID if postID == zeroUUID { return ErrInvalidInput } var zeroPost entity.Post tableName := zeroPost.TableName() if r.keyspace == "" { return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) } stmt := fmt.Sprintf("UPDATE %s.%s SET view_count = view_count + 1 WHERE id = ?", r.keyspace, tableName) query := r.db.GetSession().Query(stmt, nil). WithContext(ctx). Consistency(gocql.Quorum). Bind(postID) if err := query.ExecRelease(); err != nil { return fmt.Errorf("failed to increment view count: %w", err) } return nil } // UpdateStatus 更新貼文狀態 func (r *PostRepository) UpdateStatus(ctx context.Context, postID gocql.UUID, status post.Status) error { postEntity, err := r.FindOne(ctx, postID) if err != nil { return err } postEntity.Status = status publishedStatus := post.PostStatusPublished if status == publishedStatus && postEntity.PublishedAt == nil { now := postEntity.UpdatedAt postEntity.PublishedAt = &now } return r.Update(ctx, postEntity) } // PinPost 置頂貼文 func (r *PostRepository) PinPost(ctx context.Context, postID gocql.UUID) error { post, err := r.FindOne(ctx, postID) if err != nil { return err } post.Pin() return r.Update(ctx, post) } // UnpinPost 取消置頂 func (r *PostRepository) UnpinPost(ctx context.Context, postID gocql.UUID) error { post, err := r.FindOne(ctx, postID) if err != nil { return err } post.Unpin() return r.Update(ctx, post) } // calculateTotalPages 計算總頁數 func calculateTotalPages(total, pageSize int64) int64 { if pageSize <= 0 { return 0 } return int64(math.Ceil(float64(total) / float64(pageSize))) }