512 lines
13 KiB
Go
512 lines
13 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"math"
|
||
|
||
"backend/pkg/library/cassandra"
|
||
"backend/pkg/post/domain/entity"
|
||
"backend/pkg/post/domain/post"
|
||
domainRepo "backend/pkg/post/domain/repository"
|
||
|
||
"github.com/gocql/gocql"
|
||
)
|
||
|
||
// PostRepositoryParam 定義 PostRepository 的初始化參數
|
||
type PostRepositoryParam struct {
|
||
DB *cassandra.DB
|
||
Keyspace string
|
||
}
|
||
|
||
// PostRepository 實作 domain repository 介面
|
||
type PostRepository struct {
|
||
repo cassandra.Repository[*entity.Post]
|
||
db *cassandra.DB
|
||
keyspace string
|
||
}
|
||
|
||
// NewPostRepository 創建新的 PostRepository
|
||
func NewPostRepository(param PostRepositoryParam) domainRepo.PostRepository {
|
||
repo, err := cassandra.NewRepository[*entity.Post](param.DB, param.Keyspace)
|
||
if err != nil {
|
||
panic(fmt.Sprintf("failed to create post repository: %v", err))
|
||
}
|
||
|
||
keyspace := param.Keyspace
|
||
if keyspace == "" {
|
||
keyspace = param.DB.GetDefaultKeyspace()
|
||
}
|
||
|
||
return &PostRepository{
|
||
repo: repo,
|
||
db: param.DB,
|
||
keyspace: keyspace,
|
||
}
|
||
}
|
||
|
||
// Insert 插入單筆貼文
|
||
func (r *PostRepository) Insert(ctx context.Context, data *entity.Post) error {
|
||
if data == nil {
|
||
return ErrInvalidInput
|
||
}
|
||
|
||
// 驗證資料
|
||
if err := data.Validate(); err != nil {
|
||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||
}
|
||
|
||
// 設置時間戳
|
||
data.SetTimestamps()
|
||
|
||
// 如果是新貼文,生成 ID
|
||
if data.IsNew() {
|
||
data.ID = gocql.TimeUUID()
|
||
}
|
||
|
||
// 如果狀態是 published,設置發布時間
|
||
if data.Status == post.PostStatusPublished && data.PublishedAt == nil {
|
||
now := data.CreatedAt
|
||
data.PublishedAt = &now
|
||
}
|
||
|
||
return r.repo.Insert(ctx, data)
|
||
}
|
||
|
||
// FindOne 根據 ID 查詢單筆貼文
|
||
func (r *PostRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Post, error) {
|
||
var zeroUUID gocql.UUID
|
||
if id == zeroUUID {
|
||
return nil, ErrInvalidInput
|
||
}
|
||
|
||
post, err := r.repo.Get(ctx, id)
|
||
if err != nil {
|
||
if cassandra.IsNotFound(err) {
|
||
return nil, ErrNotFound
|
||
}
|
||
return nil, fmt.Errorf("failed to find post: %w", err)
|
||
}
|
||
|
||
return post, nil
|
||
}
|
||
|
||
// Update 更新貼文
|
||
func (r *PostRepository) Update(ctx context.Context, data *entity.Post) error {
|
||
if data == nil {
|
||
return ErrInvalidInput
|
||
}
|
||
|
||
// 驗證資料
|
||
if err := data.Validate(); err != nil {
|
||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||
}
|
||
|
||
// 更新時間戳
|
||
data.SetTimestamps()
|
||
|
||
return r.repo.Update(ctx, data)
|
||
}
|
||
|
||
// Delete 刪除貼文(軟刪除)
|
||
func (r *PostRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
||
var zeroUUID gocql.UUID
|
||
if id == zeroUUID {
|
||
return ErrInvalidInput
|
||
}
|
||
|
||
// 先查詢貼文
|
||
post, err := r.FindOne(ctx, id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 軟刪除:標記為已刪除
|
||
post.Delete()
|
||
return r.Update(ctx, post)
|
||
}
|
||
|
||
// FindByAuthorUID 根據作者 UID 查詢貼文
|
||
func (r *PostRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
||
if authorUID == "" {
|
||
return nil, 0, ErrInvalidInput
|
||
}
|
||
|
||
// 構建查詢
|
||
query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID))
|
||
|
||
// 添加狀態過濾
|
||
if params != nil && params.Status != nil {
|
||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||
}
|
||
|
||
// 添加排序
|
||
orderBy := "created_at"
|
||
if params != nil && params.OrderBy != "" {
|
||
orderBy = params.OrderBy
|
||
}
|
||
order := cassandra.DESC
|
||
if params != nil && params.OrderDirection == "ASC" {
|
||
order = cassandra.ASC
|
||
}
|
||
query = query.OrderBy(orderBy, order)
|
||
|
||
// 添加分頁
|
||
pageSize := int64(20)
|
||
if params != nil && params.PageSize > 0 {
|
||
pageSize = params.PageSize
|
||
}
|
||
pageIndex := int64(1)
|
||
if params != nil && params.PageIndex > 0 {
|
||
pageIndex = params.PageIndex
|
||
}
|
||
limit := int(pageSize)
|
||
query = query.Limit(limit)
|
||
|
||
// 執行查詢
|
||
var posts []*entity.Post
|
||
if err := query.Scan(ctx, &posts); err != nil {
|
||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
||
}
|
||
|
||
result := posts
|
||
|
||
// 計算總數(簡化實作,實際應該使用 COUNT 查詢)
|
||
total := int64(len(posts))
|
||
if params != nil && params.PageIndex > 1 {
|
||
// 這裡應該執行 COUNT 查詢,但為了簡化,我們假設有更多結果
|
||
total = pageSize * pageIndex
|
||
}
|
||
|
||
return result, total, nil
|
||
}
|
||
|
||
// FindByCategoryID 根據分類 ID 查詢貼文
|
||
func (r *PostRepository) FindByCategoryID(ctx context.Context, categoryID gocql.UUID, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
||
var zeroUUID gocql.UUID
|
||
if categoryID == zeroUUID {
|
||
return nil, 0, ErrInvalidInput
|
||
}
|
||
|
||
// 構建查詢
|
||
query := r.repo.Query().Where(cassandra.Eq("category_id", categoryID))
|
||
|
||
// 添加狀態過濾
|
||
if params != nil && params.Status != nil {
|
||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||
}
|
||
|
||
// 添加排序和分頁(類似 FindByAuthorUID)
|
||
orderBy := "created_at"
|
||
if params != nil && params.OrderBy != "" {
|
||
orderBy = params.OrderBy
|
||
}
|
||
order := cassandra.DESC
|
||
if params != nil && params.OrderDirection == "ASC" {
|
||
order = cassandra.ASC
|
||
}
|
||
query = query.OrderBy(orderBy, order)
|
||
|
||
pageSize := int64(20)
|
||
if params != nil && params.PageSize > 0 {
|
||
pageSize = params.PageSize
|
||
}
|
||
limit := int(pageSize)
|
||
query = query.Limit(limit)
|
||
|
||
var posts []*entity.Post
|
||
if err := query.Scan(ctx, &posts); err != nil {
|
||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
||
}
|
||
|
||
result := posts
|
||
|
||
total := int64(len(posts))
|
||
return result, total, nil
|
||
}
|
||
|
||
// FindByTag 根據標籤查詢貼文
|
||
func (r *PostRepository) FindByTag(ctx context.Context, tagName string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
||
if tagName == "" {
|
||
return nil, 0, ErrInvalidInput
|
||
}
|
||
|
||
// 構建查詢(注意:Cassandra 的集合查詢需要使用 CONTAINS,這裡簡化處理)
|
||
// 實際實作中,可能需要使用 SAI 索引或 Materialized View
|
||
query := r.repo.Query()
|
||
|
||
// 添加狀態過濾
|
||
if params != nil && params.Status != nil {
|
||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||
}
|
||
|
||
// 添加排序和分頁
|
||
orderBy := "created_at"
|
||
if params != nil && params.OrderBy != "" {
|
||
orderBy = params.OrderBy
|
||
}
|
||
order := cassandra.DESC
|
||
if params != nil && params.OrderDirection == "ASC" {
|
||
order = cassandra.ASC
|
||
}
|
||
query = query.OrderBy(orderBy, order)
|
||
|
||
pageSize := int64(20)
|
||
if params != nil && params.PageSize > 0 {
|
||
pageSize = params.PageSize
|
||
}
|
||
limit := int(pageSize)
|
||
query = query.Limit(limit)
|
||
|
||
var posts []*entity.Post
|
||
if err := query.Scan(ctx, &posts); err != nil {
|
||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
||
}
|
||
|
||
// 過濾包含指定標籤的貼文
|
||
filtered := make([]*entity.Post, 0)
|
||
for _, p := range posts {
|
||
for _, tag := range p.Tags {
|
||
if tag == tagName {
|
||
filtered = append(filtered, p)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
total := int64(len(filtered))
|
||
return filtered, total, nil
|
||
}
|
||
|
||
// FindPinnedPosts 查詢置頂貼文
|
||
func (r *PostRepository) FindPinnedPosts(ctx context.Context, limit int64) ([]*entity.Post, error) {
|
||
query := r.repo.Query().
|
||
Where(cassandra.Eq("is_pinned", true)).
|
||
Where(cassandra.Eq("status", post.PostStatusPublished)).
|
||
OrderBy("pinned_at", cassandra.DESC).
|
||
Limit(int(limit))
|
||
|
||
var posts []*entity.Post
|
||
if err := query.Scan(ctx, &posts); err != nil {
|
||
return nil, fmt.Errorf("failed to query pinned posts: %w", err)
|
||
}
|
||
|
||
return posts, nil
|
||
}
|
||
|
||
// FindByStatus 根據狀態查詢貼文
|
||
func (r *PostRepository) FindByStatus(ctx context.Context, status post.Status, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
||
query := r.repo.Query().Where(cassandra.Eq("status", status))
|
||
|
||
// 添加排序和分頁
|
||
orderBy := "created_at"
|
||
if params != nil && params.OrderBy != "" {
|
||
orderBy = params.OrderBy
|
||
}
|
||
order := cassandra.DESC
|
||
if params != nil && params.OrderDirection == "ASC" {
|
||
order = cassandra.ASC
|
||
}
|
||
query = query.OrderBy(orderBy, order)
|
||
|
||
pageSize := int64(20)
|
||
if params != nil && params.PageSize > 0 {
|
||
pageSize = params.PageSize
|
||
}
|
||
limit := int(pageSize)
|
||
query = query.Limit(limit)
|
||
|
||
var posts []*entity.Post
|
||
if err := query.Scan(ctx, &posts); err != nil {
|
||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
||
}
|
||
|
||
result := posts
|
||
|
||
total := int64(len(posts))
|
||
return result, total, nil
|
||
}
|
||
|
||
// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件)
|
||
// 注意:like_count 欄位必須是 counter 類型
|
||
func (r *PostRepository) IncrementLikeCount(ctx context.Context, postID gocql.UUID) error {
|
||
var zeroUUID gocql.UUID
|
||
if postID == zeroUUID {
|
||
return ErrInvalidInput
|
||
}
|
||
|
||
var zeroPost entity.Post
|
||
tableName := zeroPost.TableName()
|
||
if r.keyspace == "" {
|
||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||
}
|
||
|
||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||
query := r.db.GetSession().Query(stmt, nil).
|
||
WithContext(ctx).
|
||
Consistency(gocql.Quorum).
|
||
Bind(postID)
|
||
|
||
if err := query.ExecRelease(); err != nil {
|
||
return fmt.Errorf("failed to increment like count: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件)
|
||
// 注意:like_count 欄位必須是 counter 類型
|
||
func (r *PostRepository) DecrementLikeCount(ctx context.Context, postID gocql.UUID) error {
|
||
var zeroUUID gocql.UUID
|
||
if postID == zeroUUID {
|
||
return ErrInvalidInput
|
||
}
|
||
|
||
var zeroPost entity.Post
|
||
tableName := zeroPost.TableName()
|
||
if r.keyspace == "" {
|
||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||
}
|
||
|
||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName)
|
||
query := r.db.GetSession().Query(stmt, nil).
|
||
WithContext(ctx).
|
||
Consistency(gocql.Quorum).
|
||
Bind(postID)
|
||
|
||
if err := query.ExecRelease(); err != nil {
|
||
return fmt.Errorf("failed to decrement like count: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// IncrementCommentCount 增加評論數(使用 counter 原子操作避免競爭條件)
|
||
// 注意:comment_count 欄位必須是 counter 類型
|
||
func (r *PostRepository) IncrementCommentCount(ctx context.Context, postID gocql.UUID) error {
|
||
var zeroUUID gocql.UUID
|
||
if postID == zeroUUID {
|
||
return ErrInvalidInput
|
||
}
|
||
|
||
var zeroPost entity.Post
|
||
tableName := zeroPost.TableName()
|
||
if r.keyspace == "" {
|
||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||
}
|
||
|
||
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||
query := r.db.GetSession().Query(stmt, nil).
|
||
WithContext(ctx).
|
||
Consistency(gocql.Quorum).
|
||
Bind(postID)
|
||
|
||
if err := query.ExecRelease(); err != nil {
|
||
return fmt.Errorf("failed to increment comment count: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DecrementCommentCount 減少評論數(使用 counter 原子操作避免競爭條件)
|
||
// 注意:comment_count 欄位必須是 counter 類型
|
||
func (r *PostRepository) DecrementCommentCount(ctx context.Context, postID gocql.UUID) error {
|
||
var zeroUUID gocql.UUID
|
||
if postID == zeroUUID {
|
||
return ErrInvalidInput
|
||
}
|
||
|
||
var zeroPost entity.Post
|
||
tableName := zeroPost.TableName()
|
||
if r.keyspace == "" {
|
||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||
}
|
||
|
||
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count - 1 WHERE id = ?", r.keyspace, tableName)
|
||
query := r.db.GetSession().Query(stmt, nil).
|
||
WithContext(ctx).
|
||
Consistency(gocql.Quorum).
|
||
Bind(postID)
|
||
|
||
if err := query.ExecRelease(); err != nil {
|
||
return fmt.Errorf("failed to decrement comment count: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// IncrementViewCount 增加瀏覽數(使用 counter 原子操作避免競爭條件)
|
||
// 注意:view_count 欄位必須是 counter 類型
|
||
func (r *PostRepository) IncrementViewCount(ctx context.Context, postID gocql.UUID) error {
|
||
var zeroUUID gocql.UUID
|
||
if postID == zeroUUID {
|
||
return ErrInvalidInput
|
||
}
|
||
|
||
var zeroPost entity.Post
|
||
tableName := zeroPost.TableName()
|
||
if r.keyspace == "" {
|
||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||
}
|
||
|
||
stmt := fmt.Sprintf("UPDATE %s.%s SET view_count = view_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||
query := r.db.GetSession().Query(stmt, nil).
|
||
WithContext(ctx).
|
||
Consistency(gocql.Quorum).
|
||
Bind(postID)
|
||
|
||
if err := query.ExecRelease(); err != nil {
|
||
return fmt.Errorf("failed to increment view count: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdateStatus 更新貼文狀態
|
||
func (r *PostRepository) UpdateStatus(ctx context.Context, postID gocql.UUID, status post.Status) error {
|
||
postEntity, err := r.FindOne(ctx, postID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
postEntity.Status = status
|
||
publishedStatus := post.PostStatusPublished
|
||
if status == publishedStatus && postEntity.PublishedAt == nil {
|
||
now := postEntity.UpdatedAt
|
||
postEntity.PublishedAt = &now
|
||
}
|
||
|
||
return r.Update(ctx, postEntity)
|
||
}
|
||
|
||
// PinPost 置頂貼文
|
||
func (r *PostRepository) PinPost(ctx context.Context, postID gocql.UUID) error {
|
||
post, err := r.FindOne(ctx, postID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
post.Pin()
|
||
return r.Update(ctx, post)
|
||
}
|
||
|
||
// UnpinPost 取消置頂
|
||
func (r *PostRepository) UnpinPost(ctx context.Context, postID gocql.UUID) error {
|
||
post, err := r.FindOne(ctx, postID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
post.Unpin()
|
||
return r.Update(ctx, post)
|
||
}
|
||
|
||
// calculateTotalPages 計算總頁數
|
||
func calculateTotalPages(total, pageSize int64) int64 {
|
||
if pageSize <= 0 {
|
||
return 0
|
||
}
|
||
return int64(math.Ceil(float64(total) / float64(pageSize)))
|
||
}
|
||
|