diff --git a/pkg/library/cassandra/README.md b/pkg/library/cassandra/README.md index c7e031f..8ba6a09 100644 --- a/pkg/library/cassandra/README.md +++ b/pkg/library/cassandra/README.md @@ -504,6 +504,79 @@ err = userRepo.Query(). Scan(ctx, &users) ``` +## SAI 索引管理 + +### 建立 SAI 索引 + +```go +// 檢查是否支援 SAI +if !db.SaiSupported() { + log.Fatal("SAI is not supported in this Cassandra version") +} + +// 建立標準索引 +err := db.CreateSAIIndex(ctx, "my_keyspace", "users", "email", "users_email_idx", nil) +if err != nil { + log.Printf("建立索引失敗: %v", err) +} + +// 建立全文索引(不區分大小寫) +opts := &cassandra.SAIIndexOptions{ + IndexType: cassandra.SAIIndexTypeFullText, + IsAsync: false, + CaseSensitive: false, +} +err = db.CreateSAIIndex(ctx, "my_keyspace", "posts", "content", "posts_content_ft_idx", opts) +``` + +### 查詢 SAI 索引 + +```go +// 列出資料表的所有 SAI 索引 +indexes, err := db.ListSAIIndexes(ctx, "my_keyspace", "users") +if err != nil { + log.Printf("查詢索引失敗: %v", err) +} else { + for _, idx := range indexes { + fmt.Printf("索引: %s, 欄位: %s, 類型: %s\n", idx.Name, idx.Column, idx.Type) + } +} + +// 檢查索引是否存在 +exists, err := db.CheckSAIIndexExists(ctx, "my_keyspace", "users_email_idx") +if err != nil { + log.Printf("檢查索引失敗: %v", err) +} else if exists { + fmt.Println("索引存在") +} +``` + +### 刪除 SAI 索引 + +```go +// 刪除索引 +err := db.DropSAIIndex(ctx, "my_keyspace", "users_email_idx") +if err != nil { + log.Printf("刪除索引失敗: %v", err) +} +``` + +### SAI 索引類型 + +- **SAIIndexTypeStandard**: 標準索引(等於查詢) +- **SAIIndexTypeCollection**: 集合索引(用於 list、set、map) +- **SAIIndexTypeFullText**: 全文索引 + +### SAI 索引選項 + +```go +opts := &cassandra.SAIIndexOptions{ + IndexType: cassandra.SAIIndexTypeFullText, // 索引類型 + IsAsync: false, // 是否異步建立 + CaseSensitive: true, // 是否區分大小寫 +} +``` + ## 注意事項 ### 1. 主鍵要求 @@ -540,6 +613,8 @@ err = userRepo.Query(). - 建立索引前請先檢查 `db.SaiSupported()` - 索引建立是異步操作,可能需要一些時間 - 刪除索引時使用 `IF EXISTS`,避免索引不存在時報錯 +- 使用 SAI 索引可以大幅提升非主鍵欄位的查詢效能 +- 全文索引支援不區分大小寫的搜尋 ## 完整範例 diff --git a/pkg/library/cassandra/sai.go b/pkg/library/cassandra/sai.go index 23c7236..dfec962 100644 --- a/pkg/library/cassandra/sai.go +++ b/pkg/library/cassandra/sai.go @@ -12,43 +12,43 @@ import ( type SAIIndexType string const ( - // SAIIndexTypeStandard 標準索引(預設) - SAIIndexTypeStandard SAIIndexType = "standard" - // SAIIndexTypeFrozen 用於 frozen 類型 - SAIIndexTypeFrozen SAIIndexType = "frozen" + // SAIIndexTypeStandard 標準索引(等於查詢) + SAIIndexTypeStandard SAIIndexType = "STANDARD" + // SAIIndexTypeCollection 集合索引(用於 list、set、map) + SAIIndexTypeCollection SAIIndexType = "COLLECTION" + // SAIIndexTypeFullText 全文索引 + SAIIndexTypeFullText SAIIndexType = "FULL_TEXT" ) // SAIIndexOptions 定義 SAI 索引選項 type SAIIndexOptions struct { - CaseSensitive *bool // 是否區分大小寫(預設:true) - Normalize *bool // 是否正規化(預設:false) - Analyzer string // 分析器(如 "StandardAnalyzer") + IndexType SAIIndexType // 索引類型 + IsAsync bool // 是否異步建立索引 + CaseSensitive bool // 是否區分大小寫(用於全文索引) } -// SAIIndexInfo 表示 SAI 索引資訊 -type SAIIndexInfo struct { - KeyspaceName string // Keyspace 名稱 - TableName string // 表名稱 - IndexName string // 索引名稱 - ColumnName string // 欄位名稱 - IndexType string // 索引類型 - Options map[string]string // 索引選項 +// DefaultSAIIndexOptions 返回預設的 SAI 索引選項 +func DefaultSAIIndexOptions() *SAIIndexOptions { + return &SAIIndexOptions{ + IndexType: SAIIndexTypeStandard, + IsAsync: false, + CaseSensitive: true, + } } // CreateSAIIndex 建立 SAI 索引 -// keyspace: keyspace 名稱,如果為空則使用預設 keyspace -// table: 表名稱 +// keyspace: keyspace 名稱 +// table: 資料表名稱 // column: 欄位名稱 // indexName: 索引名稱(可選,如果為空則自動生成) -// options: 索引選項(可選) -func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column string, indexName string, options *SAIIndexOptions) error { +// opts: 索引選項(可選,如果為 nil 則使用預設選項) +func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column, indexName string, opts *SAIIndexOptions) error { + // 檢查是否支援 SAI if !db.saiSupported { - return ErrSAINotSupported + return ErrInvalidInput.WithError(fmt.Errorf("SAI is not supported in Cassandra version %s (requires 4.0.9+ or 5.0+)", db.version)) } - if keyspace == "" { - keyspace = db.defaultKeyspace - } + // 驗證參數 if keyspace == "" { return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) } @@ -59,51 +59,71 @@ func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column string return ErrInvalidInput.WithError(fmt.Errorf("column is required")) } - // 生成索引名稱(如果未提供) + // 使用預設選項如果未提供 + if opts == nil { + opts = DefaultSAIIndexOptions() + } + + // 生成索引名稱如果未提供 if indexName == "" { - indexName = fmt.Sprintf("%s_%s_%s_idx", table, column, "sai") + indexName = fmt.Sprintf("%s_%s_sai_idx", table, column) } // 構建 CREATE INDEX 語句 - stmt := fmt.Sprintf("CREATE INDEX %s ON %s.%s (%s) USING 'sai'", indexName, keyspace, table, column) + var stmt strings.Builder + stmt.WriteString("CREATE CUSTOM INDEX IF NOT EXISTS ") + stmt.WriteString(indexName) + stmt.WriteString(" ON ") + stmt.WriteString(keyspace) + stmt.WriteString(".") + stmt.WriteString(table) + stmt.WriteString(" (") + stmt.WriteString(column) + stmt.WriteString(") USING 'StorageAttachedIndex'") // 添加選項 - if options != nil { - opts := make([]string, 0) - if options.CaseSensitive != nil { - opts = append(opts, fmt.Sprintf("'case_sensitive': %v", *options.CaseSensitive)) - } - if options.Normalize != nil { - opts = append(opts, fmt.Sprintf("'normalize': %v", *options.Normalize)) - } - if options.Analyzer != "" { - opts = append(opts, fmt.Sprintf("'analyzer': '%s'", options.Analyzer)) - } - if len(opts) > 0 { - stmt += " WITH OPTIONS = {" + strings.Join(opts, ", ") + "}" - } + var options []string + if opts.IsAsync { + options = append(options, "'async'='true'") } - // 執行建立索引 - q := db.session.Query(stmt, nil).WithContext(ctx).Consistency(gocql.Quorum) - if err := q.ExecRelease(); err != nil { - return ErrInvalidInput.WithTable(table).WithError(fmt.Errorf("failed to create SAI index: %w", err)) + // 根據索引類型添加特定選項 + switch opts.IndexType { + case SAIIndexTypeFullText: + if !opts.CaseSensitive { + options = append(options, "'case_sensitive'='false'") + } else { + options = append(options, "'case_sensitive'='true'") + } + case SAIIndexTypeCollection: + // Collection 索引不需要額外選項 + } + + // 如果有選項,添加到語句中 + if len(options) > 0 { + stmt.WriteString(" WITH OPTIONS = {") + stmt.WriteString(strings.Join(options, ", ")) + stmt.WriteString("}") + } + + // 執行建立索引語句 + query := db.session.Query(stmt.String(), nil). + WithContext(ctx). + Consistency(gocql.Quorum) + + err := query.ExecRelease() + if err != nil { + return ErrInvalidInput.WithError(fmt.Errorf("failed to create SAI index: %w", err)) } return nil } // DropSAIIndex 刪除 SAI 索引 -// keyspace: keyspace 名稱,如果為空則使用預設 keyspace +// keyspace: keyspace 名稱 // indexName: 索引名稱 func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) error { - if !db.saiSupported { - return ErrSAINotSupported - } - - if keyspace == "" { - keyspace = db.defaultKeyspace - } + // 驗證參數 if keyspace == "" { return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) } @@ -114,72 +134,65 @@ func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) erro // 構建 DROP INDEX 語句 stmt := fmt.Sprintf("DROP INDEX IF EXISTS %s.%s", keyspace, indexName) - // 執行刪除索引 - q := db.session.Query(stmt, nil).WithContext(ctx).Consistency(gocql.Quorum) - if err := q.ExecRelease(); err != nil { + // 執行刪除索引語句 + query := db.session.Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum) + + err := query.ExecRelease() + if err != nil { return ErrInvalidInput.WithError(fmt.Errorf("failed to drop SAI index: %w", err)) } return nil } -// ListSAIIndexes 列出指定表的 SAI 索引 -// keyspace: keyspace 名稱,如果為空則使用預設 keyspace -// table: 表名稱(可選,如果為空則列出所有表的索引) +// ListSAIIndexes 列出指定資料表的所有 SAI 索引 +// keyspace: keyspace 名稱 +// table: 資料表名稱 func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAIIndexInfo, error) { - if !db.saiSupported { - return nil, ErrSAINotSupported - } - - if keyspace == "" { - keyspace = db.defaultKeyspace - } + // 驗證參數 if keyspace == "" { return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) } - - // 構建查詢語句 - // system_schema.indexes 表的欄位:keyspace_name, table_name, index_name, kind, options, index_type - stmt := "SELECT keyspace_name, table_name, index_name, kind, options FROM system_schema.indexes WHERE keyspace_name = ?" - args := []interface{}{keyspace} - names := []string{"keyspace_name"} - - if table != "" { - stmt += " AND table_name = ?" - args = append(args, table) - names = append(names, "table_name") + if table == "" { + return nil, ErrInvalidInput.WithError(fmt.Errorf("table is required")) } - // 執行查詢 + // 查詢系統表獲取索引資訊 + // system_schema.indexes 表的結構:keyspace_name, table_name, index_name, kind, options + stmt := ` + SELECT index_name, kind, options + FROM system_schema.indexes + WHERE keyspace_name = ? AND table_name = ? + ` + var indexes []SAIIndexInfo - iter := db.session.Query(stmt, names).Bind(args...).WithContext(ctx).Consistency(gocql.One).Iter() + iter := db.session.Query(stmt, []string{"keyspace_name", "table_name"}). + WithContext(ctx). + Consistency(gocql.One). + Bind(keyspace, table). + Iter() - var keyspaceName, tableName, indexName, kind string + var indexName, kind string var options map[string]string - - for iter.Scan(&keyspaceName, &tableName, &indexName, &kind, &options) { - // 只處理 SAI 索引(kind = 'CUSTOM' 且 index_type 在 options 中) - indexType, ok := options["class_name"] - if !ok || !strings.Contains(indexType, "StorageAttachedIndex") { - continue + for iter.Scan(&indexName, &kind, &options) { + // 檢查是否為 SAI 索引(kind = 'CUSTOM' 且 class_name 包含 StorageAttachedIndex) + if kind == "CUSTOM" { + if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") { + // 從 options 中提取 target(欄位名稱) + columnName := "" + if target, ok := options["target"]; ok { + columnName = strings.Trim(target, "()\"'") + } + indexes = append(indexes, SAIIndexInfo{ + Name: indexName, + Type: "StorageAttachedIndex", + Options: options, + Column: columnName, + }) + } } - - // 從 options 中提取 column_name - // SAI 索引的 target 欄位在 options 中 - columnName := "" - if target, ok := options["target"]; ok { - // target 格式通常是 "column_name" 或 "(column_name)" - columnName = strings.Trim(target, "()\"'") - } - - indexes = append(indexes, SAIIndexInfo{ - KeyspaceName: keyspaceName, - TableName: tableName, - IndexName: indexName, - ColumnName: columnName, - IndexType: "sai", - Options: options, - }) } if err := iter.Close(); err != nil { @@ -189,59 +202,88 @@ func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAI return indexes, nil } -// GetSAIIndex 獲取指定索引的資訊 -// keyspace: keyspace 名稱,如果為空則使用預設 keyspace -// indexName: 索引名稱 -func (db *DB) GetSAIIndex(ctx context.Context, keyspace, indexName string) (*SAIIndexInfo, error) { - if !db.saiSupported { - return nil, ErrSAINotSupported - } +// SAIIndexInfo 表示 SAI 索引資訊 +type SAIIndexInfo struct { + Name string // 索引名稱 + Type string // 索引類型 + Options map[string]string // 索引選項 + Column string // 索引欄位名稱 +} +// CheckSAIIndexExists 檢查 SAI 索引是否存在 +// keyspace: keyspace 名稱 +// indexName: 索引名稱 +func (db *DB) CheckSAIIndexExists(ctx context.Context, keyspace, indexName string) (bool, error) { + // 驗證參數 if keyspace == "" { - keyspace = db.defaultKeyspace - } - if keyspace == "" { - return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) + return false, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) } if indexName == "" { - return nil, ErrInvalidInput.WithError(fmt.Errorf("index name is required")) + return false, ErrInvalidInput.WithError(fmt.Errorf("index name is required")) } - // 構建查詢語句 - stmt := "SELECT keyspace_name, table_name, index_name, kind, options FROM system_schema.indexes WHERE keyspace_name = ? AND index_name = ?" - args := []interface{}{keyspace, indexName} - names := []string{"keyspace_name", "index_name"} + // 查詢系統表檢查索引是否存在 + stmt := ` + SELECT index_name, kind, options + FROM system_schema.indexes + WHERE keyspace_name = ? AND index_name = ? + LIMIT 1 + ` - var keyspaceName, tableName, idxName, kind string + var foundIndexName, kind string var options map[string]string + err := db.session.Query(stmt, []string{"keyspace_name", "index_name"}). + WithContext(ctx). + Consistency(gocql.One). + Bind(keyspace, indexName). + Scan(&foundIndexName, &kind, &options) - // 執行查詢 - err := db.session.Query(stmt, names).Bind(args...).WithContext(ctx).Consistency(gocql.One).Scan(&keyspaceName, &tableName, &idxName, &kind, &options) + if err == gocql.ErrNotFound { + return false, nil + } if err != nil { - if err == gocql.ErrNotFound { - return nil, ErrNotFound.WithError(fmt.Errorf("index not found: %s", indexName)) - } - return nil, ErrInvalidInput.WithError(fmt.Errorf("failed to get index: %w", err)) + return false, ErrInvalidInput.WithError(fmt.Errorf("failed to check SAI index existence: %w", err)) } // 檢查是否為 SAI 索引 - indexType, ok := options["class_name"] - if !ok || !strings.Contains(indexType, "StorageAttachedIndex") { - return nil, ErrInvalidInput.WithError(fmt.Errorf("index %s is not a SAI index", indexName)) + if kind == "CUSTOM" { + if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") { + return true, nil + } } - // 從 options 中提取 column_name - columnName := "" - if target, ok := options["target"]; ok { - columnName = strings.Trim(target, "()\"'") - } - - return &SAIIndexInfo{ - KeyspaceName: keyspaceName, - TableName: tableName, - IndexName: idxName, - ColumnName: columnName, - IndexType: "sai", - Options: options, - }, nil + return false, nil +} + +// WaitForSAIIndex 等待 SAI 索引建立完成(用於異步建立) +// keyspace: keyspace 名稱 +// indexName: 索引名稱 +// maxWaitTime: 最大等待時間(秒) +func (db *DB) WaitForSAIIndex(ctx context.Context, keyspace, indexName string, maxWaitTime int) error { + // 驗證參數 + if keyspace == "" { + return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) + } + if indexName == "" { + return ErrInvalidInput.WithError(fmt.Errorf("index name is required")) + } + + // 查詢索引狀態 + // 注意:Cassandra 沒有直接的索引狀態查詢,這裡需要通過檢查索引是否可用來判斷 + // 實際實作可能需要根據具體的 Cassandra 版本調整 + + // 簡單實作:檢查索引是否存在 + exists, err := db.CheckSAIIndexExists(ctx, keyspace, indexName) + if err != nil { + return err + } + + if !exists { + return ErrInvalidInput.WithError(fmt.Errorf("index %s does not exist", indexName)) + } + + // 注意:實際的等待邏輯可能需要查詢系統表或使用其他方法 + // 這裡只是基本框架,實際使用時可能需要根據具體需求調整 + + return nil } diff --git a/pkg/library/cassandra/sai_test.go b/pkg/library/cassandra/sai_test.go index feb5af8..1aa5e10 100644 --- a/pkg/library/cassandra/sai_test.go +++ b/pkg/library/cassandra/sai_test.go @@ -1,383 +1,267 @@ package cassandra import ( + "fmt" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestCreateSAIIndex(t *testing.T) { +func TestDefaultSAIIndexOptions(t *testing.T) { + opts := DefaultSAIIndexOptions() + assert.NotNil(t, opts) + assert.Equal(t, SAIIndexTypeStandard, opts.IndexType) + assert.False(t, opts.IsAsync) + assert.True(t, opts.CaseSensitive) +} + +func TestCreateSAIIndex_Validation(t *testing.T) { tests := []struct { - name string - keyspace string - table string - column string - indexName string - options *SAIIndexOptions - description string - wantErr bool - validate func(*testing.T, error) + name string + keyspace string + table string + column string + indexName string + opts *SAIIndexOptions + wantErr bool + errMsg string }{ { - name: "create basic SAI index", - keyspace: "test_keyspace", - table: "test_table", - column: "name", - indexName: "test_name_idx", - options: nil, - description: "should create a basic SAI index", - wantErr: false, + name: "missing keyspace", + keyspace: "", + table: "test_table", + column: "test_column", + indexName: "test_idx", + opts: nil, + wantErr: true, + errMsg: "keyspace is required", }, { - name: "create SAI index with auto-generated name", - keyspace: "test_keyspace", - table: "test_table", - column: "email", - indexName: "", - options: nil, - description: "should auto-generate index name", - wantErr: false, + name: "missing table", + keyspace: "test_keyspace", + table: "", + column: "test_column", + indexName: "test_idx", + opts: nil, + wantErr: true, + errMsg: "table is required", }, { - name: "create SAI index with case insensitive option", - keyspace: "test_keyspace", - table: "test_table", - column: "title", - indexName: "test_title_idx", - options: &SAIIndexOptions{CaseSensitive: boolPtr(false)}, - description: "should create index with case insensitive option", - wantErr: false, + name: "missing column", + keyspace: "test_keyspace", + table: "test_table", + column: "", + indexName: "test_idx", + opts: nil, + wantErr: true, + errMsg: "column is required", }, { - name: "create SAI index with normalize option", - keyspace: "test_keyspace", - table: "test_table", - column: "content", - indexName: "test_content_idx", - options: &SAIIndexOptions{Normalize: boolPtr(true)}, - description: "should create index with normalize option", - wantErr: false, + name: "valid parameters with default options", + keyspace: "test_keyspace", + table: "test_table", + column: "test_column", + indexName: "test_idx", + opts: nil, + wantErr: false, }, { - name: "create SAI index with analyzer", - keyspace: "test_keyspace", - table: "test_table", - column: "description", - indexName: "test_desc_idx", - options: &SAIIndexOptions{Analyzer: "StandardAnalyzer"}, - description: "should create index with analyzer", - wantErr: false, - }, - { - name: "create SAI index with all options", - keyspace: "test_keyspace", - table: "test_table", - column: "text", - indexName: "test_text_idx", - options: &SAIIndexOptions{CaseSensitive: boolPtr(false), Normalize: boolPtr(true), Analyzer: "StandardAnalyzer"}, - description: "should create index with all options", - wantErr: false, - }, - { - name: "missing keyspace", - keyspace: "", - table: "test_table", - column: "name", - indexName: "test_idx", - options: nil, - description: "should return error when keyspace is empty and no default", - wantErr: true, - validate: func(t *testing.T, err error) { - assert.Error(t, err) - var e *Error - if assert.ErrorAs(t, err, &e) { - assert.Equal(t, ErrCodeInvalidInput, e.Code) - } + name: "valid parameters with custom options", + keyspace: "test_keyspace", + table: "test_table", + column: "test_column", + indexName: "test_idx", + opts: &SAIIndexOptions{ + IndexType: SAIIndexTypeFullText, + IsAsync: true, + CaseSensitive: false, }, + wantErr: false, }, { - name: "missing table", - keyspace: "test_keyspace", - table: "", - column: "name", - indexName: "test_idx", - options: nil, - description: "should return error when table is empty", - wantErr: true, - validate: func(t *testing.T, err error) { - assert.Error(t, err) - var e *Error - if assert.ErrorAs(t, err, &e) { - assert.Equal(t, ErrCodeInvalidInput, e.Code) - } - }, - }, - { - name: "missing column", - keyspace: "test_keyspace", - table: "test_table", - column: "", - indexName: "test_idx", - options: nil, - description: "should return error when column is empty", - wantErr: true, - validate: func(t *testing.T, err error) { - assert.Error(t, err) - var e *Error - if assert.ErrorAs(t, err, &e) { - assert.Equal(t, ErrCodeInvalidInput, e.Code) - } - }, + name: "auto-generate index name", + keyspace: "test_keyspace", + table: "test_table", + column: "test_column", + indexName: "", + opts: nil, + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 注意:這需要一個有效的 DB 實例和 SAI 支援 - // 在實際測試中,需要使用 testcontainers 或 mock + // 在實際測試中,需要使用 mock 或 testcontainers _ = tt }) } } -func TestDropSAIIndex(t *testing.T) { +func TestDropSAIIndex_Validation(t *testing.T) { tests := []struct { - name string - keyspace string - indexName string - description string - wantErr bool - validate func(*testing.T, error) + name string + keyspace string + indexName string + wantErr bool + errMsg string }{ { - name: "drop existing index", - keyspace: "test_keyspace", - indexName: "test_name_idx", - description: "should drop existing index", - wantErr: false, + name: "missing keyspace", + keyspace: "", + indexName: "test_idx", + wantErr: true, + errMsg: "keyspace is required", }, { - name: "drop non-existent index", - keyspace: "test_keyspace", - indexName: "non_existent_idx", - description: "should not error when dropping non-existent index (IF EXISTS)", - wantErr: false, + name: "missing index name", + keyspace: "test_keyspace", + indexName: "", + wantErr: true, + errMsg: "index name is required", }, { - name: "missing keyspace", - keyspace: "", - indexName: "test_idx", - description: "should return error when keyspace is empty and no default", - wantErr: true, - validate: func(t *testing.T, err error) { - assert.Error(t, err) - var e *Error - if assert.ErrorAs(t, err, &e) { - assert.Equal(t, ErrCodeInvalidInput, e.Code) - } - }, - }, - { - name: "missing index name", - keyspace: "test_keyspace", - indexName: "", - description: "should return error when index name is empty", - wantErr: true, - validate: func(t *testing.T, err error) { - assert.Error(t, err) - var e *Error - if assert.ErrorAs(t, err, &e) { - assert.Equal(t, ErrCodeInvalidInput, e.Code) - } - }, + name: "valid parameters", + keyspace: "test_keyspace", + indexName: "test_idx", + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 注意:這需要一個有效的 DB 實例和 SAI 支援 - // 在實際測試中,需要使用 testcontainers 或 mock + // 注意:這需要一個有效的 DB 實例 + // 在實際測試中,需要使用 mock 或 testcontainers _ = tt }) } } -func TestListSAIIndexes(t *testing.T) { +func TestListSAIIndexes_Validation(t *testing.T) { tests := []struct { - name string - keyspace string - table string - description string - wantErr bool - validate func(*testing.T, []SAIIndexInfo, error) + name string + keyspace string + table string + wantErr bool + errMsg string }{ { - name: "list all indexes in keyspace", - keyspace: "test_keyspace", - table: "", - description: "should list all SAI indexes in keyspace", - wantErr: false, - validate: func(t *testing.T, indexes []SAIIndexInfo, err error) { - require.NoError(t, err) - assert.NotNil(t, indexes) - }, + name: "missing keyspace", + keyspace: "", + table: "test_table", + wantErr: true, + errMsg: "keyspace is required", }, { - name: "list indexes for specific table", - keyspace: "test_keyspace", - table: "test_table", - description: "should list SAI indexes for specific table", - wantErr: false, - validate: func(t *testing.T, indexes []SAIIndexInfo, err error) { - require.NoError(t, err) - assert.NotNil(t, indexes) - for _, idx := range indexes { - assert.Equal(t, "test_table", idx.TableName) - } - }, + name: "missing table", + keyspace: "test_keyspace", + table: "", + wantErr: true, + errMsg: "table is required", }, { - name: "missing keyspace", - keyspace: "", - table: "", - description: "should return error when keyspace is empty and no default", - wantErr: true, - validate: func(t *testing.T, indexes []SAIIndexInfo, err error) { - assert.Error(t, err) - assert.Nil(t, indexes) - var e *Error - if assert.ErrorAs(t, err, &e) { - assert.Equal(t, ErrCodeInvalidInput, e.Code) - } - }, + name: "valid parameters", + keyspace: "test_keyspace", + table: "test_table", + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 注意:這需要一個有效的 DB 實例和 SAI 支援 - // 在實際測試中,需要使用 testcontainers 或 mock + // 注意:這需要一個有效的 DB 實例 + // 在實際測試中,需要使用 mock 或 testcontainers _ = tt }) } } -func TestGetSAIIndex(t *testing.T) { +func TestCheckSAIIndexExists_Validation(t *testing.T) { tests := []struct { - name string - keyspace string - indexName string - description string - wantErr bool - validate func(*testing.T, *SAIIndexInfo, error) + name string + keyspace string + indexName string + wantErr bool + errMsg string }{ { - name: "get existing index", - keyspace: "test_keyspace", - indexName: "test_name_idx", - description: "should get existing SAI index", - wantErr: false, - validate: func(t *testing.T, index *SAIIndexInfo, err error) { - require.NoError(t, err) - assert.NotNil(t, index) - assert.Equal(t, "test_name_idx", index.IndexName) - }, + name: "missing keyspace", + keyspace: "", + indexName: "test_idx", + wantErr: true, + errMsg: "keyspace is required", }, { - name: "get non-existent index", - keyspace: "test_keyspace", - indexName: "non_existent_idx", - description: "should return ErrNotFound", - wantErr: true, - validate: func(t *testing.T, index *SAIIndexInfo, err error) { - assert.Error(t, err) - assert.Nil(t, index) - assert.True(t, IsNotFound(err)) - }, + name: "missing index name", + keyspace: "test_keyspace", + indexName: "", + wantErr: true, + errMsg: "index name is required", }, { - name: "missing keyspace", - keyspace: "", - indexName: "test_idx", - description: "should return error when keyspace is empty and no default", - wantErr: true, - validate: func(t *testing.T, index *SAIIndexInfo, err error) { - assert.Error(t, err) - assert.Nil(t, index) - }, - }, - { - name: "missing index name", - keyspace: "test_keyspace", - indexName: "", - description: "should return error when index name is empty", - wantErr: true, - validate: func(t *testing.T, index *SAIIndexInfo, err error) { - assert.Error(t, err) - assert.Nil(t, index) - }, + name: "valid parameters", + keyspace: "test_keyspace", + indexName: "test_idx", + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 注意:這需要一個有效的 DB 實例和 SAI 支援 - // 在實際測試中,需要使用 testcontainers 或 mock + // 注意:這需要一個有效的 DB 實例 + // 在實際測試中,需要使用 mock 或 testcontainers _ = tt }) } } -func TestSAIIndexOptions(t *testing.T) { - t.Run("default options", func(t *testing.T) { - opts := &SAIIndexOptions{} - assert.Nil(t, opts.CaseSensitive) - assert.Nil(t, opts.Normalize) - assert.Empty(t, opts.Analyzer) - }) +func TestSAIIndexType_Constants(t *testing.T) { + tests := []struct { + name string + indexType SAIIndexType + expected string + }{ + { + name: "standard index type", + indexType: SAIIndexTypeStandard, + expected: "STANDARD", + }, + { + name: "collection index type", + indexType: SAIIndexTypeCollection, + expected: "COLLECTION", + }, + { + name: "full text index type", + indexType: SAIIndexTypeFullText, + expected: "FULL_TEXT", + }, + } - t.Run("with case sensitive", func(t *testing.T) { - caseSensitive := false - opts := &SAIIndexOptions{CaseSensitive: &caseSensitive} - assert.NotNil(t, opts.CaseSensitive) - assert.False(t, *opts.CaseSensitive) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, string(tt.indexType)) + }) + } +} - t.Run("with normalize", func(t *testing.T) { - normalize := true - opts := &SAIIndexOptions{Normalize: &normalize} - assert.NotNil(t, opts.Normalize) - assert.True(t, *opts.Normalize) - }) - - t.Run("with analyzer", func(t *testing.T) { - opts := &SAIIndexOptions{Analyzer: "StandardAnalyzer"} - assert.Equal(t, "StandardAnalyzer", opts.Analyzer) +func TestCreateSAIIndex_NotSupported(t *testing.T) { + t.Run("should return error when SAI not supported", func(t *testing.T) { + // 注意:這需要一個不支援 SAI 的 DB 實例 + // 在實際測試中,需要使用 mock 或 testcontainers }) } -func TestSAIIndexInfo(t *testing.T) { - t.Run("index info structure", func(t *testing.T) { - info := SAIIndexInfo{ - KeyspaceName: "test_keyspace", - TableName: "test_table", - IndexName: "test_idx", - ColumnName: "name", - IndexType: "sai", - Options: map[string]string{"target": "name"}, - } +func TestCreateSAIIndex_IndexNameGeneration(t *testing.T) { + t.Run("should generate index name when not provided", func(t *testing.T) { + // 測試自動生成索引名稱的邏輯 + // 格式應該是: {table}_{column}_sai_idx + table := "users" + column := "email" + expected := "users_email_sai_idx" - assert.Equal(t, "test_keyspace", info.KeyspaceName) - assert.Equal(t, "test_table", info.TableName) - assert.Equal(t, "test_idx", info.IndexName) - assert.Equal(t, "name", info.ColumnName) - assert.Equal(t, "sai", info.IndexType) - assert.NotNil(t, info.Options) + // 這裡只是測試命名邏輯,實際建立需要 DB 實例 + generated := fmt.Sprintf("%s_%s_sai_idx", table, column) + assert.Equal(t, expected, generated) }) } - -// Helper function -func boolPtr(b bool) *bool { - return &b -} diff --git a/pkg/post/domain/post/comment_status.go b/pkg/post/domain/post/comment_status.go index 81f480d..de098af 100644 --- a/pkg/post/domain/post/comment_status.go +++ b/pkg/post/domain/post/comment_status.go @@ -3,8 +3,8 @@ package post // CommentStatus 評論狀態 type CommentStatus int32 -func (s *CommentStatus) CodeToString() string { - result, ok := commentStatusMap[*s] +func (s CommentStatus) CodeToString() string { + result, ok := commentStatusMap[s] if !ok { return "" } @@ -17,8 +17,8 @@ var commentStatusMap = map[CommentStatus]string{ CommentStatusHidden: "hidden", // 隱藏 } -func (s *CommentStatus) ToInt32() int32 { - return int32(*s) +func (s CommentStatus) ToInt32() int32 { + return int32(s) } const ( diff --git a/pkg/post/domain/post/status.go b/pkg/post/domain/post/status.go index 4d62bcd..13c1ee4 100644 --- a/pkg/post/domain/post/status.go +++ b/pkg/post/domain/post/status.go @@ -3,8 +3,8 @@ package post // Status 貼文狀態 type Status int32 -func (s *Status) CodeToString() string { - result, ok := postStatusMap[*s] +func (s Status) CodeToString() string { + result, ok := postStatusMap[s] if !ok { return "" } @@ -19,8 +19,8 @@ var postStatusMap = map[Status]string{ PostStatusHidden: "hidden", // 隱藏 } -func (s *Status) ToInt32() int32 { - return int32(*s) +func (s Status) ToInt32() int32 { + return int32(s) } const ( diff --git a/pkg/post/domain/post/type.go b/pkg/post/domain/post/type.go index a534db3..71966ab 100644 --- a/pkg/post/domain/post/type.go +++ b/pkg/post/domain/post/type.go @@ -3,8 +3,8 @@ package post // Type 貼文類型 type Type int32 -func (t *Type) CodeToString() string { - result, ok := postTypeMap[*t] +func (t Type) CodeToString() string { + result, ok := postTypeMap[t] if !ok { return "" } @@ -12,28 +12,28 @@ func (t *Type) CodeToString() string { } var postTypeMap = map[Type]string{ - PostTypeText: "text", // 純文字 - PostTypeImage: "image", // 圖片 - PostTypeVideo: "video", // 影片 - PostTypeLink: "link", // 連結 - PostTypePoll: "poll", // 投票 - PostTypeArticle: "article", // 長文 + TypeText: "text", // 純文字 + TypeImage: "image", // 圖片 + TypeVideo: "video", // 影片 + TypeLink: "link", // 連結 + TypePoll: "poll", // 投票 + TypeArticle: "article", // 長文 } -func (t *Type) ToInt32() int32 { - return int32(*t) +func (t Type) ToInt32() int32 { + return int32(t) } const ( - PostTypeText Type = 0 // 純文字 - PostTypeImage Type = 1 // 圖片 - PostTypeVideo Type = 2 // 影片 - PostTypeLink Type = 3 // 連結 - PostTypePoll Type = 4 // 投票 - PostTypeArticle Type = 5 // 長文 + TypeText Type = 0 // 純文字 + TypeImage Type = 1 // 圖片 + TypeVideo Type = 2 // 影片 + TypeLink Type = 3 // 連結 + TypePoll Type = 4 // 投票 + TypeArticle Type = 5 // 長文 ) // IsValid returns true if the type is valid func (t Type) IsValid() bool { - return t >= PostTypeText && t <= PostTypeArticle + return t >= TypeText && t <= TypeArticle } diff --git a/pkg/post/domain/repository/category.go b/pkg/post/domain/repository/category.go index 28dd6a1..56c822b 100644 --- a/pkg/post/domain/repository/category.go +++ b/pkg/post/domain/repository/category.go @@ -4,26 +4,23 @@ import ( "context" "backend/pkg/post/domain/entity" - - "github.com/gocql/gocql" ) // CategoryRepository defines the interface for category data access operations type CategoryRepository interface { BaseCategoryRepository FindBySlug(ctx context.Context, slug string) (*entity.Category, error) - FindByParentID(ctx context.Context, parentID *gocql.UUID) ([]*entity.Category, error) + FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error) FindRootCategories(ctx context.Context) ([]*entity.Category, error) FindActive(ctx context.Context) ([]*entity.Category, error) - IncrementPostCount(ctx context.Context, categoryID gocql.UUID) error - DecrementPostCount(ctx context.Context, categoryID gocql.UUID) error + IncrementPostCount(ctx context.Context, categoryID string) error + DecrementPostCount(ctx context.Context, categoryID string) error } // BaseCategoryRepository defines basic CRUD operations for categories type BaseCategoryRepository interface { Insert(ctx context.Context, data *entity.Category) error - FindOne(ctx context.Context, id gocql.UUID) (*entity.Category, error) + FindOne(ctx context.Context, id string) (*entity.Category, error) Update(ctx context.Context, data *entity.Category) error - Delete(ctx context.Context, id gocql.UUID) error + Delete(ctx context.Context, id string) error } - diff --git a/pkg/post/repository/category.go b/pkg/post/repository/category.go new file mode 100644 index 0000000..f01b1a2 --- /dev/null +++ b/pkg/post/repository/category.go @@ -0,0 +1,263 @@ +package repository + +import ( + "context" + "fmt" + "strings" + + "backend/pkg/library/cassandra" + "backend/pkg/post/domain/entity" + domainRepo "backend/pkg/post/domain/repository" + + "github.com/gocql/gocql" +) + +// CategoryRepositoryParam 定義 CategoryRepository 的初始化參數 +type CategoryRepositoryParam struct { + DB *cassandra.DB + Keyspace string +} + +// CategoryRepository 實作 domain repository 介面 +type CategoryRepository struct { + repo cassandra.Repository[*entity.Category] + db *cassandra.DB + keyspace string +} + +// NewCategoryRepository 創建新的 CategoryRepository +func NewCategoryRepository(param CategoryRepositoryParam) domainRepo.CategoryRepository { + repo, err := cassandra.NewRepository[*entity.Category](param.DB, param.Keyspace) + if err != nil { + panic(fmt.Sprintf("failed to create category repository: %v", err)) + } + + keyspace := param.Keyspace + if keyspace == "" { + keyspace = param.DB.GetDefaultKeyspace() + } + + return &CategoryRepository{ + repo: repo, + db: param.DB, + keyspace: keyspace, + } +} + +// Insert 插入單筆分類 +func (r *CategoryRepository) Insert(ctx context.Context, data *entity.Category) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + if data.ParentID == nil { + data.ParentID = &gocql.UUID{} + } + + // 設置時間戳 + data.SetTimestamps() + + // 如果是新分類,生成 ID + if data.IsNew() { + data.ID = gocql.TimeUUID() + } + + // Slug 轉為小寫 + data.Slug = strings.ToLower(strings.TrimSpace(data.Slug)) + + return r.repo.Insert(ctx, data) +} + +// FindOne 根據 ID 查詢單筆分類 +func (r *CategoryRepository) FindOne(ctx context.Context, id string) (*entity.Category, error) { + var zeroUUID gocql.UUID + uuid, err := gocql.ParseUUID(id) + if err != nil { + return nil, err + } + + if uuid == zeroUUID { + return nil, ErrInvalidInput + } + + category, err := r.repo.Get(ctx, id) + if err != nil { + if cassandra.IsNotFound(err) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to find category: %w", err) + } + + return category, nil +} + +// Update 更新分類 +func (r *CategoryRepository) Update(ctx context.Context, data *entity.Category) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + // 更新時間戳 + data.SetTimestamps() + + // Slug 轉為小寫 + data.Slug = strings.ToLower(strings.TrimSpace(data.Slug)) + + return r.repo.Update(ctx, data) +} + +// Delete 刪除分類 +func (r *CategoryRepository) Delete(ctx context.Context, id string) error { + var zeroUUID gocql.UUID + uuid, err := gocql.ParseUUID(id) + if err != nil { + return err + } + if uuid == zeroUUID { + return ErrInvalidInput + } + + return r.repo.Delete(ctx, id) +} + +// FindBySlug 根據 slug 查詢分類 +func (r *CategoryRepository) FindBySlug(ctx context.Context, slug string) (*entity.Category, error) { + if slug == "" { + return nil, ErrInvalidInput + } + + // 標準化 slug + slug = strings.ToLower(strings.TrimSpace(slug)) + + // 構建查詢(要有 SAI 索引在 slug 欄位上) + query := r.repo.Query().Where(cassandra.Eq("slug", slug)) + + var categories []*entity.Category + if err := query.Scan(ctx, &categories); err != nil { + if cassandra.IsNotFound(err) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to query category: %w", err) + } + + if len(categories) == 0 { + return nil, ErrNotFound + } + + return categories[0], nil +} + +// FindByParentID 根據父分類 ID 查詢子分類 +func (r *CategoryRepository) FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error) { + query := r.repo.Query() + var zeroUUID gocql.UUID + if parentID != "" { + // 構建查詢(有 SAI 索引在 parentID 欄位上) + uuid, err := gocql.ParseUUID(parentID) + if err != nil { + return nil, err + } + if uuid != zeroUUID { + query = query.Where(cassandra.Eq("parent_id", uuid)) + } + } else { + query = query.Where(cassandra.Eq("parent_id", zeroUUID)) + } + + // 按 sort_order 排序 + query = query.OrderBy("sort_order", cassandra.ASC) + + var categories []*entity.Category + if err := query.Scan(ctx, &categories); err != nil { + return nil, fmt.Errorf("failed to query categories: %w", err) + } + + return categories, nil +} + +// FindRootCategories 查詢根分類 +func (r *CategoryRepository) FindRootCategories(ctx context.Context) ([]*entity.Category, error) { + return r.FindByParentID(ctx, "") +} + +// FindActive 查詢啟用的分類 +func (r *CategoryRepository) FindActive(ctx context.Context) ([]*entity.Category, error) { + query := r.repo.Query(). + Where(cassandra.Eq("is_active", true)). + OrderBy("sort_order", cassandra.ASC) + + var categories []*entity.Category + if err := query.Scan(ctx, &categories); err != nil { + return nil, fmt.Errorf("failed to query active categories: %w", err) + } + + result := categories + + return result, nil +} + +// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件) +// 注意:post_count 欄位必須是 counter 類型 +func (r *CategoryRepository) IncrementPostCount(ctx context.Context, categoryID string) error { + uuid, err := gocql.ParseUUID(categoryID) + if err != nil { + return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err) + } + + // 使用 counter 原子更新操作:UPDATE categories SET post_count = post_count + 1 WHERE id = ? + var zeroCategory entity.Category + tableName := zeroCategory.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(uuid) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to increment post count: %w", err) + } + + return nil +} + +// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件) +// 注意:post_count 欄位必須是 counter 類型 +func (r *CategoryRepository) DecrementPostCount(ctx context.Context, categoryID string) error { + uuid, err := gocql.ParseUUID(categoryID) + if err != nil { + return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err) + } + + // 使用 counter 原子更新操作:UPDATE categories SET post_count = post_count - 1 WHERE id = ? + var zeroCategory entity.Category + tableName := zeroCategory.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(uuid) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to decrement post count: %w", err) + } + + return nil +} diff --git a/pkg/post/repository/comment.go b/pkg/post/repository/comment.go new file mode 100644 index 0000000..6a4f9e1 --- /dev/null +++ b/pkg/post/repository/comment.go @@ -0,0 +1,383 @@ +package repository + +import ( + "context" + "fmt" + + "backend/pkg/library/cassandra" + "backend/pkg/post/domain/entity" + "backend/pkg/post/domain/post" + domainRepo "backend/pkg/post/domain/repository" + + "github.com/gocql/gocql" +) + +// CommentRepositoryParam 定義 CommentRepository 的初始化參數 +type CommentRepositoryParam struct { + DB *cassandra.DB + Keyspace string +} + +// CommentRepository 實作 domain repository 介面 +type CommentRepository struct { + repo cassandra.Repository[*entity.Comment] + db *cassandra.DB + keyspace string +} + +// NewCommentRepository 創建新的 CommentRepository +func NewCommentRepository(param CommentRepositoryParam) domainRepo.CommentRepository { + repo, err := cassandra.NewRepository[*entity.Comment](param.DB, param.Keyspace) + if err != nil { + panic(fmt.Sprintf("failed to create comment repository: %v", err)) + } + + keyspace := param.Keyspace + if keyspace == "" { + keyspace = param.DB.GetDefaultKeyspace() + } + + return &CommentRepository{ + repo: repo, + db: param.DB, + keyspace: keyspace, + } +} + +// Insert 插入單筆評論 +func (r *CommentRepository) Insert(ctx context.Context, data *entity.Comment) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + // 設置時間戳 + data.SetTimestamps() + + // 如果是新評論,生成 ID + if data.IsNew() { + data.ID = gocql.TimeUUID() + } + + return r.repo.Insert(ctx, data) +} + +// FindOne 根據 ID 查詢單筆評論 +func (r *CommentRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Comment, error) { + var zeroUUID gocql.UUID + if id == zeroUUID { + return nil, ErrInvalidInput + } + + comment, err := r.repo.Get(ctx, id) + if err != nil { + if cassandra.IsNotFound(err) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to find comment: %w", err) + } + + return comment, nil +} + +// Update 更新評論 +func (r *CommentRepository) Update(ctx context.Context, data *entity.Comment) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + // 更新時間戳 + data.SetTimestamps() + + return r.repo.Update(ctx, data) +} + +// Delete 刪除評論(軟刪除) +func (r *CommentRepository) Delete(ctx context.Context, id gocql.UUID) error { + var zeroUUID gocql.UUID + if id == zeroUUID { + return ErrInvalidInput + } + + // 先查詢評論 + comment, err := r.FindOne(ctx, id) + if err != nil { + return err + } + + // 軟刪除:標記為已刪除 + comment.Delete() + return r.Update(ctx, comment) +} + +// FindByPostID 根據貼文 ID 查詢評論 +func (r *CommentRepository) FindByPostID(ctx context.Context, postID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) { + var zeroUUID gocql.UUID + if postID == zeroUUID { + return nil, 0, ErrInvalidInput + } + + // 構建查詢(使用 PostID 作為 clustering key) + query := r.repo.Query().Where(cassandra.Eq("post_id", postID)) + + // 添加父評論過濾(如果指定,只查詢回覆) + if params != nil && params.ParentID != nil { + query = query.Where(cassandra.Eq("parent_id", *params.ParentID)) + } else { + // 如果沒有指定 ParentID,只查詢頂層評論(parent_id 為 null) + // 注意:Cassandra 不支援直接查詢 null,需要特殊處理 + // 這裡簡化處理,實際可能需要使用 Materialized View + } + + // 添加狀態過濾 + if params != nil && params.Status != nil { + query = query.Where(cassandra.Eq("status", *params.Status)) + } else { + // 預設只查詢已發布的評論 + published := post.CommentStatusPublished + query = query.Where(cassandra.Eq("status", published)) + } + + // 添加排序 + orderBy := "created_at" + if params != nil && params.OrderBy != "" { + orderBy = params.OrderBy + } + order := cassandra.ASC + if params != nil && params.OrderDirection == "DESC" { + order = cassandra.DESC + } + query = query.OrderBy(orderBy, order) + + // 添加分頁 + pageSize := int64(20) + if params != nil && params.PageSize > 0 { + pageSize = params.PageSize + } + limit := int(pageSize) + query = query.Limit(limit) + + // 執行查詢 + var comments []*entity.Comment + if err := query.Scan(ctx, &comments); err != nil { + return nil, 0, fmt.Errorf("failed to query comments: %w", err) + } + + result := comments + + total := int64(len(result)) + return result, total, nil +} + +// FindByParentID 根據父評論 ID 查詢回覆 +func (r *CommentRepository) FindByParentID(ctx context.Context, parentID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) { + var zeroUUID gocql.UUID + if parentID == zeroUUID { + return nil, 0, ErrInvalidInput + } + + query := r.repo.Query().Where(cassandra.Eq("parent_id", parentID)) + + // 添加狀態過濾 + if params != nil && params.Status != nil { + query = query.Where(cassandra.Eq("status", *params.Status)) + } else { + published := post.CommentStatusPublished + query = query.Where(cassandra.Eq("status", published)) + } + + // 添加排序和分頁 + orderBy := "created_at" + if params != nil && params.OrderBy != "" { + orderBy = params.OrderBy + } + order := cassandra.ASC + if params != nil && params.OrderDirection == "DESC" { + order = cassandra.DESC + } + query = query.OrderBy(orderBy, order) + + pageSize := int64(20) + if params != nil && params.PageSize > 0 { + pageSize = params.PageSize + } + query = query.Limit(int(pageSize)) + + var comments []*entity.Comment + if err := query.Scan(ctx, &comments); err != nil { + return nil, 0, fmt.Errorf("failed to query replies: %w", err) + } + + return comments, int64(len(comments)), nil +} + +// FindByAuthorUID 根據作者 UID 查詢評論 +func (r *CommentRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) { + if authorUID == "" { + return nil, 0, ErrInvalidInput + } + + query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID)) + + // 添加狀態過濾 + if params != nil && params.Status != nil { + query = query.Where(cassandra.Eq("status", *params.Status)) + } + + // 添加排序和分頁 + orderBy := "created_at" + if params != nil && params.OrderBy != "" { + orderBy = params.OrderBy + } + order := cassandra.DESC + if params != nil && params.OrderDirection == "ASC" { + order = cassandra.ASC + } + query = query.OrderBy(orderBy, order) + + pageSize := int64(20) + if params != nil && params.PageSize > 0 { + pageSize = params.PageSize + } + query = query.Limit(int(pageSize)) + + var comments []*entity.Comment + if err := query.Scan(ctx, &comments); err != nil { + return nil, 0, fmt.Errorf("failed to query comments: %w", err) + } + + return comments, int64(len(comments)), nil +} + +// FindReplies 查詢指定評論的回覆 +func (r *CommentRepository) FindReplies(ctx context.Context, commentID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) { + return r.FindByParentID(ctx, commentID, params) +} + +// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件) +// 注意:like_count 欄位必須是 counter 類型 +func (r *CommentRepository) IncrementLikeCount(ctx context.Context, commentID gocql.UUID) error { + var zeroUUID gocql.UUID + if commentID == zeroUUID { + return ErrInvalidInput + } + + var zeroComment entity.Comment + tableName := zeroComment.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(commentID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to increment like count: %w", err) + } + + return nil +} + +// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件) +// 注意:like_count 欄位必須是 counter 類型 +func (r *CommentRepository) DecrementLikeCount(ctx context.Context, commentID gocql.UUID) error { + var zeroUUID gocql.UUID + if commentID == zeroUUID { + return ErrInvalidInput + } + + var zeroComment entity.Comment + tableName := zeroComment.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(commentID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to decrement like count: %w", err) + } + + return nil +} + +// IncrementReplyCount 增加回覆數(使用 counter 原子操作避免競爭條件) +// 注意:reply_count 欄位必須是 counter 類型 +func (r *CommentRepository) IncrementReplyCount(ctx context.Context, commentID gocql.UUID) error { + var zeroUUID gocql.UUID + if commentID == zeroUUID { + return ErrInvalidInput + } + + var zeroComment entity.Comment + tableName := zeroComment.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET reply_count = reply_count + 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(commentID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to increment reply count: %w", err) + } + + return nil +} + +// DecrementReplyCount 減少回覆數(使用 counter 原子操作避免競爭條件) +// 注意:reply_count 欄位必須是 counter 類型 +func (r *CommentRepository) DecrementReplyCount(ctx context.Context, commentID gocql.UUID) error { + var zeroUUID gocql.UUID + if commentID == zeroUUID { + return ErrInvalidInput + } + + var zeroComment entity.Comment + tableName := zeroComment.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET reply_count = reply_count - 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(commentID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to decrement reply count: %w", err) + } + + return nil +} + +// UpdateStatus 更新評論狀態 +func (r *CommentRepository) UpdateStatus(ctx context.Context, commentID gocql.UUID, status post.CommentStatus) error { + comment, err := r.FindOne(ctx, commentID) + if err != nil { + return err + } + + comment.Status = status + return r.Update(ctx, comment) +} diff --git a/pkg/post/repository/error.go b/pkg/post/repository/error.go new file mode 100644 index 0000000..0aa7efe --- /dev/null +++ b/pkg/post/repository/error.go @@ -0,0 +1,34 @@ +package repository + +import ( + "errors" + + "backend/pkg/library/cassandra" +) + +// Common repository errors +var ( + // ErrNotFound is returned when a requested resource is not found + ErrNotFound = errors.New("resource not found") + + // ErrInvalidInput is returned when input validation fails + ErrInvalidInput = errors.New("invalid input") + + // ErrDuplicateKey is returned when attempting to insert a document with a duplicate key + ErrDuplicateKey = errors.New("duplicate key error") +) + +// IsNotFound checks if the error is a not found error +func IsNotFound(err error) bool { + if err == nil { + return false + } + if err == ErrNotFound { + return true + } + if cassandra.IsNotFound(err) { + return true + } + return false +} + diff --git a/pkg/post/repository/like.go b/pkg/post/repository/like.go new file mode 100644 index 0000000..bd8d999 --- /dev/null +++ b/pkg/post/repository/like.go @@ -0,0 +1,228 @@ +package repository + +import ( + "context" + "fmt" + + "backend/pkg/library/cassandra" + "backend/pkg/post/domain/entity" + domainRepo "backend/pkg/post/domain/repository" + + "github.com/gocql/gocql" +) + +// LikeRepositoryParam 定義 LikeRepository 的初始化參數 +type LikeRepositoryParam struct { + DB *cassandra.DB + Keyspace string +} + +// LikeRepository 實作 domain repository 介面 +type LikeRepository struct { + repo cassandra.Repository[*entity.Like] + db *cassandra.DB +} + +// NewLikeRepository 創建新的 LikeRepository +func NewLikeRepository(param LikeRepositoryParam) domainRepo.LikeRepository { + repo, err := cassandra.NewRepository[*entity.Like](param.DB, param.Keyspace) + if err != nil { + panic(fmt.Sprintf("failed to create like repository: %v", err)) + } + + return &LikeRepository{ + repo: repo, + db: param.DB, + } +} + +// Insert 插入單筆按讚 +func (r *LikeRepository) Insert(ctx context.Context, data *entity.Like) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + // 設置時間戳 + data.SetTimestamps() + + // 如果是新按讚,生成 ID + if data.IsNew() { + data.ID = gocql.TimeUUID() + } + + return r.repo.Insert(ctx, data) +} + +// FindOne 根據 ID 查詢單筆按讚 +func (r *LikeRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Like, error) { + var zeroUUID gocql.UUID + if id == zeroUUID { + return nil, ErrInvalidInput + } + + like, err := r.repo.Get(ctx, id) + if err != nil { + if cassandra.IsNotFound(err) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to find like: %w", err) + } + + return like, nil +} + +// Delete 刪除按讚 +func (r *LikeRepository) Delete(ctx context.Context, id gocql.UUID) error { + var zeroUUID gocql.UUID + if id == zeroUUID { + return ErrInvalidInput + } + + return r.repo.Delete(ctx, id) +} + +// FindByTargetID 根據目標 ID 查詢按讚列表 +func (r *LikeRepository) FindByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) ([]*entity.Like, error) { + var zeroUUID gocql.UUID + if targetID == zeroUUID { + return nil, ErrInvalidInput + } + + if targetType != "post" && targetType != "comment" { + return nil, ErrInvalidInput + } + + // 構建查詢 + query := r.repo.Query(). + Where(cassandra.Eq("target_id", targetID)). + Where(cassandra.Eq("target_type", targetType)). + OrderBy("created_at", cassandra.DESC) + + var likes []*entity.Like + if err := query.Scan(ctx, &likes); err != nil { + return nil, fmt.Errorf("failed to query likes: %w", err) + } + + return likes, nil +} + +// FindByUserUID 根據用戶 UID 查詢按讚列表 +func (r *LikeRepository) FindByUserUID(ctx context.Context, userUID string, params *domainRepo.LikeQueryParams) ([]*entity.Like, int64, error) { + if userUID == "" { + return nil, 0, ErrInvalidInput + } + + query := r.repo.Query().Where(cassandra.Eq("user_uid", userUID)) + + // 添加目標類型過濾 + if params != nil && params.TargetType != nil { + query = query.Where(cassandra.Eq("target_type", *params.TargetType)) + } + + // 添加目標 ID 過濾 + if params != nil && params.TargetID != nil { + query = query.Where(cassandra.Eq("target_id", *params.TargetID)) + } + + // 添加排序 + orderBy := "created_at" + if params != nil && params.OrderBy != "" { + orderBy = params.OrderBy + } + order := cassandra.DESC + if params != nil && params.OrderDirection == "ASC" { + order = cassandra.ASC + } + query = query.OrderBy(orderBy, order) + + // 添加分頁 + pageSize := int64(20) + if params != nil && params.PageSize > 0 { + pageSize = params.PageSize + } + query = query.Limit(int(pageSize)) + + var likes []*entity.Like + if err := query.Scan(ctx, &likes); err != nil { + return nil, 0, fmt.Errorf("failed to query likes: %w", err) + } + + result := likes + + return result, int64(len(result)), nil +} + +// FindByTargetAndUser 根據目標和用戶查詢按讚 +func (r *LikeRepository) FindByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) (*entity.Like, error) { + var zeroUUID gocql.UUID + if targetID == zeroUUID || userUID == "" { + return nil, ErrInvalidInput + } + + if targetType != "post" && targetType != "comment" { + return nil, ErrInvalidInput + } + + // 構建查詢 + query := r.repo.Query(). + Where(cassandra.Eq("target_id", targetID)). + Where(cassandra.Eq("user_uid", userUID)). + Where(cassandra.Eq("target_type", targetType)). + Limit(1) + + var likes []*entity.Like + if err := query.Scan(ctx, &likes); err != nil { + if cassandra.IsNotFound(err) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to query like: %w", err) + } + + if len(likes) == 0 { + return nil, ErrNotFound + } + + return likes[0], nil +} + +// CountByTargetID 計算目標的按讚數 +func (r *LikeRepository) CountByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) (int64, error) { + var zeroUUID gocql.UUID + if targetID == zeroUUID { + return 0, ErrInvalidInput + } + + if targetType != "post" && targetType != "comment" { + return 0, ErrInvalidInput + } + + // 構建查詢 + query := r.repo.Query(). + Where(cassandra.Eq("target_id", targetID)). + Where(cassandra.Eq("target_type", targetType)) + + count, err := query.Count(ctx) + if err != nil { + return 0, fmt.Errorf("failed to count likes: %w", err) + } + + return count, nil +} + +// DeleteByTargetAndUser 根據目標和用戶刪除按讚 +func (r *LikeRepository) DeleteByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) error { + // 先查詢按讚 + like, err := r.FindByTargetAndUser(ctx, targetID, userUID, targetType) + if err != nil { + return err + } + + // 刪除按讚 + return r.Delete(ctx, like.ID) +} + diff --git a/pkg/post/repository/post.go b/pkg/post/repository/post.go new file mode 100644 index 0000000..148de1f --- /dev/null +++ b/pkg/post/repository/post.go @@ -0,0 +1,511 @@ +package repository + +import ( + "context" + "fmt" + "math" + + "backend/pkg/library/cassandra" + "backend/pkg/post/domain/entity" + "backend/pkg/post/domain/post" + domainRepo "backend/pkg/post/domain/repository" + + "github.com/gocql/gocql" +) + +// PostRepositoryParam 定義 PostRepository 的初始化參數 +type PostRepositoryParam struct { + DB *cassandra.DB + Keyspace string +} + +// PostRepository 實作 domain repository 介面 +type PostRepository struct { + repo cassandra.Repository[*entity.Post] + db *cassandra.DB + keyspace string +} + +// NewPostRepository 創建新的 PostRepository +func NewPostRepository(param PostRepositoryParam) domainRepo.PostRepository { + repo, err := cassandra.NewRepository[*entity.Post](param.DB, param.Keyspace) + if err != nil { + panic(fmt.Sprintf("failed to create post repository: %v", err)) + } + + keyspace := param.Keyspace + if keyspace == "" { + keyspace = param.DB.GetDefaultKeyspace() + } + + return &PostRepository{ + repo: repo, + db: param.DB, + keyspace: keyspace, + } +} + +// Insert 插入單筆貼文 +func (r *PostRepository) Insert(ctx context.Context, data *entity.Post) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + // 設置時間戳 + data.SetTimestamps() + + // 如果是新貼文,生成 ID + if data.IsNew() { + data.ID = gocql.TimeUUID() + } + + // 如果狀態是 published,設置發布時間 + if data.Status == post.PostStatusPublished && data.PublishedAt == nil { + now := data.CreatedAt + data.PublishedAt = &now + } + + return r.repo.Insert(ctx, data) +} + +// FindOne 根據 ID 查詢單筆貼文 +func (r *PostRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Post, error) { + var zeroUUID gocql.UUID + if id == zeroUUID { + return nil, ErrInvalidInput + } + + post, err := r.repo.Get(ctx, id) + if err != nil { + if cassandra.IsNotFound(err) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to find post: %w", err) + } + + return post, nil +} + +// Update 更新貼文 +func (r *PostRepository) Update(ctx context.Context, data *entity.Post) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + // 更新時間戳 + data.SetTimestamps() + + return r.repo.Update(ctx, data) +} + +// Delete 刪除貼文(軟刪除) +func (r *PostRepository) Delete(ctx context.Context, id gocql.UUID) error { + var zeroUUID gocql.UUID + if id == zeroUUID { + return ErrInvalidInput + } + + // 先查詢貼文 + post, err := r.FindOne(ctx, id) + if err != nil { + return err + } + + // 軟刪除:標記為已刪除 + post.Delete() + return r.Update(ctx, post) +} + +// FindByAuthorUID 根據作者 UID 查詢貼文 +func (r *PostRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) { + if authorUID == "" { + return nil, 0, ErrInvalidInput + } + + // 構建查詢 + query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID)) + + // 添加狀態過濾 + if params != nil && params.Status != nil { + query = query.Where(cassandra.Eq("status", *params.Status)) + } + + // 添加排序 + orderBy := "created_at" + if params != nil && params.OrderBy != "" { + orderBy = params.OrderBy + } + order := cassandra.DESC + if params != nil && params.OrderDirection == "ASC" { + order = cassandra.ASC + } + query = query.OrderBy(orderBy, order) + + // 添加分頁 + pageSize := int64(20) + if params != nil && params.PageSize > 0 { + pageSize = params.PageSize + } + pageIndex := int64(1) + if params != nil && params.PageIndex > 0 { + pageIndex = params.PageIndex + } + limit := int(pageSize) + query = query.Limit(limit) + + // 執行查詢 + var posts []*entity.Post + if err := query.Scan(ctx, &posts); err != nil { + return nil, 0, fmt.Errorf("failed to query posts: %w", err) + } + + result := posts + + // 計算總數(簡化實作,實際應該使用 COUNT 查詢) + total := int64(len(posts)) + if params != nil && params.PageIndex > 1 { + // 這裡應該執行 COUNT 查詢,但為了簡化,我們假設有更多結果 + total = pageSize * pageIndex + } + + return result, total, nil +} + +// FindByCategoryID 根據分類 ID 查詢貼文 +func (r *PostRepository) FindByCategoryID(ctx context.Context, categoryID gocql.UUID, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) { + var zeroUUID gocql.UUID + if categoryID == zeroUUID { + return nil, 0, ErrInvalidInput + } + + // 構建查詢 + query := r.repo.Query().Where(cassandra.Eq("category_id", categoryID)) + + // 添加狀態過濾 + if params != nil && params.Status != nil { + query = query.Where(cassandra.Eq("status", *params.Status)) + } + + // 添加排序和分頁(類似 FindByAuthorUID) + orderBy := "created_at" + if params != nil && params.OrderBy != "" { + orderBy = params.OrderBy + } + order := cassandra.DESC + if params != nil && params.OrderDirection == "ASC" { + order = cassandra.ASC + } + query = query.OrderBy(orderBy, order) + + pageSize := int64(20) + if params != nil && params.PageSize > 0 { + pageSize = params.PageSize + } + limit := int(pageSize) + query = query.Limit(limit) + + var posts []*entity.Post + if err := query.Scan(ctx, &posts); err != nil { + return nil, 0, fmt.Errorf("failed to query posts: %w", err) + } + + result := posts + + total := int64(len(posts)) + return result, total, nil +} + +// FindByTag 根據標籤查詢貼文 +func (r *PostRepository) FindByTag(ctx context.Context, tagName string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) { + if tagName == "" { + return nil, 0, ErrInvalidInput + } + + // 構建查詢(注意:Cassandra 的集合查詢需要使用 CONTAINS,這裡簡化處理) + // 實際實作中,可能需要使用 SAI 索引或 Materialized View + query := r.repo.Query() + + // 添加狀態過濾 + if params != nil && params.Status != nil { + query = query.Where(cassandra.Eq("status", *params.Status)) + } + + // 添加排序和分頁 + orderBy := "created_at" + if params != nil && params.OrderBy != "" { + orderBy = params.OrderBy + } + order := cassandra.DESC + if params != nil && params.OrderDirection == "ASC" { + order = cassandra.ASC + } + query = query.OrderBy(orderBy, order) + + pageSize := int64(20) + if params != nil && params.PageSize > 0 { + pageSize = params.PageSize + } + limit := int(pageSize) + query = query.Limit(limit) + + var posts []*entity.Post + if err := query.Scan(ctx, &posts); err != nil { + return nil, 0, fmt.Errorf("failed to query posts: %w", err) + } + + // 過濾包含指定標籤的貼文 + filtered := make([]*entity.Post, 0) + for _, p := range posts { + for _, tag := range p.Tags { + if tag == tagName { + filtered = append(filtered, p) + break + } + } + } + + total := int64(len(filtered)) + return filtered, total, nil +} + +// FindPinnedPosts 查詢置頂貼文 +func (r *PostRepository) FindPinnedPosts(ctx context.Context, limit int64) ([]*entity.Post, error) { + query := r.repo.Query(). + Where(cassandra.Eq("is_pinned", true)). + Where(cassandra.Eq("status", post.PostStatusPublished)). + OrderBy("pinned_at", cassandra.DESC). + Limit(int(limit)) + + var posts []*entity.Post + if err := query.Scan(ctx, &posts); err != nil { + return nil, fmt.Errorf("failed to query pinned posts: %w", err) + } + + return posts, nil +} + +// FindByStatus 根據狀態查詢貼文 +func (r *PostRepository) FindByStatus(ctx context.Context, status post.Status, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) { + query := r.repo.Query().Where(cassandra.Eq("status", status)) + + // 添加排序和分頁 + orderBy := "created_at" + if params != nil && params.OrderBy != "" { + orderBy = params.OrderBy + } + order := cassandra.DESC + if params != nil && params.OrderDirection == "ASC" { + order = cassandra.ASC + } + query = query.OrderBy(orderBy, order) + + pageSize := int64(20) + if params != nil && params.PageSize > 0 { + pageSize = params.PageSize + } + limit := int(pageSize) + query = query.Limit(limit) + + var posts []*entity.Post + if err := query.Scan(ctx, &posts); err != nil { + return nil, 0, fmt.Errorf("failed to query posts: %w", err) + } + + result := posts + + total := int64(len(posts)) + return result, total, nil +} + +// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件) +// 注意:like_count 欄位必須是 counter 類型 +func (r *PostRepository) IncrementLikeCount(ctx context.Context, postID gocql.UUID) error { + var zeroUUID gocql.UUID + if postID == zeroUUID { + return ErrInvalidInput + } + + var zeroPost entity.Post + tableName := zeroPost.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(postID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to increment like count: %w", err) + } + + return nil +} + +// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件) +// 注意:like_count 欄位必須是 counter 類型 +func (r *PostRepository) DecrementLikeCount(ctx context.Context, postID gocql.UUID) error { + var zeroUUID gocql.UUID + if postID == zeroUUID { + return ErrInvalidInput + } + + var zeroPost entity.Post + tableName := zeroPost.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(postID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to decrement like count: %w", err) + } + + return nil +} + +// IncrementCommentCount 增加評論數(使用 counter 原子操作避免競爭條件) +// 注意:comment_count 欄位必須是 counter 類型 +func (r *PostRepository) IncrementCommentCount(ctx context.Context, postID gocql.UUID) error { + var zeroUUID gocql.UUID + if postID == zeroUUID { + return ErrInvalidInput + } + + var zeroPost entity.Post + tableName := zeroPost.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count + 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(postID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to increment comment count: %w", err) + } + + return nil +} + +// DecrementCommentCount 減少評論數(使用 counter 原子操作避免競爭條件) +// 注意:comment_count 欄位必須是 counter 類型 +func (r *PostRepository) DecrementCommentCount(ctx context.Context, postID gocql.UUID) error { + var zeroUUID gocql.UUID + if postID == zeroUUID { + return ErrInvalidInput + } + + var zeroPost entity.Post + tableName := zeroPost.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count - 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(postID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to decrement comment count: %w", err) + } + + return nil +} + +// IncrementViewCount 增加瀏覽數(使用 counter 原子操作避免競爭條件) +// 注意:view_count 欄位必須是 counter 類型 +func (r *PostRepository) IncrementViewCount(ctx context.Context, postID gocql.UUID) error { + var zeroUUID gocql.UUID + if postID == zeroUUID { + return ErrInvalidInput + } + + var zeroPost entity.Post + tableName := zeroPost.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET view_count = view_count + 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(postID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to increment view count: %w", err) + } + + return nil +} + +// UpdateStatus 更新貼文狀態 +func (r *PostRepository) UpdateStatus(ctx context.Context, postID gocql.UUID, status post.Status) error { + postEntity, err := r.FindOne(ctx, postID) + if err != nil { + return err + } + + postEntity.Status = status + publishedStatus := post.PostStatusPublished + if status == publishedStatus && postEntity.PublishedAt == nil { + now := postEntity.UpdatedAt + postEntity.PublishedAt = &now + } + + return r.Update(ctx, postEntity) +} + +// PinPost 置頂貼文 +func (r *PostRepository) PinPost(ctx context.Context, postID gocql.UUID) error { + post, err := r.FindOne(ctx, postID) + if err != nil { + return err + } + + post.Pin() + return r.Update(ctx, post) +} + +// UnpinPost 取消置頂 +func (r *PostRepository) UnpinPost(ctx context.Context, postID gocql.UUID) error { + post, err := r.FindOne(ctx, postID) + if err != nil { + return err + } + + post.Unpin() + return r.Update(ctx, post) +} + +// calculateTotalPages 計算總頁數 +func calculateTotalPages(total, pageSize int64) int64 { + if pageSize <= 0 { + return 0 + } + return int64(math.Ceil(float64(total) / float64(pageSize))) +} + diff --git a/pkg/post/repository/tag.go b/pkg/post/repository/tag.go new file mode 100644 index 0000000..364e76c --- /dev/null +++ b/pkg/post/repository/tag.go @@ -0,0 +1,250 @@ +package repository + +import ( + "context" + "fmt" + "strings" + + "backend/pkg/library/cassandra" + "backend/pkg/post/domain/entity" + domainRepo "backend/pkg/post/domain/repository" + + "github.com/gocql/gocql" +) + +// TagRepositoryParam 定義 TagRepository 的初始化參數 +type TagRepositoryParam struct { + DB *cassandra.DB + Keyspace string +} + +// TagRepository 實作 domain repository 介面 +type TagRepository struct { + repo cassandra.Repository[*entity.Tag] + db *cassandra.DB + keyspace string +} + +// NewTagRepository 創建新的 TagRepository +func NewTagRepository(param TagRepositoryParam) domainRepo.TagRepository { + repo, err := cassandra.NewRepository[*entity.Tag](param.DB, param.Keyspace) + if err != nil { + panic(fmt.Sprintf("failed to create tag repository: %v", err)) + } + + keyspace := param.Keyspace + if keyspace == "" { + keyspace = param.DB.GetDefaultKeyspace() + } + + return &TagRepository{ + repo: repo, + db: param.DB, + keyspace: keyspace, + } +} + +// Insert 插入單筆標籤 +func (r *TagRepository) Insert(ctx context.Context, data *entity.Tag) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + // 設置時間戳 + data.SetTimestamps() + + // 如果是新標籤,生成 ID + if data.IsNew() { + data.ID = gocql.TimeUUID() + } + + // 標籤名稱轉為小寫(統一格式) + data.Name = strings.ToLower(strings.TrimSpace(data.Name)) + + return r.repo.Insert(ctx, data) +} + +// FindOne 根據 ID 查詢單筆標籤 +func (r *TagRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Tag, error) { + var zeroUUID gocql.UUID + if id == zeroUUID { + return nil, ErrInvalidInput + } + + tag, err := r.repo.Get(ctx, id) + if err != nil { + if cassandra.IsNotFound(err) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to find tag: %w", err) + } + + return tag, nil +} + +// Update 更新標籤 +func (r *TagRepository) Update(ctx context.Context, data *entity.Tag) error { + if data == nil { + return ErrInvalidInput + } + + // 驗證資料 + if err := data.Validate(); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInput, err) + } + + // 更新時間戳 + data.SetTimestamps() + + // 標籤名稱轉為小寫 + data.Name = strings.ToLower(strings.TrimSpace(data.Name)) + + return r.repo.Update(ctx, data) +} + +// Delete 刪除標籤 +func (r *TagRepository) Delete(ctx context.Context, id gocql.UUID) error { + var zeroUUID gocql.UUID + if id == zeroUUID { + return ErrInvalidInput + } + + return r.repo.Delete(ctx, id) +} + +// FindByName 根據名稱查詢標籤 +func (r *TagRepository) FindByName(ctx context.Context, name string) (*entity.Tag, error) { + if name == "" { + return nil, ErrInvalidInput + } + + // 標準化名稱 + name = strings.ToLower(strings.TrimSpace(name)) + + // 構建查詢(假設有 SAI 索引在 name 欄位上) + query := r.repo.Query().Where(cassandra.Eq("name", name)) + + var tags []*entity.Tag + if err := query.Scan(ctx, &tags); err != nil { + if cassandra.IsNotFound(err) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to query tag: %w", err) + } + + if len(tags) == 0 { + return nil, ErrNotFound + } + + return tags[0], nil +} + +// FindByNames 根據名稱列表查詢標籤 +func (r *TagRepository) FindByNames(ctx context.Context, names []string) ([]*entity.Tag, error) { + if len(names) == 0 { + return []*entity.Tag{}, nil + } + + // 標準化名稱 + normalizedNames := make([]string, len(names)) + for i, name := range names { + normalizedNames[i] = strings.ToLower(strings.TrimSpace(name)) + } + + // 構建查詢(使用 IN 條件) + query := r.repo.Query().Where(cassandra.In("name", toAnySlice(normalizedNames))) + + var tags []*entity.Tag + if err := query.Scan(ctx, &tags); err != nil { + return nil, fmt.Errorf("failed to query tags: %w", err) + } + + return tags, nil +} + +// FindPopular 查詢熱門標籤 +func (r *TagRepository) FindPopular(ctx context.Context, limit int64) ([]*entity.Tag, error) { + // 構建查詢,按 post_count 降序排列 + query := r.repo.Query(). + OrderBy("post_count", cassandra.DESC). + Limit(int(limit)) + + var tags []*entity.Tag + if err := query.Scan(ctx, &tags); err != nil { + return nil, fmt.Errorf("failed to query popular tags: %w", err) + } + + result := tags + + return result, nil +} + +// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件) +// 注意:post_count 欄位必須是 counter 類型 +func (r *TagRepository) IncrementPostCount(ctx context.Context, tagID gocql.UUID) error { + var zeroUUID gocql.UUID + if tagID == zeroUUID { + return ErrInvalidInput + } + + // 使用 counter 原子更新操作:UPDATE tags SET post_count = post_count + 1 WHERE id = ? + var zeroTag entity.Tag + tableName := zeroTag.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(tagID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to increment post count: %w", err) + } + + return nil +} + +// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件) +// 注意:post_count 欄位必須是 counter 類型 +func (r *TagRepository) DecrementPostCount(ctx context.Context, tagID gocql.UUID) error { + var zeroUUID gocql.UUID + if tagID == zeroUUID { + return ErrInvalidInput + } + + // 使用 counter 原子更新操作:UPDATE tags SET post_count = post_count - 1 WHERE id = ? + var zeroTag entity.Tag + tableName := zeroTag.TableName() + if r.keyspace == "" { + return fmt.Errorf("%w: keyspace is required", ErrInvalidInput) + } + + stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName) + query := r.db.GetSession().Query(stmt, nil). + WithContext(ctx). + Consistency(gocql.Quorum). + Bind(tagID) + + if err := query.ExecRelease(); err != nil { + return fmt.Errorf("failed to decrement post count: %w", err) + } + + return nil +} + +// toAnySlice 將 string slice 轉換為 []any +func toAnySlice(strs []string) []any { + result := make([]any, len(strs)) + for i, s := range strs { + result[i] = s + } + return result +} diff --git a/pkg/post/usecase/comment.go b/pkg/post/usecase/comment.go new file mode 100644 index 0000000..1e374ba --- /dev/null +++ b/pkg/post/usecase/comment.go @@ -0,0 +1,455 @@ +package usecase + +import ( + "context" + "errors" + "fmt" + "math" + + errs "backend/pkg/library/errors" + "backend/pkg/post/domain/entity" + "backend/pkg/post/domain/post" + domainRepo "backend/pkg/post/domain/repository" + domainUsecase "backend/pkg/post/domain/usecase" + "backend/pkg/post/repository" + + "github.com/gocql/gocql" +) + +// CommentUseCaseParam 定義 CommentUseCase 的初始化參數 +type CommentUseCaseParam struct { + Comment domainRepo.CommentRepository + Post domainRepo.PostRepository + Like domainRepo.LikeRepository + Logger errs.Logger +} + +// CommentUseCase 實作 domain usecase 介面 +type CommentUseCase struct { + CommentUseCaseParam +} + +// MustCommentUseCase 創建新的 CommentUseCase(如果失敗會 panic) +func MustCommentUseCase(param CommentUseCaseParam) domainUsecase.CommentUseCase { + return &CommentUseCase{ + CommentUseCaseParam: param, + } +} + +// CreateComment 創建新評論 +func (uc *CommentUseCase) CreateComment(ctx context.Context, req domainUsecase.CreateCommentRequest) (*domainUsecase.CommentResponse, error) { + // 驗證輸入 + if err := uc.validateCreateCommentRequest(req); err != nil { + return nil, err + } + + // 驗證貼文存在 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return nil, errs.InputInvalidRangeError("post_id is required") + } + + post, err := uc.Post.FindOne(ctx, req.PostID) + if err != nil { + if repository.IsNotFound(err) { + return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID)) + } + return nil, uc.handleDBError("Post.FindOne", req, err) + } + + // 檢查貼文是否可見 + if !post.IsVisible() { + return nil, errs.ResNotFoundError("cannot comment on non-visible post") + } + + // 建立評論實體 + comment := &entity.Comment{ + PostID: req.PostID, + AuthorUID: req.AuthorUID, + ParentID: req.ParentID, + Content: req.Content, + Status: post.CommentStatusPublished, + } + + // 插入資料庫 + if err := uc.Comment.Insert(ctx, comment); err != nil { + return nil, uc.handleDBError("Comment.Insert", req, err) + } + + // 如果是回覆,增加父評論的回覆數 + if req.ParentID != nil { + if err := uc.Comment.IncrementReplyCount(ctx, *req.ParentID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to increment reply count: %v", err)) + } + } + + // 增加貼文的評論數 + if err := uc.Post.IncrementCommentCount(ctx, req.PostID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to increment comment count: %v", err)) + } + + return uc.mapCommentToResponse(comment), nil +} + +// GetComment 取得評論 +func (uc *CommentUseCase) GetComment(ctx context.Context, req domainUsecase.GetCommentRequest) (*domainUsecase.CommentResponse, error) { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.CommentID == zeroUUID { + return nil, errs.InputInvalidRangeError("comment_id is required") + } + + // 查詢評論 + comment, err := uc.Comment.FindOne(ctx, req.CommentID) + if err != nil { + if repository.IsNotFound(err) { + return nil, errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID)) + } + return nil, uc.handleDBError("Comment.FindOne", req, err) + } + + return uc.mapCommentToResponse(comment), nil +} + +// UpdateComment 更新評論 +func (uc *CommentUseCase) UpdateComment(ctx context.Context, req domainUsecase.UpdateCommentRequest) (*domainUsecase.CommentResponse, error) { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.CommentID == zeroUUID { + return nil, errs.InputInvalidRangeError("comment_id is required") + } + if req.AuthorUID == "" { + return nil, errs.InputInvalidRangeError("author_uid is required") + } + if req.Content == "" { + return nil, errs.InputInvalidRangeError("content is required") + } + + // 查詢現有評論 + comment, err := uc.Comment.FindOne(ctx, req.CommentID) + if err != nil { + if repository.IsNotFound(err) { + return nil, errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID)) + } + return nil, uc.handleDBError("Comment.FindOne", req, err) + } + + // 驗證權限 + if comment.AuthorUID != req.AuthorUID { + return nil, errs.ResNotFoundError("not authorized to update this comment") + } + + // 檢查是否可見 + if !comment.IsVisible() { + return nil, errs.ResNotFoundError("comment is not visible") + } + + // 更新內容 + comment.Content = req.Content + + // 更新資料庫 + if err := uc.Comment.Update(ctx, comment); err != nil { + return nil, uc.handleDBError("Comment.Update", req, err) + } + + return uc.mapCommentToResponse(comment), nil +} + +// DeleteComment 刪除評論(軟刪除) +func (uc *CommentUseCase) DeleteComment(ctx context.Context, req domainUsecase.DeleteCommentRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.CommentID == zeroUUID { + return errs.InputInvalidRangeError("comment_id is required") + } + if req.AuthorUID == "" { + return errs.InputInvalidRangeError("author_uid is required") + } + + // 查詢評論 + comment, err := uc.Comment.FindOne(ctx, req.CommentID) + if err != nil { + if repository.IsNotFound(err) { + return errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID)) + } + return uc.handleDBError("Comment.FindOne", req, err) + } + + // 驗證權限 + if comment.AuthorUID != req.AuthorUID { + return errs.ResNotFoundError("not authorized to delete this comment") + } + + // 刪除評論 + if err := uc.Comment.Delete(ctx, req.CommentID); err != nil { + return uc.handleDBError("Comment.Delete", req, err) + } + + // 如果是回覆,減少父評論的回覆數 + if comment.ParentID != nil { + if err := uc.Comment.DecrementReplyCount(ctx, *comment.ParentID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to decrement reply count: %v", err)) + } + } + + // 減少貼文的評論數 + if err := uc.Post.DecrementCommentCount(ctx, comment.PostID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to decrement comment count: %v", err)) + } + + return nil +} + +// ListComments 列出評論 +func (uc *CommentUseCase) ListComments(ctx context.Context, req domainUsecase.ListCommentsRequest) (*domainUsecase.ListCommentsResponse, error) { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return nil, errs.InputInvalidRangeError("post_id is required") + } + if req.PageSize <= 0 { + req.PageSize = 20 + } + if req.PageIndex <= 0 { + req.PageIndex = 1 + } + + // 構建查詢參數 + params := &domainRepo.CommentQueryParams{ + PostID: &req.PostID, + ParentID: req.ParentID, + PageSize: req.PageSize, + PageIndex: req.PageIndex, + OrderBy: req.OrderBy, + OrderDirection: req.OrderDirection, + } + + // 如果 OrderBy 未指定,預設為 created_at + if params.OrderBy == "" { + params.OrderBy = "created_at" + } + // 如果 OrderDirection 未指定,預設為 ASC + if params.OrderDirection == "" { + params.OrderDirection = "ASC" + } + + // 執行查詢 + comments, total, err := uc.Comment.FindByPostID(ctx, req.PostID, params) + if err != nil { + return nil, uc.handleDBError("Comment.FindByPostID", req, err) + } + + // 轉換為 Response + responses := make([]domainUsecase.CommentResponse, len(comments)) + for i, c := range comments { + responses[i] = *uc.mapCommentToResponse(c) + } + + return &domainUsecase.ListCommentsResponse{ + Data: responses, + Page: domainUsecase.Pager{ + PageIndex: req.PageIndex, + PageSize: req.PageSize, + Total: total, + TotalPage: calculateTotalPages(total, req.PageSize), + }, + }, nil +} + +// ListReplies 列出回覆 +func (uc *CommentUseCase) ListReplies(ctx context.Context, req domainUsecase.ListRepliesRequest) (*domainUsecase.ListCommentsResponse, error) { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.CommentID == zeroUUID { + return nil, errs.InputInvalidRangeError("comment_id is required") + } + if req.PageSize <= 0 { + req.PageSize = 20 + } + if req.PageIndex <= 0 { + req.PageIndex = 1 + } + + // 構建查詢參數 + params := &domainRepo.CommentQueryParams{ + PageSize: req.PageSize, + PageIndex: req.PageIndex, + OrderBy: "created_at", + OrderDirection: "ASC", + } + + // 執行查詢 + comments, total, err := uc.Comment.FindReplies(ctx, req.CommentID, params) + if err != nil { + return nil, uc.handleDBError("Comment.FindReplies", req, err) + } + + // 轉換為 Response + responses := make([]domainUsecase.CommentResponse, len(comments)) + for i, c := range comments { + responses[i] = *uc.mapCommentToResponse(c) + } + + return &domainUsecase.ListCommentsResponse{ + Data: responses, + Page: domainUsecase.Pager{ + PageIndex: req.PageIndex, + PageSize: req.PageSize, + Total: total, + TotalPage: calculateTotalPages(total, req.PageSize), + }, + }, nil +} + +// ListCommentsByAuthor 根據作者列出評論 +func (uc *CommentUseCase) ListCommentsByAuthor(ctx context.Context, req domainUsecase.ListCommentsByAuthorRequest) (*domainUsecase.ListCommentsResponse, error) { + if req.AuthorUID == "" { + return nil, errs.InputInvalidRangeError("author_uid is required") + } + if req.PageSize <= 0 { + req.PageSize = 20 + } + if req.PageIndex <= 0 { + req.PageIndex = 1 + } + + params := &domainRepo.CommentQueryParams{ + PageSize: req.PageSize, + PageIndex: req.PageIndex, + OrderBy: "created_at", + OrderDirection: "DESC", + } + + comments, total, err := uc.Comment.FindByAuthorUID(ctx, req.AuthorUID, params) + if err != nil { + return nil, uc.handleDBError("Comment.FindByAuthorUID", req, err) + } + + responses := make([]domainUsecase.CommentResponse, len(comments)) + for i, c := range comments { + responses[i] = *uc.mapCommentToResponse(c) + } + + return &domainUsecase.ListCommentsResponse{ + Data: responses, + Page: domainUsecase.Pager{ + PageIndex: req.PageIndex, + PageSize: req.PageSize, + Total: total, + TotalPage: calculateTotalPages(total, req.PageSize), + }, + }, nil +} + +// LikeComment 按讚評論 +func (uc *CommentUseCase) LikeComment(ctx context.Context, req domainUsecase.LikeCommentRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.CommentID == zeroUUID { + return errs.InputInvalidRangeError("comment_id is required") + } + if req.UserUID == "" { + return errs.InputInvalidRangeError("user_uid is required") + } + + // 檢查是否已經按讚 + existingLike, err := uc.Like.FindByTargetAndUser(ctx, req.CommentID, req.UserUID, "comment") + if err == nil && existingLike != nil { + // 已經按讚,直接返回成功 + return nil + } + if err != nil && !repository.IsNotFound(err) { + return uc.handleDBError("Like.FindByTargetAndUser", req, err) + } + + // 建立按讚記錄 + like := &entity.Like{ + TargetID: req.CommentID, + UserUID: req.UserUID, + TargetType: "comment", + } + + if err := uc.Like.Insert(ctx, like); err != nil { + return uc.handleDBError("Like.Insert", req, err) + } + + // 增加評論的按讚數 + if err := uc.Comment.IncrementLikeCount(ctx, req.CommentID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to increment like count: %v", err)) + } + + return nil +} + +// UnlikeComment 取消按讚評論 +func (uc *CommentUseCase) UnlikeComment(ctx context.Context, req domainUsecase.UnlikeCommentRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.CommentID == zeroUUID { + return errs.InputInvalidRangeError("comment_id is required") + } + if req.UserUID == "" { + return errs.InputInvalidRangeError("user_uid is required") + } + + // 刪除按讚記錄 + if err := uc.Like.DeleteByTargetAndUser(ctx, req.CommentID, req.UserUID, "comment"); err != nil { + if repository.IsNotFound(err) { + // 已經取消按讚,直接返回成功 + return nil + } + return uc.handleDBError("Like.DeleteByTargetAndUser", req, err) + } + + // 減少評論的按讚數 + if err := uc.Comment.DecrementLikeCount(ctx, req.CommentID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to decrement like count: %v", err)) + } + + return nil +} + +// validateCreateCommentRequest 驗證建立評論請求 +func (uc *CommentUseCase) validateCreateCommentRequest(req domainUsecase.CreateCommentRequest) error { + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return errs.InputInvalidRangeError("post_id is required") + } + if req.AuthorUID == "" { + return errs.InputInvalidRangeError("author_uid is required") + } + if req.Content == "" { + return errs.InputInvalidRangeError("content is required") + } + return nil +} + +// mapCommentToResponse 將 Comment 實體轉換為 CommentResponse +func (uc *CommentUseCase) mapCommentToResponse(comment *entity.Comment) *domainUsecase.CommentResponse { + return &domainUsecase.CommentResponse{ + ID: comment.ID, + PostID: comment.PostID, + AuthorUID: comment.AuthorUID, + ParentID: comment.ParentID, + Content: comment.Content, + Status: comment.Status, + LikeCount: comment.LikeCount, + ReplyCount: comment.ReplyCount, + CreatedAt: comment.CreatedAt, + UpdatedAt: comment.UpdatedAt, + } +} + +// handleDBError 處理資料庫錯誤 +func (uc *CommentUseCase) handleDBError(funcName string, req any, err error) error { + return errs.DBErrorErrorL( + uc.Logger, + []errs.LogField{ + {Key: "func", Val: funcName}, + {Key: "req", Val: req}, + {Key: "error", Val: err.Error()}, + }, + fmt.Sprintf("database operation failed: %s", funcName), + ).Wrap(err) +} + diff --git a/pkg/post/usecase/post.go b/pkg/post/usecase/post.go new file mode 100644 index 0000000..df07a28 --- /dev/null +++ b/pkg/post/usecase/post.go @@ -0,0 +1,801 @@ +package usecase + +import ( + "context" + "errors" + "fmt" + "math" + + errs "backend/pkg/library/errors" + "backend/pkg/post/domain/entity" + "backend/pkg/post/domain/post" + domainRepo "backend/pkg/post/domain/repository" + domainUsecase "backend/pkg/post/domain/usecase" + "backend/pkg/post/repository" + + "github.com/gocql/gocql" +) + +// PostUseCaseParam 定義 PostUseCase 的初始化參數 +type PostUseCaseParam struct { + Post domainRepo.PostRepository + Comment domainRepo.CommentRepository + Like domainRepo.LikeRepository + Tag domainRepo.TagRepository + Category domainRepo.CategoryRepository + Logger errs.Logger +} + +// PostUseCase 實作 domain usecase 介面 +type PostUseCase struct { + PostUseCaseParam +} + +// MustPostUseCase 創建新的 PostUseCase(如果失敗會 panic) +func MustPostUseCase(param PostUseCaseParam) domainUsecase.PostUseCase { + return &PostUseCase{ + PostUseCaseParam: param, + } +} + +// CreatePost 創建新貼文 +func (uc *PostUseCase) CreatePost(ctx context.Context, req domainUsecase.CreatePostRequest) (*domainUsecase.PostResponse, error) { + // 驗證輸入 + if err := uc.validateCreatePostRequest(req); err != nil { + return nil, err + } + + // 建立貼文實體 + post := &entity.Post{ + AuthorUID: req.AuthorUID, + Title: req.Title, + Content: req.Content, + Type: req.Type, + CategoryID: req.CategoryID, + Tags: req.Tags, + Images: req.Images, + VideoURL: req.VideoURL, + LinkURL: req.LinkURL, + Status: req.Status, + } + + // 如果狀態未指定,預設為草稿 + if post.Status == 0 { + post.Status = post.PostStatusDraft + } + + // 插入資料庫 + if err := uc.Post.Insert(ctx, post); err != nil { + return nil, uc.handleDBError("Post.Insert", req, err) + } + + // 處理標籤(更新標籤的貼文數) + if err := uc.updateTagPostCounts(ctx, req.Tags, true); err != nil { + // 記錄錯誤但不中斷流程 + uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err)) + } + + // 處理分類(更新分類的貼文數) + if req.CategoryID != nil { + if err := uc.Category.IncrementPostCount(ctx, *req.CategoryID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to increment category post count: %v", err)) + } + } + + return uc.mapPostToResponse(post), nil +} + +// GetPost 取得貼文 +func (uc *PostUseCase) GetPost(ctx context.Context, req domainUsecase.GetPostRequest) (*domainUsecase.PostResponse, error) { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return nil, errs.InputInvalidRangeError("post_id is required") + } + + // 查詢貼文 + post, err := uc.Post.FindOne(ctx, req.PostID) + if err != nil { + if repository.IsNotFound(err) { + return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID)) + } + return nil, uc.handleDBError("Post.FindOne", req, err) + } + + // 如果提供了 UserUID,增加瀏覽數 + if req.UserUID != nil { + if err := uc.Post.IncrementViewCount(ctx, req.PostID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to increment view count: %v", err)) + } + } + + return uc.mapPostToResponse(post), nil +} + +// UpdatePost 更新貼文 +func (uc *PostUseCase) UpdatePost(ctx context.Context, req domainUsecase.UpdatePostRequest) (*domainUsecase.PostResponse, error) { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return nil, errs.InputInvalidRangeError("post_id is required") + } + if req.AuthorUID == "" { + return nil, errs.InputInvalidRangeError("author_uid is required") + } + + // 查詢現有貼文 + post, err := uc.Post.FindOne(ctx, req.PostID) + if err != nil { + if repository.IsNotFound(err) { + return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID)) + } + return nil, uc.handleDBError("Post.FindOne", req, err) + } + + // 驗證權限 + if post.AuthorUID != req.AuthorUID { + return nil, errs.ResNotFoundError("not authorized to update this post") + } + + // 檢查是否可編輯 + if !post.IsEditable() { + return nil, errs.ResNotFoundError("post is not editable") + } + + // 更新欄位 + if req.Title != nil { + post.Title = *req.Title + } + if req.Content != nil { + post.Content = *req.Content + } + if req.Type != nil { + post.Type = *req.Type + } + if req.CategoryID != nil { + // 更新分類計數 + if post.CategoryID != nil && *post.CategoryID != *req.CategoryID { + if err := uc.Category.DecrementPostCount(ctx, *post.CategoryID); err != nil { + uc.Logger.Error("failed to decrement category post count", errs.LogField{Key: "error", Val: err.Error()}) + } + if err := uc.Category.IncrementPostCount(ctx, *req.CategoryID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to increment category post count: %v", err)) + } + } + post.CategoryID = req.CategoryID + } + if req.Tags != nil { + // 更新標籤計數 + oldTags := post.Tags + post.Tags = req.Tags + if err := uc.updateTagPostCountsDiff(ctx, oldTags, req.Tags); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err)) + } + } + if req.Images != nil { + post.Images = req.Images + } + if req.VideoURL != nil { + post.VideoURL = req.VideoURL + } + if req.LinkURL != nil { + post.LinkURL = req.LinkURL + } + + // 更新資料庫 + if err := uc.Post.Update(ctx, post); err != nil { + return nil, uc.handleDBError("Post.Update", req, err) + } + + return uc.mapPostToResponse(post), nil +} + +// DeletePost 刪除貼文(軟刪除) +func (uc *PostUseCase) DeletePost(ctx context.Context, req domainUsecase.DeletePostRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return errs.InputInvalidRangeError("post_id is required") + } + if req.AuthorUID == "" { + return errs.InputInvalidRangeError("author_uid is required") + } + + // 查詢貼文 + post, err := uc.Post.FindOne(ctx, req.PostID) + if err != nil { + if repository.IsNotFound(err) { + return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID)) + } + return uc.handleDBError("Post.FindOne", req, err) + } + + // 驗證權限 + if post.AuthorUID != req.AuthorUID { + return errs.ResNotFoundError("not authorized to delete this post") + } + + // 刪除貼文 + if err := uc.Post.Delete(ctx, req.PostID); err != nil { + return uc.handleDBError("Post.Delete", req, err) + } + + // 更新標籤和分類計數 + if len(post.Tags) > 0 { + if err := uc.updateTagPostCounts(ctx, post.Tags, false); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err)) + } + } + if post.CategoryID != nil { + if err := uc.Category.DecrementPostCount(ctx, *post.CategoryID); err != nil { + uc.Logger.Error("failed to decrement category post count", errs.LogField{Key: "error", Val: err.Error()}) + } + } + + return nil +} + +// PublishPost 發布貼文 +func (uc *PostUseCase) PublishPost(ctx context.Context, req domainUsecase.PublishPostRequest) (*domainUsecase.PostResponse, error) { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return nil, errs.InputInvalidRangeError("post_id is required") + } + if req.AuthorUID == "" { + return nil, errs.InputInvalidRangeError("author_uid is required") + } + + // 查詢貼文 + post, err := uc.Post.FindOne(ctx, req.PostID) + if err != nil { + if repository.IsNotFound(err) { + return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID)) + } + return nil, uc.handleDBError("Post.FindOne", req, err) + } + + // 驗證權限 + if post.AuthorUID != req.AuthorUID { + return nil, errs.ResNotFoundError("not authorized to publish this post") + } + + // 發布貼文 + post.Publish() + + // 更新資料庫 + if err := uc.Post.Update(ctx, post); err != nil { + return nil, uc.handleDBError("Post.Update", req, err) + } + + return uc.mapPostToResponse(post), nil +} + +// ArchivePost 歸檔貼文 +func (uc *PostUseCase) ArchivePost(ctx context.Context, req domainUsecase.ArchivePostRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return errs.InputInvalidRangeError("post_id is required") + } + if req.AuthorUID == "" { + return errs.InputInvalidRangeError("author_uid is required") + } + + // 查詢貼文 + post, err := uc.Post.FindOne(ctx, req.PostID) + if err != nil { + if repository.IsNotFound(err) { + return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID)) + } + return uc.handleDBError("Post.FindOne", req, err) + } + + // 驗證權限 + if post.AuthorUID != req.AuthorUID { + return errs.ResNotFoundError("not authorized to archive this post") + } + + // 歸檔貼文 + post.Archive() + + // 更新資料庫 + return uc.Post.Update(ctx, post) +} + +// ListPosts 列出貼文 +func (uc *PostUseCase) ListPosts(ctx context.Context, req domainUsecase.ListPostsRequest) (*domainUsecase.ListPostsResponse, error) { + // 驗證分頁參數 + if req.PageSize <= 0 { + req.PageSize = 20 + } + if req.PageIndex <= 0 { + req.PageIndex = 1 + } + + // 構建查詢參數 + params := &domainRepo.PostQueryParams{ + CategoryID: req.CategoryID, + Tag: req.Tag, + Status: req.Status, + Type: req.Type, + AuthorUID: req.AuthorUID, + CreateStartTime: req.CreateStartTime, + CreateEndTime: req.CreateEndTime, + PageSize: req.PageSize, + PageIndex: req.PageIndex, + OrderBy: req.OrderBy, + OrderDirection: req.OrderDirection, + } + + // 執行查詢 + var posts []*entity.Post + var total int64 + var err error + + if req.CategoryID != nil { + posts, total, err = uc.Post.FindByCategoryID(ctx, *req.CategoryID, params) + } else if req.Tag != nil { + posts, total, err = uc.Post.FindByTag(ctx, *req.Tag, params) + } else if req.AuthorUID != nil { + posts, total, err = uc.Post.FindByAuthorUID(ctx, *req.AuthorUID, params) + } else if req.Status != nil { + posts, total, err = uc.Post.FindByStatus(ctx, *req.Status, params) + } else { + // 預設查詢所有已發布的貼文 + published := post.PostStatusPublished + params.Status = &published + posts, total, err = uc.Post.FindByStatus(ctx, published, params) + } + + if err != nil { + return nil, uc.handleDBError("Post.FindBy*", req, err) + } + + // 轉換為 Response + responses := make([]domainUsecase.PostResponse, len(posts)) + for i, p := range posts { + responses[i] = *uc.mapPostToResponse(p) + } + + return &domainUsecase.ListPostsResponse{ + Data: responses, + Page: domainUsecase.Pager{ + PageIndex: req.PageIndex, + PageSize: req.PageSize, + Total: total, + TotalPage: calculateTotalPages(total, req.PageSize), + }, + }, nil +} + +// ListPostsByAuthor 根據作者列出貼文 +func (uc *PostUseCase) ListPostsByAuthor(ctx context.Context, req domainUsecase.ListPostsByAuthorRequest) (*domainUsecase.ListPostsResponse, error) { + if req.AuthorUID == "" { + return nil, errs.InputInvalidRangeError("author_uid is required") + } + if req.PageSize <= 0 { + req.PageSize = 20 + } + if req.PageIndex <= 0 { + req.PageIndex = 1 + } + + params := &domainRepo.PostQueryParams{ + Status: req.Status, + PageSize: req.PageSize, + PageIndex: req.PageIndex, + OrderBy: "created_at", + OrderDirection: "DESC", + } + + posts, total, err := uc.Post.FindByAuthorUID(ctx, req.AuthorUID, params) + if err != nil { + return nil, uc.handleDBError("Post.FindByAuthorUID", req, err) + } + + responses := make([]domainUsecase.PostResponse, len(posts)) + for i, p := range posts { + responses[i] = *uc.mapPostToResponse(p) + } + + return &domainUsecase.ListPostsResponse{ + Data: responses, + Page: domainUsecase.Pager{ + PageIndex: req.PageIndex, + PageSize: req.PageSize, + Total: total, + TotalPage: calculateTotalPages(total, req.PageSize), + }, + }, nil +} + +// ListPostsByCategory 根據分類列出貼文 +func (uc *PostUseCase) ListPostsByCategory(ctx context.Context, req domainUsecase.ListPostsByCategoryRequest) (*domainUsecase.ListPostsResponse, error) { + var zeroUUID gocql.UUID + if req.CategoryID == zeroUUID { + return nil, errs.InputInvalidRangeError("category_id is required") + } + if req.PageSize <= 0 { + req.PageSize = 20 + } + if req.PageIndex <= 0 { + req.PageIndex = 1 + } + + params := &domainRepo.PostQueryParams{ + Status: req.Status, + PageSize: req.PageSize, + PageIndex: req.PageIndex, + OrderBy: "created_at", + OrderDirection: "DESC", + } + + posts, total, err := uc.Post.FindByCategoryID(ctx, req.CategoryID, params) + if err != nil { + return nil, uc.handleDBError("Post.FindByCategoryID", req, err) + } + + responses := make([]domainUsecase.PostResponse, len(posts)) + for i, p := range posts { + responses[i] = *uc.mapPostToResponse(p) + } + + return &domainUsecase.ListPostsResponse{ + Data: responses, + Page: domainUsecase.Pager{ + PageIndex: req.PageIndex, + PageSize: req.PageSize, + Total: total, + TotalPage: calculateTotalPages(total, req.PageSize), + }, + }, nil +} + +// ListPostsByTag 根據標籤列出貼文 +func (uc *PostUseCase) ListPostsByTag(ctx context.Context, req domainUsecase.ListPostsByTagRequest) (*domainUsecase.ListPostsResponse, error) { + if req.Tag == "" { + return nil, errs.InputInvalidRangeError("tag is required") + } + if req.PageSize <= 0 { + req.PageSize = 20 + } + if req.PageIndex <= 0 { + req.PageIndex = 1 + } + + params := &domainRepo.PostQueryParams{ + Status: req.Status, + PageSize: req.PageSize, + PageIndex: req.PageIndex, + OrderBy: "created_at", + OrderDirection: "DESC", + } + + posts, total, err := uc.Post.FindByTag(ctx, req.Tag, params) + if err != nil { + return nil, uc.handleDBError("Post.FindByTag", req, err) + } + + responses := make([]domainUsecase.PostResponse, len(posts)) + for i, p := range posts { + responses[i] = *uc.mapPostToResponse(p) + } + + return &domainUsecase.ListPostsResponse{ + Data: responses, + Page: domainUsecase.Pager{ + PageIndex: req.PageIndex, + PageSize: req.PageSize, + Total: total, + TotalPage: calculateTotalPages(total, req.PageSize), + }, + }, nil +} + +// GetPinnedPosts 取得置頂貼文 +func (uc *PostUseCase) GetPinnedPosts(ctx context.Context, req domainUsecase.GetPinnedPostsRequest) (*domainUsecase.ListPostsResponse, error) { + limit := int64(10) + if req.Limit > 0 { + limit = req.Limit + } + + posts, err := uc.Post.FindPinnedPosts(ctx, limit) + if err != nil { + return nil, uc.handleDBError("Post.FindPinnedPosts", req, err) + } + + responses := make([]domainUsecase.PostResponse, len(posts)) + for i, p := range posts { + responses[i] = *uc.mapPostToResponse(p) + } + + return &domainUsecase.ListPostsResponse{ + Data: responses, + Page: domainUsecase.Pager{ + PageIndex: 1, + PageSize: limit, + Total: int64(len(responses)), + TotalPage: 1, + }, + }, nil +} + +// LikePost 按讚貼文 +func (uc *PostUseCase) LikePost(ctx context.Context, req domainUsecase.LikePostRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return errs.InputInvalidRangeError("post_id is required") + } + if req.UserUID == "" { + return errs.InputInvalidRangeError("user_uid is required") + } + + // 檢查是否已經按讚 + existingLike, err := uc.Like.FindByTargetAndUser(ctx, req.PostID, req.UserUID, "post") + if err == nil && existingLike != nil { + // 已經按讚,直接返回成功 + return nil + } + if err != nil && !repository.IsNotFound(err) { + return uc.handleDBError("Like.FindByTargetAndUser", req, err) + } + + // 建立按讚記錄 + like := &entity.Like{ + TargetID: req.PostID, + UserUID: req.UserUID, + TargetType: "post", + } + + if err := uc.Like.Insert(ctx, like); err != nil { + return uc.handleDBError("Like.Insert", req, err) + } + + // 增加貼文的按讚數 + if err := uc.Post.IncrementLikeCount(ctx, req.PostID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to increment like count: %v", err)) + } + + return nil +} + +// UnlikePost 取消按讚 +func (uc *PostUseCase) UnlikePost(ctx context.Context, req domainUsecase.UnlikePostRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return errs.InputInvalidRangeError("post_id is required") + } + if req.UserUID == "" { + return errs.InputInvalidRangeError("user_uid is required") + } + + // 刪除按讚記錄 + if err := uc.Like.DeleteByTargetAndUser(ctx, req.PostID, req.UserUID, "post"); err != nil { + if repository.IsNotFound(err) { + // 已經取消按讚,直接返回成功 + return nil + } + return uc.handleDBError("Like.DeleteByTargetAndUser", req, err) + } + + // 減少貼文的按讚數 + if err := uc.Post.DecrementLikeCount(ctx, req.PostID); err != nil { + uc.Logger.Error(fmt.Sprintf("failed to decrement like count: %v", err)) + } + + return nil +} + +// ViewPost 瀏覽貼文(增加瀏覽數) +func (uc *PostUseCase) ViewPost(ctx context.Context, req domainUsecase.ViewPostRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return errs.InputInvalidRangeError("post_id is required") + } + + // 增加瀏覽數 + if err := uc.Post.IncrementViewCount(ctx, req.PostID); err != nil { + return uc.handleDBError("Post.IncrementViewCount", req, err) + } + + return nil +} + +// PinPost 置頂貼文 +func (uc *PostUseCase) PinPost(ctx context.Context, req domainUsecase.PinPostRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return errs.InputInvalidRangeError("post_id is required") + } + if req.AuthorUID == "" { + return errs.InputInvalidRangeError("author_uid is required") + } + + // 查詢貼文 + post, err := uc.Post.FindOne(ctx, req.PostID) + if err != nil { + if repository.IsNotFound(err) { + return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID)) + } + return uc.handleDBError("Post.FindOne", req, err) + } + + // 驗證權限 + if post.AuthorUID != req.AuthorUID { + return errs.ResNotFoundError("not authorized to pin this post") + } + + // 置頂貼文 + return uc.Post.PinPost(ctx, req.PostID) +} + +// UnpinPost 取消置頂 +func (uc *PostUseCase) UnpinPost(ctx context.Context, req domainUsecase.UnpinPostRequest) error { + // 驗證輸入 + var zeroUUID gocql.UUID + if req.PostID == zeroUUID { + return errs.InputInvalidRangeError("post_id is required") + } + if req.AuthorUID == "" { + return errs.InputInvalidRangeError("author_uid is required") + } + + // 查詢貼文 + post, err := uc.Post.FindOne(ctx, req.PostID) + if err != nil { + if repository.IsNotFound(err) { + return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID)) + } + return uc.handleDBError("Post.FindOne", req, err) + } + + // 驗證權限 + if post.AuthorUID != req.AuthorUID { + return errs.ResNotFoundError("not authorized to unpin this post") + } + + // 取消置頂 + return uc.Post.UnpinPost(ctx, req.PostID) +} + +// validateCreatePostRequest 驗證建立貼文請求 +func (uc *PostUseCase) validateCreatePostRequest(req domainUsecase.CreatePostRequest) error { + if req.AuthorUID == "" { + return errs.InputInvalidRangeError("author_uid is required") + } + if req.Title == "" { + return errs.InputInvalidRangeError("title is required") + } + if req.Content == "" { + return errs.InputInvalidRangeError("content is required") + } + if !req.Type.IsValid() { + return errs.InputInvalidRangeError("invalid post type") + } + return nil +} + +// mapPostToResponse 將 Post 實體轉換為 PostResponse +func (uc *PostUseCase) mapPostToResponse(post *entity.Post) *domainUsecase.PostResponse { + return &domainUsecase.PostResponse{ + ID: post.ID, + AuthorUID: post.AuthorUID, + Title: post.Title, + Content: post.Content, + Type: post.Type, + Status: post.Status, + CategoryID: post.CategoryID, + Tags: post.Tags, + Images: post.Images, + VideoURL: post.VideoURL, + LinkURL: post.LinkURL, + LikeCount: post.LikeCount, + CommentCount: post.CommentCount, + ViewCount: post.ViewCount, + IsPinned: post.IsPinned, + PinnedAt: post.PinnedAt, + PublishedAt: post.PublishedAt, + CreatedAt: post.CreatedAt, + UpdatedAt: post.UpdatedAt, + } +} + +// handleDBError 處理資料庫錯誤 +func (uc *PostUseCase) handleDBError(funcName string, req any, err error) error { + return errs.DBErrorErrorL( + uc.Logger, + []errs.LogField{ + {Key: "func", Val: funcName}, + {Key: "req", Val: req}, + {Key: "error", Val: err.Error()}, + }, + fmt.Sprintf("database operation failed: %s", funcName), + ).Wrap(err) +} + +// updateTagPostCounts 更新標籤的貼文數 +func (uc *PostUseCase) updateTagPostCounts(ctx context.Context, tags []string, increment bool) error { + if len(tags) == 0 { + return nil + } + + // 查詢或建立標籤 + for _, tagName := range tags { + tag, err := uc.Tag.FindByName(ctx, tagName) + if err != nil { + if repository.IsNotFound(err) { + // 建立新標籤 + newTag := &entity.Tag{ + Name: tagName, + } + if err := uc.Tag.Insert(ctx, newTag); err != nil { + return fmt.Errorf("failed to create tag: %w", err) + } + tag = newTag + } else { + return fmt.Errorf("failed to find tag: %w", err) + } + } + + // 更新計數 + if increment { + if err := uc.Tag.IncrementPostCount(ctx, tag.ID); err != nil { + return fmt.Errorf("failed to increment tag count: %w", err) + } + } else { + if err := uc.Tag.DecrementPostCount(ctx, tag.ID); err != nil { + return fmt.Errorf("failed to decrement tag count: %w", err) + } + } + } + + return nil +} + +// updateTagPostCountsDiff 更新標籤計數(處理差異) +func (uc *PostUseCase) updateTagPostCountsDiff(ctx context.Context, oldTags, newTags []string) error { + // 找出新增和刪除的標籤 + oldTagMap := make(map[string]bool) + for _, tag := range oldTags { + oldTagMap[tag] = true + } + + newTagMap := make(map[string]bool) + for _, tag := range newTags { + newTagMap[tag] = true + } + + // 新增的標籤 + for _, tag := range newTags { + if !oldTagMap[tag] { + if err := uc.updateTagPostCounts(ctx, []string{tag}, true); err != nil { + return err + } + } + } + + // 刪除的標籤 + for _, tag := range oldTags { + if !newTagMap[tag] { + if err := uc.updateTagPostCounts(ctx, []string{tag}, false); err != nil { + return err + } + } + } + + return nil +} + +// calculateTotalPages 計算總頁數 +func calculateTotalPages(total, pageSize int64) int64 { + if pageSize <= 0 { + return 0 + } + return int64(math.Ceil(float64(total) / float64(pageSize))) +} +