feat: add cassandra lib

This commit is contained in:
王性驊 2025-11-19 17:06:44 +08:00
parent 1786e7c690
commit 97d6c8f499
15 changed files with 3389 additions and 466 deletions

View File

@ -504,6 +504,79 @@ err = userRepo.Query().
Scan(ctx, &users) Scan(ctx, &users)
``` ```
## SAI 索引管理
### 建立 SAI 索引
```go
// 檢查是否支援 SAI
if !db.SaiSupported() {
log.Fatal("SAI is not supported in this Cassandra version")
}
// 建立標準索引
err := db.CreateSAIIndex(ctx, "my_keyspace", "users", "email", "users_email_idx", nil)
if err != nil {
log.Printf("建立索引失敗: %v", err)
}
// 建立全文索引(不區分大小寫)
opts := &cassandra.SAIIndexOptions{
IndexType: cassandra.SAIIndexTypeFullText,
IsAsync: false,
CaseSensitive: false,
}
err = db.CreateSAIIndex(ctx, "my_keyspace", "posts", "content", "posts_content_ft_idx", opts)
```
### 查詢 SAI 索引
```go
// 列出資料表的所有 SAI 索引
indexes, err := db.ListSAIIndexes(ctx, "my_keyspace", "users")
if err != nil {
log.Printf("查詢索引失敗: %v", err)
} else {
for _, idx := range indexes {
fmt.Printf("索引: %s, 欄位: %s, 類型: %s\n", idx.Name, idx.Column, idx.Type)
}
}
// 檢查索引是否存在
exists, err := db.CheckSAIIndexExists(ctx, "my_keyspace", "users_email_idx")
if err != nil {
log.Printf("檢查索引失敗: %v", err)
} else if exists {
fmt.Println("索引存在")
}
```
### 刪除 SAI 索引
```go
// 刪除索引
err := db.DropSAIIndex(ctx, "my_keyspace", "users_email_idx")
if err != nil {
log.Printf("刪除索引失敗: %v", err)
}
```
### SAI 索引類型
- **SAIIndexTypeStandard**: 標準索引(等於查詢)
- **SAIIndexTypeCollection**: 集合索引(用於 list、set、map
- **SAIIndexTypeFullText**: 全文索引
### SAI 索引選項
```go
opts := &cassandra.SAIIndexOptions{
IndexType: cassandra.SAIIndexTypeFullText, // 索引類型
IsAsync: false, // 是否異步建立
CaseSensitive: true, // 是否區分大小寫
}
```
## 注意事項 ## 注意事項
### 1. 主鍵要求 ### 1. 主鍵要求
@ -540,6 +613,8 @@ err = userRepo.Query().
- 建立索引前請先檢查 `db.SaiSupported()` - 建立索引前請先檢查 `db.SaiSupported()`
- 索引建立是異步操作,可能需要一些時間 - 索引建立是異步操作,可能需要一些時間
- 刪除索引時使用 `IF EXISTS`,避免索引不存在時報錯 - 刪除索引時使用 `IF EXISTS`,避免索引不存在時報錯
- 使用 SAI 索引可以大幅提升非主鍵欄位的查詢效能
- 全文索引支援不區分大小寫的搜尋
## 完整範例 ## 完整範例

View File

@ -12,43 +12,43 @@ import (
type SAIIndexType string type SAIIndexType string
const ( const (
// SAIIndexTypeStandard 標準索引(預設) // SAIIndexTypeStandard 標準索引(等於查詢)
SAIIndexTypeStandard SAIIndexType = "standard" SAIIndexTypeStandard SAIIndexType = "STANDARD"
// SAIIndexTypeFrozen 用於 frozen 類型 // SAIIndexTypeCollection 集合索引(用於 list、set、map
SAIIndexTypeFrozen SAIIndexType = "frozen" SAIIndexTypeCollection SAIIndexType = "COLLECTION"
// SAIIndexTypeFullText 全文索引
SAIIndexTypeFullText SAIIndexType = "FULL_TEXT"
) )
// SAIIndexOptions 定義 SAI 索引選項 // SAIIndexOptions 定義 SAI 索引選項
type SAIIndexOptions struct { type SAIIndexOptions struct {
CaseSensitive *bool // 是否區分大小寫預設true IndexType SAIIndexType // 索引類型
Normalize *bool // 是否正規化預設false IsAsync bool // 是否異步建立索引
Analyzer string // 分析器(如 "StandardAnalyzer" CaseSensitive bool // 是否區分大小寫(用於全文索引
} }
// SAIIndexInfo 表示 SAI 索引資訊 // DefaultSAIIndexOptions 返回預設的 SAI 索引選項
type SAIIndexInfo struct { func DefaultSAIIndexOptions() *SAIIndexOptions {
KeyspaceName string // Keyspace 名稱 return &SAIIndexOptions{
TableName string // 表名稱 IndexType: SAIIndexTypeStandard,
IndexName string // 索引名稱 IsAsync: false,
ColumnName string // 欄位名稱 CaseSensitive: true,
IndexType string // 索引類型 }
Options map[string]string // 索引選項
} }
// CreateSAIIndex 建立 SAI 索引 // CreateSAIIndex 建立 SAI 索引
// keyspace: keyspace 名稱,如果為空則使用預設 keyspace // keyspace: keyspace 名稱
// table: 表名稱 // table: 資料表名稱
// column: 欄位名稱 // column: 欄位名稱
// indexName: 索引名稱(可選,如果為空則自動生成) // indexName: 索引名稱(可選,如果為空則自動生成)
// options: 索引選項(可選) // opts: 索引選項(可選,如果為 nil 則使用預設選項)
func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column string, indexName string, options *SAIIndexOptions) error { func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column, indexName string, opts *SAIIndexOptions) error {
// 檢查是否支援 SAI
if !db.saiSupported { if !db.saiSupported {
return ErrSAINotSupported return ErrInvalidInput.WithError(fmt.Errorf("SAI is not supported in Cassandra version %s (requires 4.0.9+ or 5.0+)", db.version))
} }
if keyspace == "" { // 驗證參數
keyspace = db.defaultKeyspace
}
if keyspace == "" { if keyspace == "" {
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
} }
@ -59,51 +59,71 @@ func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column string
return ErrInvalidInput.WithError(fmt.Errorf("column is required")) return ErrInvalidInput.WithError(fmt.Errorf("column is required"))
} }
// 生成索引名稱(如果未提供) // 使用預設選項如果未提供
if opts == nil {
opts = DefaultSAIIndexOptions()
}
// 生成索引名稱如果未提供
if indexName == "" { if indexName == "" {
indexName = fmt.Sprintf("%s_%s_%s_idx", table, column, "sai") indexName = fmt.Sprintf("%s_%s_sai_idx", table, column)
} }
// 構建 CREATE INDEX 語句 // 構建 CREATE INDEX 語句
stmt := fmt.Sprintf("CREATE INDEX %s ON %s.%s (%s) USING 'sai'", indexName, keyspace, table, column) var stmt strings.Builder
stmt.WriteString("CREATE CUSTOM INDEX IF NOT EXISTS ")
stmt.WriteString(indexName)
stmt.WriteString(" ON ")
stmt.WriteString(keyspace)
stmt.WriteString(".")
stmt.WriteString(table)
stmt.WriteString(" (")
stmt.WriteString(column)
stmt.WriteString(") USING 'StorageAttachedIndex'")
// 添加選項 // 添加選項
if options != nil { var options []string
opts := make([]string, 0) if opts.IsAsync {
if options.CaseSensitive != nil { options = append(options, "'async'='true'")
opts = append(opts, fmt.Sprintf("'case_sensitive': %v", *options.CaseSensitive))
}
if options.Normalize != nil {
opts = append(opts, fmt.Sprintf("'normalize': %v", *options.Normalize))
}
if options.Analyzer != "" {
opts = append(opts, fmt.Sprintf("'analyzer': '%s'", options.Analyzer))
}
if len(opts) > 0 {
stmt += " WITH OPTIONS = {" + strings.Join(opts, ", ") + "}"
}
} }
// 執行建立索引 // 根據索引類型添加特定選項
q := db.session.Query(stmt, nil).WithContext(ctx).Consistency(gocql.Quorum) switch opts.IndexType {
if err := q.ExecRelease(); err != nil { case SAIIndexTypeFullText:
return ErrInvalidInput.WithTable(table).WithError(fmt.Errorf("failed to create SAI index: %w", err)) if !opts.CaseSensitive {
options = append(options, "'case_sensitive'='false'")
} else {
options = append(options, "'case_sensitive'='true'")
}
case SAIIndexTypeCollection:
// Collection 索引不需要額外選項
}
// 如果有選項,添加到語句中
if len(options) > 0 {
stmt.WriteString(" WITH OPTIONS = {")
stmt.WriteString(strings.Join(options, ", "))
stmt.WriteString("}")
}
// 執行建立索引語句
query := db.session.Query(stmt.String(), nil).
WithContext(ctx).
Consistency(gocql.Quorum)
err := query.ExecRelease()
if err != nil {
return ErrInvalidInput.WithError(fmt.Errorf("failed to create SAI index: %w", err))
} }
return nil return nil
} }
// DropSAIIndex 刪除 SAI 索引 // DropSAIIndex 刪除 SAI 索引
// keyspace: keyspace 名稱,如果為空則使用預設 keyspace // keyspace: keyspace 名稱
// indexName: 索引名稱 // indexName: 索引名稱
func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) error { func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) error {
if !db.saiSupported { // 驗證參數
return ErrSAINotSupported
}
if keyspace == "" {
keyspace = db.defaultKeyspace
}
if keyspace == "" { if keyspace == "" {
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
} }
@ -114,73 +134,66 @@ func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) erro
// 構建 DROP INDEX 語句 // 構建 DROP INDEX 語句
stmt := fmt.Sprintf("DROP INDEX IF EXISTS %s.%s", keyspace, indexName) stmt := fmt.Sprintf("DROP INDEX IF EXISTS %s.%s", keyspace, indexName)
// 執行刪除索引 // 執行刪除索引語句
q := db.session.Query(stmt, nil).WithContext(ctx).Consistency(gocql.Quorum) query := db.session.Query(stmt, nil).
if err := q.ExecRelease(); err != nil { WithContext(ctx).
Consistency(gocql.Quorum)
err := query.ExecRelease()
if err != nil {
return ErrInvalidInput.WithError(fmt.Errorf("failed to drop SAI index: %w", err)) return ErrInvalidInput.WithError(fmt.Errorf("failed to drop SAI index: %w", err))
} }
return nil return nil
} }
// ListSAIIndexes 列出指定表的 SAI 索引 // ListSAIIndexes 列出指定資料表的所有 SAI 索引
// keyspace: keyspace 名稱,如果為空則使用預設 keyspace // keyspace: keyspace 名稱
// table: 表名稱(可選,如果為空則列出所有表的索引) // table: 資料表名稱
func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAIIndexInfo, error) { func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAIIndexInfo, error) {
if !db.saiSupported { // 驗證參數
return nil, ErrSAINotSupported
}
if keyspace == "" {
keyspace = db.defaultKeyspace
}
if keyspace == "" { if keyspace == "" {
return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required")) return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
} }
if table == "" {
// 構建查詢語句 return nil, ErrInvalidInput.WithError(fmt.Errorf("table is required"))
// system_schema.indexes 表的欄位keyspace_name, table_name, index_name, kind, options, index_type
stmt := "SELECT keyspace_name, table_name, index_name, kind, options FROM system_schema.indexes WHERE keyspace_name = ?"
args := []interface{}{keyspace}
names := []string{"keyspace_name"}
if table != "" {
stmt += " AND table_name = ?"
args = append(args, table)
names = append(names, "table_name")
} }
// 執行查詢 // 查詢系統表獲取索引資訊
// system_schema.indexes 表的結構keyspace_name, table_name, index_name, kind, options
stmt := `
SELECT index_name, kind, options
FROM system_schema.indexes
WHERE keyspace_name = ? AND table_name = ?
`
var indexes []SAIIndexInfo var indexes []SAIIndexInfo
iter := db.session.Query(stmt, names).Bind(args...).WithContext(ctx).Consistency(gocql.One).Iter() iter := db.session.Query(stmt, []string{"keyspace_name", "table_name"}).
WithContext(ctx).
Consistency(gocql.One).
Bind(keyspace, table).
Iter()
var keyspaceName, tableName, indexName, kind string var indexName, kind string
var options map[string]string var options map[string]string
for iter.Scan(&indexName, &kind, &options) {
for iter.Scan(&keyspaceName, &tableName, &indexName, &kind, &options) { // 檢查是否為 SAI 索引kind = 'CUSTOM' 且 class_name 包含 StorageAttachedIndex
// 只處理 SAI 索引kind = 'CUSTOM' 且 index_type 在 options 中) if kind == "CUSTOM" {
indexType, ok := options["class_name"] if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") {
if !ok || !strings.Contains(indexType, "StorageAttachedIndex") { // 從 options 中提取 target欄位名稱
continue
}
// 從 options 中提取 column_name
// SAI 索引的 target 欄位在 options 中
columnName := "" columnName := ""
if target, ok := options["target"]; ok { if target, ok := options["target"]; ok {
// target 格式通常是 "column_name" 或 "(column_name)"
columnName = strings.Trim(target, "()\"'") columnName = strings.Trim(target, "()\"'")
} }
indexes = append(indexes, SAIIndexInfo{ indexes = append(indexes, SAIIndexInfo{
KeyspaceName: keyspaceName, Name: indexName,
TableName: tableName, Type: "StorageAttachedIndex",
IndexName: indexName,
ColumnName: columnName,
IndexType: "sai",
Options: options, Options: options,
Column: columnName,
}) })
} }
}
}
if err := iter.Close(); err != nil { if err := iter.Close(); err != nil {
return nil, ErrInvalidInput.WithError(fmt.Errorf("failed to list SAI indexes: %w", err)) return nil, ErrInvalidInput.WithError(fmt.Errorf("failed to list SAI indexes: %w", err))
@ -189,59 +202,88 @@ func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAI
return indexes, nil return indexes, nil
} }
// GetSAIIndex 獲取指定索引的資訊 // SAIIndexInfo 表示 SAI 索引資訊
// keyspace: keyspace 名稱,如果為空則使用預設 keyspace type SAIIndexInfo struct {
// indexName: 索引名稱 Name string // 索引名稱
func (db *DB) GetSAIIndex(ctx context.Context, keyspace, indexName string) (*SAIIndexInfo, error) { Type string // 索引類型
if !db.saiSupported { Options map[string]string // 索引選項
return nil, ErrSAINotSupported Column string // 索引欄位名稱
} }
// CheckSAIIndexExists 檢查 SAI 索引是否存在
// keyspace: keyspace 名稱
// indexName: 索引名稱
func (db *DB) CheckSAIIndexExists(ctx context.Context, keyspace, indexName string) (bool, error) {
// 驗證參數
if keyspace == "" { if keyspace == "" {
keyspace = db.defaultKeyspace return false, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
}
if keyspace == "" {
return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
} }
if indexName == "" { if indexName == "" {
return nil, ErrInvalidInput.WithError(fmt.Errorf("index name is required")) return false, ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
} }
// 構建查詢語句 // 查詢系統表檢查索引是否存在
stmt := "SELECT keyspace_name, table_name, index_name, kind, options FROM system_schema.indexes WHERE keyspace_name = ? AND index_name = ?" stmt := `
args := []interface{}{keyspace, indexName} SELECT index_name, kind, options
names := []string{"keyspace_name", "index_name"} FROM system_schema.indexes
WHERE keyspace_name = ? AND index_name = ?
LIMIT 1
`
var keyspaceName, tableName, idxName, kind string var foundIndexName, kind string
var options map[string]string var options map[string]string
err := db.session.Query(stmt, []string{"keyspace_name", "index_name"}).
WithContext(ctx).
Consistency(gocql.One).
Bind(keyspace, indexName).
Scan(&foundIndexName, &kind, &options)
// 執行查詢
err := db.session.Query(stmt, names).Bind(args...).WithContext(ctx).Consistency(gocql.One).Scan(&keyspaceName, &tableName, &idxName, &kind, &options)
if err != nil {
if err == gocql.ErrNotFound { if err == gocql.ErrNotFound {
return nil, ErrNotFound.WithError(fmt.Errorf("index not found: %s", indexName)) return false, nil
} }
return nil, ErrInvalidInput.WithError(fmt.Errorf("failed to get index: %w", err)) if err != nil {
return false, ErrInvalidInput.WithError(fmt.Errorf("failed to check SAI index existence: %w", err))
} }
// 檢查是否為 SAI 索引 // 檢查是否為 SAI 索引
indexType, ok := options["class_name"] if kind == "CUSTOM" {
if !ok || !strings.Contains(indexType, "StorageAttachedIndex") { if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") {
return nil, ErrInvalidInput.WithError(fmt.Errorf("index %s is not a SAI index", indexName)) return true, nil
}
} }
// 從 options 中提取 column_name return false, nil
columnName := "" }
if target, ok := options["target"]; ok {
columnName = strings.Trim(target, "()\"'") // WaitForSAIIndex 等待 SAI 索引建立完成(用於異步建立)
} // keyspace: keyspace 名稱
// indexName: 索引名稱
return &SAIIndexInfo{ // maxWaitTime: 最大等待時間(秒)
KeyspaceName: keyspaceName, func (db *DB) WaitForSAIIndex(ctx context.Context, keyspace, indexName string, maxWaitTime int) error {
TableName: tableName, // 驗證參數
IndexName: idxName, if keyspace == "" {
ColumnName: columnName, return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
IndexType: "sai", }
Options: options, if indexName == "" {
}, nil return ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
}
// 查詢索引狀態
// 注意Cassandra 沒有直接的索引狀態查詢,這裡需要通過檢查索引是否可用來判斷
// 實際實作可能需要根據具體的 Cassandra 版本調整
// 簡單實作:檢查索引是否存在
exists, err := db.CheckSAIIndexExists(ctx, keyspace, indexName)
if err != nil {
return err
}
if !exists {
return ErrInvalidInput.WithError(fmt.Errorf("index %s does not exist", indexName))
}
// 注意:實際的等待邏輯可能需要查詢系統表或使用其他方法
// 這裡只是基本框架,實際使用時可能需要根據具體需求調整
return nil
} }

View File

@ -1,117 +1,50 @@
package cassandra package cassandra
import ( import (
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestCreateSAIIndex(t *testing.T) { func TestDefaultSAIIndexOptions(t *testing.T) {
opts := DefaultSAIIndexOptions()
assert.NotNil(t, opts)
assert.Equal(t, SAIIndexTypeStandard, opts.IndexType)
assert.False(t, opts.IsAsync)
assert.True(t, opts.CaseSensitive)
}
func TestCreateSAIIndex_Validation(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
keyspace string keyspace string
table string table string
column string column string
indexName string indexName string
options *SAIIndexOptions opts *SAIIndexOptions
description string
wantErr bool wantErr bool
validate func(*testing.T, error) errMsg string
}{ }{
{
name: "create basic SAI index",
keyspace: "test_keyspace",
table: "test_table",
column: "name",
indexName: "test_name_idx",
options: nil,
description: "should create a basic SAI index",
wantErr: false,
},
{
name: "create SAI index with auto-generated name",
keyspace: "test_keyspace",
table: "test_table",
column: "email",
indexName: "",
options: nil,
description: "should auto-generate index name",
wantErr: false,
},
{
name: "create SAI index with case insensitive option",
keyspace: "test_keyspace",
table: "test_table",
column: "title",
indexName: "test_title_idx",
options: &SAIIndexOptions{CaseSensitive: boolPtr(false)},
description: "should create index with case insensitive option",
wantErr: false,
},
{
name: "create SAI index with normalize option",
keyspace: "test_keyspace",
table: "test_table",
column: "content",
indexName: "test_content_idx",
options: &SAIIndexOptions{Normalize: boolPtr(true)},
description: "should create index with normalize option",
wantErr: false,
},
{
name: "create SAI index with analyzer",
keyspace: "test_keyspace",
table: "test_table",
column: "description",
indexName: "test_desc_idx",
options: &SAIIndexOptions{Analyzer: "StandardAnalyzer"},
description: "should create index with analyzer",
wantErr: false,
},
{
name: "create SAI index with all options",
keyspace: "test_keyspace",
table: "test_table",
column: "text",
indexName: "test_text_idx",
options: &SAIIndexOptions{CaseSensitive: boolPtr(false), Normalize: boolPtr(true), Analyzer: "StandardAnalyzer"},
description: "should create index with all options",
wantErr: false,
},
{ {
name: "missing keyspace", name: "missing keyspace",
keyspace: "", keyspace: "",
table: "test_table", table: "test_table",
column: "name", column: "test_column",
indexName: "test_idx", indexName: "test_idx",
options: nil, opts: nil,
description: "should return error when keyspace is empty and no default",
wantErr: true, wantErr: true,
validate: func(t *testing.T, err error) { errMsg: "keyspace is required",
assert.Error(t, err)
var e *Error
if assert.ErrorAs(t, err, &e) {
assert.Equal(t, ErrCodeInvalidInput, e.Code)
}
},
}, },
{ {
name: "missing table", name: "missing table",
keyspace: "test_keyspace", keyspace: "test_keyspace",
table: "", table: "",
column: "name", column: "test_column",
indexName: "test_idx", indexName: "test_idx",
options: nil, opts: nil,
description: "should return error when table is empty",
wantErr: true, wantErr: true,
validate: func(t *testing.T, err error) { errMsg: "table is required",
assert.Error(t, err)
var e *Error
if assert.ErrorAs(t, err, &e) {
assert.Equal(t, ErrCodeInvalidInput, e.Code)
}
},
}, },
{ {
name: "missing column", name: "missing column",
@ -119,265 +52,216 @@ func TestCreateSAIIndex(t *testing.T) {
table: "test_table", table: "test_table",
column: "", column: "",
indexName: "test_idx", indexName: "test_idx",
options: nil, opts: nil,
description: "should return error when column is empty",
wantErr: true, wantErr: true,
validate: func(t *testing.T, err error) { errMsg: "column is required",
assert.Error(t, err)
var e *Error
if assert.ErrorAs(t, err, &e) {
assert.Equal(t, ErrCodeInvalidInput, e.Code)
}
}, },
{
name: "valid parameters with default options",
keyspace: "test_keyspace",
table: "test_table",
column: "test_column",
indexName: "test_idx",
opts: nil,
wantErr: false,
},
{
name: "valid parameters with custom options",
keyspace: "test_keyspace",
table: "test_table",
column: "test_column",
indexName: "test_idx",
opts: &SAIIndexOptions{
IndexType: SAIIndexTypeFullText,
IsAsync: true,
CaseSensitive: false,
},
wantErr: false,
},
{
name: "auto-generate index name",
keyspace: "test_keyspace",
table: "test_table",
column: "test_column",
indexName: "",
opts: nil,
wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例和 SAI 支援 // 注意:這需要一個有效的 DB 實例和 SAI 支援
// 在實際測試中,需要使用 testcontainers 或 mock // 在實際測試中,需要使用 mock 或 testcontainers
_ = tt _ = tt
}) })
} }
} }
func TestDropSAIIndex(t *testing.T) { func TestDropSAIIndex_Validation(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
keyspace string keyspace string
indexName string indexName string
description string
wantErr bool wantErr bool
validate func(*testing.T, error) errMsg string
}{ }{
{
name: "drop existing index",
keyspace: "test_keyspace",
indexName: "test_name_idx",
description: "should drop existing index",
wantErr: false,
},
{
name: "drop non-existent index",
keyspace: "test_keyspace",
indexName: "non_existent_idx",
description: "should not error when dropping non-existent index (IF EXISTS)",
wantErr: false,
},
{ {
name: "missing keyspace", name: "missing keyspace",
keyspace: "", keyspace: "",
indexName: "test_idx", indexName: "test_idx",
description: "should return error when keyspace is empty and no default",
wantErr: true, wantErr: true,
validate: func(t *testing.T, err error) { errMsg: "keyspace is required",
assert.Error(t, err)
var e *Error
if assert.ErrorAs(t, err, &e) {
assert.Equal(t, ErrCodeInvalidInput, e.Code)
}
},
}, },
{ {
name: "missing index name", name: "missing index name",
keyspace: "test_keyspace", keyspace: "test_keyspace",
indexName: "", indexName: "",
description: "should return error when index name is empty",
wantErr: true, wantErr: true,
validate: func(t *testing.T, err error) { errMsg: "index name is required",
assert.Error(t, err)
var e *Error
if assert.ErrorAs(t, err, &e) {
assert.Equal(t, ErrCodeInvalidInput, e.Code)
}
}, },
{
name: "valid parameters",
keyspace: "test_keyspace",
indexName: "test_idx",
wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例和 SAI 支援 // 注意:這需要一個有效的 DB 實例
// 在實際測試中,需要使用 testcontainers 或 mock // 在實際測試中,需要使用 mock 或 testcontainers
_ = tt _ = tt
}) })
} }
} }
func TestListSAIIndexes(t *testing.T) { func TestListSAIIndexes_Validation(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
keyspace string keyspace string
table string table string
description string
wantErr bool wantErr bool
validate func(*testing.T, []SAIIndexInfo, error) errMsg string
}{ }{
{
name: "list all indexes in keyspace",
keyspace: "test_keyspace",
table: "",
description: "should list all SAI indexes in keyspace",
wantErr: false,
validate: func(t *testing.T, indexes []SAIIndexInfo, err error) {
require.NoError(t, err)
assert.NotNil(t, indexes)
},
},
{
name: "list indexes for specific table",
keyspace: "test_keyspace",
table: "test_table",
description: "should list SAI indexes for specific table",
wantErr: false,
validate: func(t *testing.T, indexes []SAIIndexInfo, err error) {
require.NoError(t, err)
assert.NotNil(t, indexes)
for _, idx := range indexes {
assert.Equal(t, "test_table", idx.TableName)
}
},
},
{ {
name: "missing keyspace", name: "missing keyspace",
keyspace: "", keyspace: "",
table: "", table: "test_table",
description: "should return error when keyspace is empty and no default",
wantErr: true, wantErr: true,
validate: func(t *testing.T, indexes []SAIIndexInfo, err error) { errMsg: "keyspace is required",
assert.Error(t, err)
assert.Nil(t, indexes)
var e *Error
if assert.ErrorAs(t, err, &e) {
assert.Equal(t, ErrCodeInvalidInput, e.Code)
}
}, },
{
name: "missing table",
keyspace: "test_keyspace",
table: "",
wantErr: true,
errMsg: "table is required",
},
{
name: "valid parameters",
keyspace: "test_keyspace",
table: "test_table",
wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例和 SAI 支援 // 注意:這需要一個有效的 DB 實例
// 在實際測試中,需要使用 testcontainers 或 mock // 在實際測試中,需要使用 mock 或 testcontainers
_ = tt _ = tt
}) })
} }
} }
func TestGetSAIIndex(t *testing.T) { func TestCheckSAIIndexExists_Validation(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
keyspace string keyspace string
indexName string indexName string
description string
wantErr bool wantErr bool
validate func(*testing.T, *SAIIndexInfo, error) errMsg string
}{ }{
{
name: "get existing index",
keyspace: "test_keyspace",
indexName: "test_name_idx",
description: "should get existing SAI index",
wantErr: false,
validate: func(t *testing.T, index *SAIIndexInfo, err error) {
require.NoError(t, err)
assert.NotNil(t, index)
assert.Equal(t, "test_name_idx", index.IndexName)
},
},
{
name: "get non-existent index",
keyspace: "test_keyspace",
indexName: "non_existent_idx",
description: "should return ErrNotFound",
wantErr: true,
validate: func(t *testing.T, index *SAIIndexInfo, err error) {
assert.Error(t, err)
assert.Nil(t, index)
assert.True(t, IsNotFound(err))
},
},
{ {
name: "missing keyspace", name: "missing keyspace",
keyspace: "", keyspace: "",
indexName: "test_idx", indexName: "test_idx",
description: "should return error when keyspace is empty and no default",
wantErr: true, wantErr: true,
validate: func(t *testing.T, index *SAIIndexInfo, err error) { errMsg: "keyspace is required",
assert.Error(t, err)
assert.Nil(t, index)
},
}, },
{ {
name: "missing index name", name: "missing index name",
keyspace: "test_keyspace", keyspace: "test_keyspace",
indexName: "", indexName: "",
description: "should return error when index name is empty",
wantErr: true, wantErr: true,
validate: func(t *testing.T, index *SAIIndexInfo, err error) { errMsg: "index name is required",
assert.Error(t, err)
assert.Nil(t, index)
}, },
{
name: "valid parameters",
keyspace: "test_keyspace",
indexName: "test_idx",
wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// 注意:這需要一個有效的 DB 實例和 SAI 支援 // 注意:這需要一個有效的 DB 實例
// 在實際測試中,需要使用 testcontainers 或 mock // 在實際測試中,需要使用 mock 或 testcontainers
_ = tt _ = tt
}) })
} }
} }
func TestSAIIndexOptions(t *testing.T) { func TestSAIIndexType_Constants(t *testing.T) {
t.Run("default options", func(t *testing.T) { tests := []struct {
opts := &SAIIndexOptions{} name string
assert.Nil(t, opts.CaseSensitive) indexType SAIIndexType
assert.Nil(t, opts.Normalize) expected string
assert.Empty(t, opts.Analyzer) }{
}) {
name: "standard index type",
t.Run("with case sensitive", func(t *testing.T) { indexType: SAIIndexTypeStandard,
caseSensitive := false expected: "STANDARD",
opts := &SAIIndexOptions{CaseSensitive: &caseSensitive} },
assert.NotNil(t, opts.CaseSensitive) {
assert.False(t, *opts.CaseSensitive) name: "collection index type",
}) indexType: SAIIndexTypeCollection,
expected: "COLLECTION",
t.Run("with normalize", func(t *testing.T) { },
normalize := true {
opts := &SAIIndexOptions{Normalize: &normalize} name: "full text index type",
assert.NotNil(t, opts.Normalize) indexType: SAIIndexTypeFullText,
assert.True(t, *opts.Normalize) expected: "FULL_TEXT",
}) },
t.Run("with analyzer", func(t *testing.T) {
opts := &SAIIndexOptions{Analyzer: "StandardAnalyzer"}
assert.Equal(t, "StandardAnalyzer", opts.Analyzer)
})
}
func TestSAIIndexInfo(t *testing.T) {
t.Run("index info structure", func(t *testing.T) {
info := SAIIndexInfo{
KeyspaceName: "test_keyspace",
TableName: "test_table",
IndexName: "test_idx",
ColumnName: "name",
IndexType: "sai",
Options: map[string]string{"target": "name"},
} }
assert.Equal(t, "test_keyspace", info.KeyspaceName) for _, tt := range tests {
assert.Equal(t, "test_table", info.TableName) t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, "test_idx", info.IndexName) assert.Equal(t, tt.expected, string(tt.indexType))
assert.Equal(t, "name", info.ColumnName) })
assert.Equal(t, "sai", info.IndexType) }
assert.NotNil(t, info.Options) }
func TestCreateSAIIndex_NotSupported(t *testing.T) {
t.Run("should return error when SAI not supported", func(t *testing.T) {
// 注意:這需要一個不支援 SAI 的 DB 實例
// 在實際測試中,需要使用 mock 或 testcontainers
}) })
} }
// Helper function func TestCreateSAIIndex_IndexNameGeneration(t *testing.T) {
func boolPtr(b bool) *bool { t.Run("should generate index name when not provided", func(t *testing.T) {
return &b // 測試自動生成索引名稱的邏輯
// 格式應該是: {table}_{column}_sai_idx
table := "users"
column := "email"
expected := "users_email_sai_idx"
// 這裡只是測試命名邏輯,實際建立需要 DB 實例
generated := fmt.Sprintf("%s_%s_sai_idx", table, column)
assert.Equal(t, expected, generated)
})
} }

View File

@ -3,8 +3,8 @@ package post
// CommentStatus 評論狀態 // CommentStatus 評論狀態
type CommentStatus int32 type CommentStatus int32
func (s *CommentStatus) CodeToString() string { func (s CommentStatus) CodeToString() string {
result, ok := commentStatusMap[*s] result, ok := commentStatusMap[s]
if !ok { if !ok {
return "" return ""
} }
@ -17,8 +17,8 @@ var commentStatusMap = map[CommentStatus]string{
CommentStatusHidden: "hidden", // 隱藏 CommentStatusHidden: "hidden", // 隱藏
} }
func (s *CommentStatus) ToInt32() int32 { func (s CommentStatus) ToInt32() int32 {
return int32(*s) return int32(s)
} }
const ( const (

View File

@ -3,8 +3,8 @@ package post
// Status 貼文狀態 // Status 貼文狀態
type Status int32 type Status int32
func (s *Status) CodeToString() string { func (s Status) CodeToString() string {
result, ok := postStatusMap[*s] result, ok := postStatusMap[s]
if !ok { if !ok {
return "" return ""
} }
@ -19,8 +19,8 @@ var postStatusMap = map[Status]string{
PostStatusHidden: "hidden", // 隱藏 PostStatusHidden: "hidden", // 隱藏
} }
func (s *Status) ToInt32() int32 { func (s Status) ToInt32() int32 {
return int32(*s) return int32(s)
} }
const ( const (

View File

@ -3,8 +3,8 @@ package post
// Type 貼文類型 // Type 貼文類型
type Type int32 type Type int32
func (t *Type) CodeToString() string { func (t Type) CodeToString() string {
result, ok := postTypeMap[*t] result, ok := postTypeMap[t]
if !ok { if !ok {
return "" return ""
} }
@ -12,28 +12,28 @@ func (t *Type) CodeToString() string {
} }
var postTypeMap = map[Type]string{ var postTypeMap = map[Type]string{
PostTypeText: "text", // 純文字 TypeText: "text", // 純文字
PostTypeImage: "image", // 圖片 TypeImage: "image", // 圖片
PostTypeVideo: "video", // 影片 TypeVideo: "video", // 影片
PostTypeLink: "link", // 連結 TypeLink: "link", // 連結
PostTypePoll: "poll", // 投票 TypePoll: "poll", // 投票
PostTypeArticle: "article", // 長文 TypeArticle: "article", // 長文
} }
func (t *Type) ToInt32() int32 { func (t Type) ToInt32() int32 {
return int32(*t) return int32(t)
} }
const ( const (
PostTypeText Type = 0 // 純文字 TypeText Type = 0 // 純文字
PostTypeImage Type = 1 // 圖片 TypeImage Type = 1 // 圖片
PostTypeVideo Type = 2 // 影片 TypeVideo Type = 2 // 影片
PostTypeLink Type = 3 // 連結 TypeLink Type = 3 // 連結
PostTypePoll Type = 4 // 投票 TypePoll Type = 4 // 投票
PostTypeArticle Type = 5 // 長文 TypeArticle Type = 5 // 長文
) )
// IsValid returns true if the type is valid // IsValid returns true if the type is valid
func (t Type) IsValid() bool { func (t Type) IsValid() bool {
return t >= PostTypeText && t <= PostTypeArticle return t >= TypeText && t <= TypeArticle
} }

View File

@ -4,26 +4,23 @@ import (
"context" "context"
"backend/pkg/post/domain/entity" "backend/pkg/post/domain/entity"
"github.com/gocql/gocql"
) )
// CategoryRepository defines the interface for category data access operations // CategoryRepository defines the interface for category data access operations
type CategoryRepository interface { type CategoryRepository interface {
BaseCategoryRepository BaseCategoryRepository
FindBySlug(ctx context.Context, slug string) (*entity.Category, error) FindBySlug(ctx context.Context, slug string) (*entity.Category, error)
FindByParentID(ctx context.Context, parentID *gocql.UUID) ([]*entity.Category, error) FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error)
FindRootCategories(ctx context.Context) ([]*entity.Category, error) FindRootCategories(ctx context.Context) ([]*entity.Category, error)
FindActive(ctx context.Context) ([]*entity.Category, error) FindActive(ctx context.Context) ([]*entity.Category, error)
IncrementPostCount(ctx context.Context, categoryID gocql.UUID) error IncrementPostCount(ctx context.Context, categoryID string) error
DecrementPostCount(ctx context.Context, categoryID gocql.UUID) error DecrementPostCount(ctx context.Context, categoryID string) error
} }
// BaseCategoryRepository defines basic CRUD operations for categories // BaseCategoryRepository defines basic CRUD operations for categories
type BaseCategoryRepository interface { type BaseCategoryRepository interface {
Insert(ctx context.Context, data *entity.Category) error Insert(ctx context.Context, data *entity.Category) error
FindOne(ctx context.Context, id gocql.UUID) (*entity.Category, error) FindOne(ctx context.Context, id string) (*entity.Category, error)
Update(ctx context.Context, data *entity.Category) error Update(ctx context.Context, data *entity.Category) error
Delete(ctx context.Context, id gocql.UUID) error Delete(ctx context.Context, id string) error
} }

View File

@ -0,0 +1,263 @@
package repository
import (
"context"
"fmt"
"strings"
"backend/pkg/library/cassandra"
"backend/pkg/post/domain/entity"
domainRepo "backend/pkg/post/domain/repository"
"github.com/gocql/gocql"
)
// CategoryRepositoryParam 定義 CategoryRepository 的初始化參數
type CategoryRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// CategoryRepository 實作 domain repository 介面
type CategoryRepository struct {
repo cassandra.Repository[*entity.Category]
db *cassandra.DB
keyspace string
}
// NewCategoryRepository 創建新的 CategoryRepository
func NewCategoryRepository(param CategoryRepositoryParam) domainRepo.CategoryRepository {
repo, err := cassandra.NewRepository[*entity.Category](param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create category repository: %v", err))
}
keyspace := param.Keyspace
if keyspace == "" {
keyspace = param.DB.GetDefaultKeyspace()
}
return &CategoryRepository{
repo: repo,
db: param.DB,
keyspace: keyspace,
}
}
// Insert 插入單筆分類
func (r *CategoryRepository) Insert(ctx context.Context, data *entity.Category) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
if data.ParentID == nil {
data.ParentID = &gocql.UUID{}
}
// 設置時間戳
data.SetTimestamps()
// 如果是新分類,生成 ID
if data.IsNew() {
data.ID = gocql.TimeUUID()
}
// Slug 轉為小寫
data.Slug = strings.ToLower(strings.TrimSpace(data.Slug))
return r.repo.Insert(ctx, data)
}
// FindOne 根據 ID 查詢單筆分類
func (r *CategoryRepository) FindOne(ctx context.Context, id string) (*entity.Category, error) {
var zeroUUID gocql.UUID
uuid, err := gocql.ParseUUID(id)
if err != nil {
return nil, err
}
if uuid == zeroUUID {
return nil, ErrInvalidInput
}
category, err := r.repo.Get(ctx, id)
if err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to find category: %w", err)
}
return category, nil
}
// Update 更新分類
func (r *CategoryRepository) Update(ctx context.Context, data *entity.Category) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 更新時間戳
data.SetTimestamps()
// Slug 轉為小寫
data.Slug = strings.ToLower(strings.TrimSpace(data.Slug))
return r.repo.Update(ctx, data)
}
// Delete 刪除分類
func (r *CategoryRepository) Delete(ctx context.Context, id string) error {
var zeroUUID gocql.UUID
uuid, err := gocql.ParseUUID(id)
if err != nil {
return err
}
if uuid == zeroUUID {
return ErrInvalidInput
}
return r.repo.Delete(ctx, id)
}
// FindBySlug 根據 slug 查詢分類
func (r *CategoryRepository) FindBySlug(ctx context.Context, slug string) (*entity.Category, error) {
if slug == "" {
return nil, ErrInvalidInput
}
// 標準化 slug
slug = strings.ToLower(strings.TrimSpace(slug))
// 構建查詢(要有 SAI 索引在 slug 欄位上)
query := r.repo.Query().Where(cassandra.Eq("slug", slug))
var categories []*entity.Category
if err := query.Scan(ctx, &categories); err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to query category: %w", err)
}
if len(categories) == 0 {
return nil, ErrNotFound
}
return categories[0], nil
}
// FindByParentID 根據父分類 ID 查詢子分類
func (r *CategoryRepository) FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error) {
query := r.repo.Query()
var zeroUUID gocql.UUID
if parentID != "" {
// 構建查詢(有 SAI 索引在 parentID 欄位上)
uuid, err := gocql.ParseUUID(parentID)
if err != nil {
return nil, err
}
if uuid != zeroUUID {
query = query.Where(cassandra.Eq("parent_id", uuid))
}
} else {
query = query.Where(cassandra.Eq("parent_id", zeroUUID))
}
// 按 sort_order 排序
query = query.OrderBy("sort_order", cassandra.ASC)
var categories []*entity.Category
if err := query.Scan(ctx, &categories); err != nil {
return nil, fmt.Errorf("failed to query categories: %w", err)
}
return categories, nil
}
// FindRootCategories 查詢根分類
func (r *CategoryRepository) FindRootCategories(ctx context.Context) ([]*entity.Category, error) {
return r.FindByParentID(ctx, "")
}
// FindActive 查詢啟用的分類
func (r *CategoryRepository) FindActive(ctx context.Context) ([]*entity.Category, error) {
query := r.repo.Query().
Where(cassandra.Eq("is_active", true)).
OrderBy("sort_order", cassandra.ASC)
var categories []*entity.Category
if err := query.Scan(ctx, &categories); err != nil {
return nil, fmt.Errorf("failed to query active categories: %w", err)
}
result := categories
return result, nil
}
// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件)
// 注意post_count 欄位必須是 counter 類型
func (r *CategoryRepository) IncrementPostCount(ctx context.Context, categoryID string) error {
uuid, err := gocql.ParseUUID(categoryID)
if err != nil {
return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err)
}
// 使用 counter 原子更新操作UPDATE categories SET post_count = post_count + 1 WHERE id = ?
var zeroCategory entity.Category
tableName := zeroCategory.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(uuid)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment post count: %w", err)
}
return nil
}
// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件)
// 注意post_count 欄位必須是 counter 類型
func (r *CategoryRepository) DecrementPostCount(ctx context.Context, categoryID string) error {
uuid, err := gocql.ParseUUID(categoryID)
if err != nil {
return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err)
}
// 使用 counter 原子更新操作UPDATE categories SET post_count = post_count - 1 WHERE id = ?
var zeroCategory entity.Category
tableName := zeroCategory.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(uuid)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement post count: %w", err)
}
return nil
}

View File

@ -0,0 +1,383 @@
package repository
import (
"context"
"fmt"
"backend/pkg/library/cassandra"
"backend/pkg/post/domain/entity"
"backend/pkg/post/domain/post"
domainRepo "backend/pkg/post/domain/repository"
"github.com/gocql/gocql"
)
// CommentRepositoryParam 定義 CommentRepository 的初始化參數
type CommentRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// CommentRepository 實作 domain repository 介面
type CommentRepository struct {
repo cassandra.Repository[*entity.Comment]
db *cassandra.DB
keyspace string
}
// NewCommentRepository 創建新的 CommentRepository
func NewCommentRepository(param CommentRepositoryParam) domainRepo.CommentRepository {
repo, err := cassandra.NewRepository[*entity.Comment](param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create comment repository: %v", err))
}
keyspace := param.Keyspace
if keyspace == "" {
keyspace = param.DB.GetDefaultKeyspace()
}
return &CommentRepository{
repo: repo,
db: param.DB,
keyspace: keyspace,
}
}
// Insert 插入單筆評論
func (r *CommentRepository) Insert(ctx context.Context, data *entity.Comment) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 設置時間戳
data.SetTimestamps()
// 如果是新評論,生成 ID
if data.IsNew() {
data.ID = gocql.TimeUUID()
}
return r.repo.Insert(ctx, data)
}
// FindOne 根據 ID 查詢單筆評論
func (r *CommentRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Comment, error) {
var zeroUUID gocql.UUID
if id == zeroUUID {
return nil, ErrInvalidInput
}
comment, err := r.repo.Get(ctx, id)
if err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to find comment: %w", err)
}
return comment, nil
}
// Update 更新評論
func (r *CommentRepository) Update(ctx context.Context, data *entity.Comment) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 更新時間戳
data.SetTimestamps()
return r.repo.Update(ctx, data)
}
// Delete 刪除評論(軟刪除)
func (r *CommentRepository) Delete(ctx context.Context, id gocql.UUID) error {
var zeroUUID gocql.UUID
if id == zeroUUID {
return ErrInvalidInput
}
// 先查詢評論
comment, err := r.FindOne(ctx, id)
if err != nil {
return err
}
// 軟刪除:標記為已刪除
comment.Delete()
return r.Update(ctx, comment)
}
// FindByPostID 根據貼文 ID 查詢評論
func (r *CommentRepository) FindByPostID(ctx context.Context, postID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return nil, 0, ErrInvalidInput
}
// 構建查詢(使用 PostID 作為 clustering key
query := r.repo.Query().Where(cassandra.Eq("post_id", postID))
// 添加父評論過濾(如果指定,只查詢回覆)
if params != nil && params.ParentID != nil {
query = query.Where(cassandra.Eq("parent_id", *params.ParentID))
} else {
// 如果沒有指定 ParentID只查詢頂層評論parent_id 為 null
// 注意Cassandra 不支援直接查詢 null需要特殊處理
// 這裡簡化處理,實際可能需要使用 Materialized View
}
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
} else {
// 預設只查詢已發布的評論
published := post.CommentStatusPublished
query = query.Where(cassandra.Eq("status", published))
}
// 添加排序
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.ASC
if params != nil && params.OrderDirection == "DESC" {
order = cassandra.DESC
}
query = query.OrderBy(orderBy, order)
// 添加分頁
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
limit := int(pageSize)
query = query.Limit(limit)
// 執行查詢
var comments []*entity.Comment
if err := query.Scan(ctx, &comments); err != nil {
return nil, 0, fmt.Errorf("failed to query comments: %w", err)
}
result := comments
total := int64(len(result))
return result, total, nil
}
// FindByParentID 根據父評論 ID 查詢回覆
func (r *CommentRepository) FindByParentID(ctx context.Context, parentID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
var zeroUUID gocql.UUID
if parentID == zeroUUID {
return nil, 0, ErrInvalidInput
}
query := r.repo.Query().Where(cassandra.Eq("parent_id", parentID))
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
} else {
published := post.CommentStatusPublished
query = query.Where(cassandra.Eq("status", published))
}
// 添加排序和分頁
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.ASC
if params != nil && params.OrderDirection == "DESC" {
order = cassandra.DESC
}
query = query.OrderBy(orderBy, order)
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
query = query.Limit(int(pageSize))
var comments []*entity.Comment
if err := query.Scan(ctx, &comments); err != nil {
return nil, 0, fmt.Errorf("failed to query replies: %w", err)
}
return comments, int64(len(comments)), nil
}
// FindByAuthorUID 根據作者 UID 查詢評論
func (r *CommentRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
if authorUID == "" {
return nil, 0, ErrInvalidInput
}
query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID))
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
}
// 添加排序和分頁
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
query = query.Limit(int(pageSize))
var comments []*entity.Comment
if err := query.Scan(ctx, &comments); err != nil {
return nil, 0, fmt.Errorf("failed to query comments: %w", err)
}
return comments, int64(len(comments)), nil
}
// FindReplies 查詢指定評論的回覆
func (r *CommentRepository) FindReplies(ctx context.Context, commentID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
return r.FindByParentID(ctx, commentID, params)
}
// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件)
// 注意like_count 欄位必須是 counter 類型
func (r *CommentRepository) IncrementLikeCount(ctx context.Context, commentID gocql.UUID) error {
var zeroUUID gocql.UUID
if commentID == zeroUUID {
return ErrInvalidInput
}
var zeroComment entity.Comment
tableName := zeroComment.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(commentID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment like count: %w", err)
}
return nil
}
// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件)
// 注意like_count 欄位必須是 counter 類型
func (r *CommentRepository) DecrementLikeCount(ctx context.Context, commentID gocql.UUID) error {
var zeroUUID gocql.UUID
if commentID == zeroUUID {
return ErrInvalidInput
}
var zeroComment entity.Comment
tableName := zeroComment.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(commentID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement like count: %w", err)
}
return nil
}
// IncrementReplyCount 增加回覆數(使用 counter 原子操作避免競爭條件)
// 注意reply_count 欄位必須是 counter 類型
func (r *CommentRepository) IncrementReplyCount(ctx context.Context, commentID gocql.UUID) error {
var zeroUUID gocql.UUID
if commentID == zeroUUID {
return ErrInvalidInput
}
var zeroComment entity.Comment
tableName := zeroComment.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET reply_count = reply_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(commentID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment reply count: %w", err)
}
return nil
}
// DecrementReplyCount 減少回覆數(使用 counter 原子操作避免競爭條件)
// 注意reply_count 欄位必須是 counter 類型
func (r *CommentRepository) DecrementReplyCount(ctx context.Context, commentID gocql.UUID) error {
var zeroUUID gocql.UUID
if commentID == zeroUUID {
return ErrInvalidInput
}
var zeroComment entity.Comment
tableName := zeroComment.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET reply_count = reply_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(commentID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement reply count: %w", err)
}
return nil
}
// UpdateStatus 更新評論狀態
func (r *CommentRepository) UpdateStatus(ctx context.Context, commentID gocql.UUID, status post.CommentStatus) error {
comment, err := r.FindOne(ctx, commentID)
if err != nil {
return err
}
comment.Status = status
return r.Update(ctx, comment)
}

View File

@ -0,0 +1,34 @@
package repository
import (
"errors"
"backend/pkg/library/cassandra"
)
// Common repository errors
var (
// ErrNotFound is returned when a requested resource is not found
ErrNotFound = errors.New("resource not found")
// ErrInvalidInput is returned when input validation fails
ErrInvalidInput = errors.New("invalid input")
// ErrDuplicateKey is returned when attempting to insert a document with a duplicate key
ErrDuplicateKey = errors.New("duplicate key error")
)
// IsNotFound checks if the error is a not found error
func IsNotFound(err error) bool {
if err == nil {
return false
}
if err == ErrNotFound {
return true
}
if cassandra.IsNotFound(err) {
return true
}
return false
}

228
pkg/post/repository/like.go Normal file
View File

@ -0,0 +1,228 @@
package repository
import (
"context"
"fmt"
"backend/pkg/library/cassandra"
"backend/pkg/post/domain/entity"
domainRepo "backend/pkg/post/domain/repository"
"github.com/gocql/gocql"
)
// LikeRepositoryParam 定義 LikeRepository 的初始化參數
type LikeRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// LikeRepository 實作 domain repository 介面
type LikeRepository struct {
repo cassandra.Repository[*entity.Like]
db *cassandra.DB
}
// NewLikeRepository 創建新的 LikeRepository
func NewLikeRepository(param LikeRepositoryParam) domainRepo.LikeRepository {
repo, err := cassandra.NewRepository[*entity.Like](param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create like repository: %v", err))
}
return &LikeRepository{
repo: repo,
db: param.DB,
}
}
// Insert 插入單筆按讚
func (r *LikeRepository) Insert(ctx context.Context, data *entity.Like) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 設置時間戳
data.SetTimestamps()
// 如果是新按讚,生成 ID
if data.IsNew() {
data.ID = gocql.TimeUUID()
}
return r.repo.Insert(ctx, data)
}
// FindOne 根據 ID 查詢單筆按讚
func (r *LikeRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Like, error) {
var zeroUUID gocql.UUID
if id == zeroUUID {
return nil, ErrInvalidInput
}
like, err := r.repo.Get(ctx, id)
if err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to find like: %w", err)
}
return like, nil
}
// Delete 刪除按讚
func (r *LikeRepository) Delete(ctx context.Context, id gocql.UUID) error {
var zeroUUID gocql.UUID
if id == zeroUUID {
return ErrInvalidInput
}
return r.repo.Delete(ctx, id)
}
// FindByTargetID 根據目標 ID 查詢按讚列表
func (r *LikeRepository) FindByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) ([]*entity.Like, error) {
var zeroUUID gocql.UUID
if targetID == zeroUUID {
return nil, ErrInvalidInput
}
if targetType != "post" && targetType != "comment" {
return nil, ErrInvalidInput
}
// 構建查詢
query := r.repo.Query().
Where(cassandra.Eq("target_id", targetID)).
Where(cassandra.Eq("target_type", targetType)).
OrderBy("created_at", cassandra.DESC)
var likes []*entity.Like
if err := query.Scan(ctx, &likes); err != nil {
return nil, fmt.Errorf("failed to query likes: %w", err)
}
return likes, nil
}
// FindByUserUID 根據用戶 UID 查詢按讚列表
func (r *LikeRepository) FindByUserUID(ctx context.Context, userUID string, params *domainRepo.LikeQueryParams) ([]*entity.Like, int64, error) {
if userUID == "" {
return nil, 0, ErrInvalidInput
}
query := r.repo.Query().Where(cassandra.Eq("user_uid", userUID))
// 添加目標類型過濾
if params != nil && params.TargetType != nil {
query = query.Where(cassandra.Eq("target_type", *params.TargetType))
}
// 添加目標 ID 過濾
if params != nil && params.TargetID != nil {
query = query.Where(cassandra.Eq("target_id", *params.TargetID))
}
// 添加排序
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
// 添加分頁
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
query = query.Limit(int(pageSize))
var likes []*entity.Like
if err := query.Scan(ctx, &likes); err != nil {
return nil, 0, fmt.Errorf("failed to query likes: %w", err)
}
result := likes
return result, int64(len(result)), nil
}
// FindByTargetAndUser 根據目標和用戶查詢按讚
func (r *LikeRepository) FindByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) (*entity.Like, error) {
var zeroUUID gocql.UUID
if targetID == zeroUUID || userUID == "" {
return nil, ErrInvalidInput
}
if targetType != "post" && targetType != "comment" {
return nil, ErrInvalidInput
}
// 構建查詢
query := r.repo.Query().
Where(cassandra.Eq("target_id", targetID)).
Where(cassandra.Eq("user_uid", userUID)).
Where(cassandra.Eq("target_type", targetType)).
Limit(1)
var likes []*entity.Like
if err := query.Scan(ctx, &likes); err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to query like: %w", err)
}
if len(likes) == 0 {
return nil, ErrNotFound
}
return likes[0], nil
}
// CountByTargetID 計算目標的按讚數
func (r *LikeRepository) CountByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) (int64, error) {
var zeroUUID gocql.UUID
if targetID == zeroUUID {
return 0, ErrInvalidInput
}
if targetType != "post" && targetType != "comment" {
return 0, ErrInvalidInput
}
// 構建查詢
query := r.repo.Query().
Where(cassandra.Eq("target_id", targetID)).
Where(cassandra.Eq("target_type", targetType))
count, err := query.Count(ctx)
if err != nil {
return 0, fmt.Errorf("failed to count likes: %w", err)
}
return count, nil
}
// DeleteByTargetAndUser 根據目標和用戶刪除按讚
func (r *LikeRepository) DeleteByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) error {
// 先查詢按讚
like, err := r.FindByTargetAndUser(ctx, targetID, userUID, targetType)
if err != nil {
return err
}
// 刪除按讚
return r.Delete(ctx, like.ID)
}

511
pkg/post/repository/post.go Normal file
View File

@ -0,0 +1,511 @@
package repository
import (
"context"
"fmt"
"math"
"backend/pkg/library/cassandra"
"backend/pkg/post/domain/entity"
"backend/pkg/post/domain/post"
domainRepo "backend/pkg/post/domain/repository"
"github.com/gocql/gocql"
)
// PostRepositoryParam 定義 PostRepository 的初始化參數
type PostRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// PostRepository 實作 domain repository 介面
type PostRepository struct {
repo cassandra.Repository[*entity.Post]
db *cassandra.DB
keyspace string
}
// NewPostRepository 創建新的 PostRepository
func NewPostRepository(param PostRepositoryParam) domainRepo.PostRepository {
repo, err := cassandra.NewRepository[*entity.Post](param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create post repository: %v", err))
}
keyspace := param.Keyspace
if keyspace == "" {
keyspace = param.DB.GetDefaultKeyspace()
}
return &PostRepository{
repo: repo,
db: param.DB,
keyspace: keyspace,
}
}
// Insert 插入單筆貼文
func (r *PostRepository) Insert(ctx context.Context, data *entity.Post) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 設置時間戳
data.SetTimestamps()
// 如果是新貼文,生成 ID
if data.IsNew() {
data.ID = gocql.TimeUUID()
}
// 如果狀態是 published設置發布時間
if data.Status == post.PostStatusPublished && data.PublishedAt == nil {
now := data.CreatedAt
data.PublishedAt = &now
}
return r.repo.Insert(ctx, data)
}
// FindOne 根據 ID 查詢單筆貼文
func (r *PostRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Post, error) {
var zeroUUID gocql.UUID
if id == zeroUUID {
return nil, ErrInvalidInput
}
post, err := r.repo.Get(ctx, id)
if err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to find post: %w", err)
}
return post, nil
}
// Update 更新貼文
func (r *PostRepository) Update(ctx context.Context, data *entity.Post) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 更新時間戳
data.SetTimestamps()
return r.repo.Update(ctx, data)
}
// Delete 刪除貼文(軟刪除)
func (r *PostRepository) Delete(ctx context.Context, id gocql.UUID) error {
var zeroUUID gocql.UUID
if id == zeroUUID {
return ErrInvalidInput
}
// 先查詢貼文
post, err := r.FindOne(ctx, id)
if err != nil {
return err
}
// 軟刪除:標記為已刪除
post.Delete()
return r.Update(ctx, post)
}
// FindByAuthorUID 根據作者 UID 查詢貼文
func (r *PostRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
if authorUID == "" {
return nil, 0, ErrInvalidInput
}
// 構建查詢
query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID))
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
}
// 添加排序
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
// 添加分頁
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
pageIndex := int64(1)
if params != nil && params.PageIndex > 0 {
pageIndex = params.PageIndex
}
limit := int(pageSize)
query = query.Limit(limit)
// 執行查詢
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
}
result := posts
// 計算總數(簡化實作,實際應該使用 COUNT 查詢)
total := int64(len(posts))
if params != nil && params.PageIndex > 1 {
// 這裡應該執行 COUNT 查詢,但為了簡化,我們假設有更多結果
total = pageSize * pageIndex
}
return result, total, nil
}
// FindByCategoryID 根據分類 ID 查詢貼文
func (r *PostRepository) FindByCategoryID(ctx context.Context, categoryID gocql.UUID, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
var zeroUUID gocql.UUID
if categoryID == zeroUUID {
return nil, 0, ErrInvalidInput
}
// 構建查詢
query := r.repo.Query().Where(cassandra.Eq("category_id", categoryID))
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
}
// 添加排序和分頁(類似 FindByAuthorUID
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
limit := int(pageSize)
query = query.Limit(limit)
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
}
result := posts
total := int64(len(posts))
return result, total, nil
}
// FindByTag 根據標籤查詢貼文
func (r *PostRepository) FindByTag(ctx context.Context, tagName string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
if tagName == "" {
return nil, 0, ErrInvalidInput
}
// 構建查詢注意Cassandra 的集合查詢需要使用 CONTAINS這裡簡化處理
// 實際實作中,可能需要使用 SAI 索引或 Materialized View
query := r.repo.Query()
// 添加狀態過濾
if params != nil && params.Status != nil {
query = query.Where(cassandra.Eq("status", *params.Status))
}
// 添加排序和分頁
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
limit := int(pageSize)
query = query.Limit(limit)
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
}
// 過濾包含指定標籤的貼文
filtered := make([]*entity.Post, 0)
for _, p := range posts {
for _, tag := range p.Tags {
if tag == tagName {
filtered = append(filtered, p)
break
}
}
}
total := int64(len(filtered))
return filtered, total, nil
}
// FindPinnedPosts 查詢置頂貼文
func (r *PostRepository) FindPinnedPosts(ctx context.Context, limit int64) ([]*entity.Post, error) {
query := r.repo.Query().
Where(cassandra.Eq("is_pinned", true)).
Where(cassandra.Eq("status", post.PostStatusPublished)).
OrderBy("pinned_at", cassandra.DESC).
Limit(int(limit))
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, fmt.Errorf("failed to query pinned posts: %w", err)
}
return posts, nil
}
// FindByStatus 根據狀態查詢貼文
func (r *PostRepository) FindByStatus(ctx context.Context, status post.Status, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
query := r.repo.Query().Where(cassandra.Eq("status", status))
// 添加排序和分頁
orderBy := "created_at"
if params != nil && params.OrderBy != "" {
orderBy = params.OrderBy
}
order := cassandra.DESC
if params != nil && params.OrderDirection == "ASC" {
order = cassandra.ASC
}
query = query.OrderBy(orderBy, order)
pageSize := int64(20)
if params != nil && params.PageSize > 0 {
pageSize = params.PageSize
}
limit := int(pageSize)
query = query.Limit(limit)
var posts []*entity.Post
if err := query.Scan(ctx, &posts); err != nil {
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
}
result := posts
total := int64(len(posts))
return result, total, nil
}
// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件)
// 注意like_count 欄位必須是 counter 類型
func (r *PostRepository) IncrementLikeCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment like count: %w", err)
}
return nil
}
// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件)
// 注意like_count 欄位必須是 counter 類型
func (r *PostRepository) DecrementLikeCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement like count: %w", err)
}
return nil
}
// IncrementCommentCount 增加評論數(使用 counter 原子操作避免競爭條件)
// 注意comment_count 欄位必須是 counter 類型
func (r *PostRepository) IncrementCommentCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment comment count: %w", err)
}
return nil
}
// DecrementCommentCount 減少評論數(使用 counter 原子操作避免競爭條件)
// 注意comment_count 欄位必須是 counter 類型
func (r *PostRepository) DecrementCommentCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement comment count: %w", err)
}
return nil
}
// IncrementViewCount 增加瀏覽數(使用 counter 原子操作避免競爭條件)
// 注意view_count 欄位必須是 counter 類型
func (r *PostRepository) IncrementViewCount(ctx context.Context, postID gocql.UUID) error {
var zeroUUID gocql.UUID
if postID == zeroUUID {
return ErrInvalidInput
}
var zeroPost entity.Post
tableName := zeroPost.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET view_count = view_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(postID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment view count: %w", err)
}
return nil
}
// UpdateStatus 更新貼文狀態
func (r *PostRepository) UpdateStatus(ctx context.Context, postID gocql.UUID, status post.Status) error {
postEntity, err := r.FindOne(ctx, postID)
if err != nil {
return err
}
postEntity.Status = status
publishedStatus := post.PostStatusPublished
if status == publishedStatus && postEntity.PublishedAt == nil {
now := postEntity.UpdatedAt
postEntity.PublishedAt = &now
}
return r.Update(ctx, postEntity)
}
// PinPost 置頂貼文
func (r *PostRepository) PinPost(ctx context.Context, postID gocql.UUID) error {
post, err := r.FindOne(ctx, postID)
if err != nil {
return err
}
post.Pin()
return r.Update(ctx, post)
}
// UnpinPost 取消置頂
func (r *PostRepository) UnpinPost(ctx context.Context, postID gocql.UUID) error {
post, err := r.FindOne(ctx, postID)
if err != nil {
return err
}
post.Unpin()
return r.Update(ctx, post)
}
// calculateTotalPages 計算總頁數
func calculateTotalPages(total, pageSize int64) int64 {
if pageSize <= 0 {
return 0
}
return int64(math.Ceil(float64(total) / float64(pageSize)))
}

250
pkg/post/repository/tag.go Normal file
View File

@ -0,0 +1,250 @@
package repository
import (
"context"
"fmt"
"strings"
"backend/pkg/library/cassandra"
"backend/pkg/post/domain/entity"
domainRepo "backend/pkg/post/domain/repository"
"github.com/gocql/gocql"
)
// TagRepositoryParam 定義 TagRepository 的初始化參數
type TagRepositoryParam struct {
DB *cassandra.DB
Keyspace string
}
// TagRepository 實作 domain repository 介面
type TagRepository struct {
repo cassandra.Repository[*entity.Tag]
db *cassandra.DB
keyspace string
}
// NewTagRepository 創建新的 TagRepository
func NewTagRepository(param TagRepositoryParam) domainRepo.TagRepository {
repo, err := cassandra.NewRepository[*entity.Tag](param.DB, param.Keyspace)
if err != nil {
panic(fmt.Sprintf("failed to create tag repository: %v", err))
}
keyspace := param.Keyspace
if keyspace == "" {
keyspace = param.DB.GetDefaultKeyspace()
}
return &TagRepository{
repo: repo,
db: param.DB,
keyspace: keyspace,
}
}
// Insert 插入單筆標籤
func (r *TagRepository) Insert(ctx context.Context, data *entity.Tag) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 設置時間戳
data.SetTimestamps()
// 如果是新標籤,生成 ID
if data.IsNew() {
data.ID = gocql.TimeUUID()
}
// 標籤名稱轉為小寫(統一格式)
data.Name = strings.ToLower(strings.TrimSpace(data.Name))
return r.repo.Insert(ctx, data)
}
// FindOne 根據 ID 查詢單筆標籤
func (r *TagRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Tag, error) {
var zeroUUID gocql.UUID
if id == zeroUUID {
return nil, ErrInvalidInput
}
tag, err := r.repo.Get(ctx, id)
if err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to find tag: %w", err)
}
return tag, nil
}
// Update 更新標籤
func (r *TagRepository) Update(ctx context.Context, data *entity.Tag) error {
if data == nil {
return ErrInvalidInput
}
// 驗證資料
if err := data.Validate(); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
}
// 更新時間戳
data.SetTimestamps()
// 標籤名稱轉為小寫
data.Name = strings.ToLower(strings.TrimSpace(data.Name))
return r.repo.Update(ctx, data)
}
// Delete 刪除標籤
func (r *TagRepository) Delete(ctx context.Context, id gocql.UUID) error {
var zeroUUID gocql.UUID
if id == zeroUUID {
return ErrInvalidInput
}
return r.repo.Delete(ctx, id)
}
// FindByName 根據名稱查詢標籤
func (r *TagRepository) FindByName(ctx context.Context, name string) (*entity.Tag, error) {
if name == "" {
return nil, ErrInvalidInput
}
// 標準化名稱
name = strings.ToLower(strings.TrimSpace(name))
// 構建查詢(假設有 SAI 索引在 name 欄位上)
query := r.repo.Query().Where(cassandra.Eq("name", name))
var tags []*entity.Tag
if err := query.Scan(ctx, &tags); err != nil {
if cassandra.IsNotFound(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to query tag: %w", err)
}
if len(tags) == 0 {
return nil, ErrNotFound
}
return tags[0], nil
}
// FindByNames 根據名稱列表查詢標籤
func (r *TagRepository) FindByNames(ctx context.Context, names []string) ([]*entity.Tag, error) {
if len(names) == 0 {
return []*entity.Tag{}, nil
}
// 標準化名稱
normalizedNames := make([]string, len(names))
for i, name := range names {
normalizedNames[i] = strings.ToLower(strings.TrimSpace(name))
}
// 構建查詢(使用 IN 條件)
query := r.repo.Query().Where(cassandra.In("name", toAnySlice(normalizedNames)))
var tags []*entity.Tag
if err := query.Scan(ctx, &tags); err != nil {
return nil, fmt.Errorf("failed to query tags: %w", err)
}
return tags, nil
}
// FindPopular 查詢熱門標籤
func (r *TagRepository) FindPopular(ctx context.Context, limit int64) ([]*entity.Tag, error) {
// 構建查詢,按 post_count 降序排列
query := r.repo.Query().
OrderBy("post_count", cassandra.DESC).
Limit(int(limit))
var tags []*entity.Tag
if err := query.Scan(ctx, &tags); err != nil {
return nil, fmt.Errorf("failed to query popular tags: %w", err)
}
result := tags
return result, nil
}
// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件)
// 注意post_count 欄位必須是 counter 類型
func (r *TagRepository) IncrementPostCount(ctx context.Context, tagID gocql.UUID) error {
var zeroUUID gocql.UUID
if tagID == zeroUUID {
return ErrInvalidInput
}
// 使用 counter 原子更新操作UPDATE tags SET post_count = post_count + 1 WHERE id = ?
var zeroTag entity.Tag
tableName := zeroTag.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(tagID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to increment post count: %w", err)
}
return nil
}
// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件)
// 注意post_count 欄位必須是 counter 類型
func (r *TagRepository) DecrementPostCount(ctx context.Context, tagID gocql.UUID) error {
var zeroUUID gocql.UUID
if tagID == zeroUUID {
return ErrInvalidInput
}
// 使用 counter 原子更新操作UPDATE tags SET post_count = post_count - 1 WHERE id = ?
var zeroTag entity.Tag
tableName := zeroTag.TableName()
if r.keyspace == "" {
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
}
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName)
query := r.db.GetSession().Query(stmt, nil).
WithContext(ctx).
Consistency(gocql.Quorum).
Bind(tagID)
if err := query.ExecRelease(); err != nil {
return fmt.Errorf("failed to decrement post count: %w", err)
}
return nil
}
// toAnySlice 將 string slice 轉換為 []any
func toAnySlice(strs []string) []any {
result := make([]any, len(strs))
for i, s := range strs {
result[i] = s
}
return result
}

455
pkg/post/usecase/comment.go Normal file
View File

@ -0,0 +1,455 @@
package usecase
import (
"context"
"errors"
"fmt"
"math"
errs "backend/pkg/library/errors"
"backend/pkg/post/domain/entity"
"backend/pkg/post/domain/post"
domainRepo "backend/pkg/post/domain/repository"
domainUsecase "backend/pkg/post/domain/usecase"
"backend/pkg/post/repository"
"github.com/gocql/gocql"
)
// CommentUseCaseParam 定義 CommentUseCase 的初始化參數
type CommentUseCaseParam struct {
Comment domainRepo.CommentRepository
Post domainRepo.PostRepository
Like domainRepo.LikeRepository
Logger errs.Logger
}
// CommentUseCase 實作 domain usecase 介面
type CommentUseCase struct {
CommentUseCaseParam
}
// MustCommentUseCase 創建新的 CommentUseCase如果失敗會 panic
func MustCommentUseCase(param CommentUseCaseParam) domainUsecase.CommentUseCase {
return &CommentUseCase{
CommentUseCaseParam: param,
}
}
// CreateComment 創建新評論
func (uc *CommentUseCase) CreateComment(ctx context.Context, req domainUsecase.CreateCommentRequest) (*domainUsecase.CommentResponse, error) {
// 驗證輸入
if err := uc.validateCreateCommentRequest(req); err != nil {
return nil, err
}
// 驗證貼文存在
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return nil, errs.InputInvalidRangeError("post_id is required")
}
post, err := uc.Post.FindOne(ctx, req.PostID)
if err != nil {
if repository.IsNotFound(err) {
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
}
return nil, uc.handleDBError("Post.FindOne", req, err)
}
// 檢查貼文是否可見
if !post.IsVisible() {
return nil, errs.ResNotFoundError("cannot comment on non-visible post")
}
// 建立評論實體
comment := &entity.Comment{
PostID: req.PostID,
AuthorUID: req.AuthorUID,
ParentID: req.ParentID,
Content: req.Content,
Status: post.CommentStatusPublished,
}
// 插入資料庫
if err := uc.Comment.Insert(ctx, comment); err != nil {
return nil, uc.handleDBError("Comment.Insert", req, err)
}
// 如果是回覆,增加父評論的回覆數
if req.ParentID != nil {
if err := uc.Comment.IncrementReplyCount(ctx, *req.ParentID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to increment reply count: %v", err))
}
}
// 增加貼文的評論數
if err := uc.Post.IncrementCommentCount(ctx, req.PostID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to increment comment count: %v", err))
}
return uc.mapCommentToResponse(comment), nil
}
// GetComment 取得評論
func (uc *CommentUseCase) GetComment(ctx context.Context, req domainUsecase.GetCommentRequest) (*domainUsecase.CommentResponse, error) {
// 驗證輸入
var zeroUUID gocql.UUID
if req.CommentID == zeroUUID {
return nil, errs.InputInvalidRangeError("comment_id is required")
}
// 查詢評論
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
if err != nil {
if repository.IsNotFound(err) {
return nil, errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
}
return nil, uc.handleDBError("Comment.FindOne", req, err)
}
return uc.mapCommentToResponse(comment), nil
}
// UpdateComment 更新評論
func (uc *CommentUseCase) UpdateComment(ctx context.Context, req domainUsecase.UpdateCommentRequest) (*domainUsecase.CommentResponse, error) {
// 驗證輸入
var zeroUUID gocql.UUID
if req.CommentID == zeroUUID {
return nil, errs.InputInvalidRangeError("comment_id is required")
}
if req.AuthorUID == "" {
return nil, errs.InputInvalidRangeError("author_uid is required")
}
if req.Content == "" {
return nil, errs.InputInvalidRangeError("content is required")
}
// 查詢現有評論
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
if err != nil {
if repository.IsNotFound(err) {
return nil, errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
}
return nil, uc.handleDBError("Comment.FindOne", req, err)
}
// 驗證權限
if comment.AuthorUID != req.AuthorUID {
return nil, errs.ResNotFoundError("not authorized to update this comment")
}
// 檢查是否可見
if !comment.IsVisible() {
return nil, errs.ResNotFoundError("comment is not visible")
}
// 更新內容
comment.Content = req.Content
// 更新資料庫
if err := uc.Comment.Update(ctx, comment); err != nil {
return nil, uc.handleDBError("Comment.Update", req, err)
}
return uc.mapCommentToResponse(comment), nil
}
// DeleteComment 刪除評論(軟刪除)
func (uc *CommentUseCase) DeleteComment(ctx context.Context, req domainUsecase.DeleteCommentRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.CommentID == zeroUUID {
return errs.InputInvalidRangeError("comment_id is required")
}
if req.AuthorUID == "" {
return errs.InputInvalidRangeError("author_uid is required")
}
// 查詢評論
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
if err != nil {
if repository.IsNotFound(err) {
return errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
}
return uc.handleDBError("Comment.FindOne", req, err)
}
// 驗證權限
if comment.AuthorUID != req.AuthorUID {
return errs.ResNotFoundError("not authorized to delete this comment")
}
// 刪除評論
if err := uc.Comment.Delete(ctx, req.CommentID); err != nil {
return uc.handleDBError("Comment.Delete", req, err)
}
// 如果是回覆,減少父評論的回覆數
if comment.ParentID != nil {
if err := uc.Comment.DecrementReplyCount(ctx, *comment.ParentID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to decrement reply count: %v", err))
}
}
// 減少貼文的評論數
if err := uc.Post.DecrementCommentCount(ctx, comment.PostID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to decrement comment count: %v", err))
}
return nil
}
// ListComments 列出評論
func (uc *CommentUseCase) ListComments(ctx context.Context, req domainUsecase.ListCommentsRequest) (*domainUsecase.ListCommentsResponse, error) {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return nil, errs.InputInvalidRangeError("post_id is required")
}
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageIndex <= 0 {
req.PageIndex = 1
}
// 構建查詢參數
params := &domainRepo.CommentQueryParams{
PostID: &req.PostID,
ParentID: req.ParentID,
PageSize: req.PageSize,
PageIndex: req.PageIndex,
OrderBy: req.OrderBy,
OrderDirection: req.OrderDirection,
}
// 如果 OrderBy 未指定,預設為 created_at
if params.OrderBy == "" {
params.OrderBy = "created_at"
}
// 如果 OrderDirection 未指定,預設為 ASC
if params.OrderDirection == "" {
params.OrderDirection = "ASC"
}
// 執行查詢
comments, total, err := uc.Comment.FindByPostID(ctx, req.PostID, params)
if err != nil {
return nil, uc.handleDBError("Comment.FindByPostID", req, err)
}
// 轉換為 Response
responses := make([]domainUsecase.CommentResponse, len(comments))
for i, c := range comments {
responses[i] = *uc.mapCommentToResponse(c)
}
return &domainUsecase.ListCommentsResponse{
Data: responses,
Page: domainUsecase.Pager{
PageIndex: req.PageIndex,
PageSize: req.PageSize,
Total: total,
TotalPage: calculateTotalPages(total, req.PageSize),
},
}, nil
}
// ListReplies 列出回覆
func (uc *CommentUseCase) ListReplies(ctx context.Context, req domainUsecase.ListRepliesRequest) (*domainUsecase.ListCommentsResponse, error) {
// 驗證輸入
var zeroUUID gocql.UUID
if req.CommentID == zeroUUID {
return nil, errs.InputInvalidRangeError("comment_id is required")
}
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageIndex <= 0 {
req.PageIndex = 1
}
// 構建查詢參數
params := &domainRepo.CommentQueryParams{
PageSize: req.PageSize,
PageIndex: req.PageIndex,
OrderBy: "created_at",
OrderDirection: "ASC",
}
// 執行查詢
comments, total, err := uc.Comment.FindReplies(ctx, req.CommentID, params)
if err != nil {
return nil, uc.handleDBError("Comment.FindReplies", req, err)
}
// 轉換為 Response
responses := make([]domainUsecase.CommentResponse, len(comments))
for i, c := range comments {
responses[i] = *uc.mapCommentToResponse(c)
}
return &domainUsecase.ListCommentsResponse{
Data: responses,
Page: domainUsecase.Pager{
PageIndex: req.PageIndex,
PageSize: req.PageSize,
Total: total,
TotalPage: calculateTotalPages(total, req.PageSize),
},
}, nil
}
// ListCommentsByAuthor 根據作者列出評論
func (uc *CommentUseCase) ListCommentsByAuthor(ctx context.Context, req domainUsecase.ListCommentsByAuthorRequest) (*domainUsecase.ListCommentsResponse, error) {
if req.AuthorUID == "" {
return nil, errs.InputInvalidRangeError("author_uid is required")
}
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageIndex <= 0 {
req.PageIndex = 1
}
params := &domainRepo.CommentQueryParams{
PageSize: req.PageSize,
PageIndex: req.PageIndex,
OrderBy: "created_at",
OrderDirection: "DESC",
}
comments, total, err := uc.Comment.FindByAuthorUID(ctx, req.AuthorUID, params)
if err != nil {
return nil, uc.handleDBError("Comment.FindByAuthorUID", req, err)
}
responses := make([]domainUsecase.CommentResponse, len(comments))
for i, c := range comments {
responses[i] = *uc.mapCommentToResponse(c)
}
return &domainUsecase.ListCommentsResponse{
Data: responses,
Page: domainUsecase.Pager{
PageIndex: req.PageIndex,
PageSize: req.PageSize,
Total: total,
TotalPage: calculateTotalPages(total, req.PageSize),
},
}, nil
}
// LikeComment 按讚評論
func (uc *CommentUseCase) LikeComment(ctx context.Context, req domainUsecase.LikeCommentRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.CommentID == zeroUUID {
return errs.InputInvalidRangeError("comment_id is required")
}
if req.UserUID == "" {
return errs.InputInvalidRangeError("user_uid is required")
}
// 檢查是否已經按讚
existingLike, err := uc.Like.FindByTargetAndUser(ctx, req.CommentID, req.UserUID, "comment")
if err == nil && existingLike != nil {
// 已經按讚,直接返回成功
return nil
}
if err != nil && !repository.IsNotFound(err) {
return uc.handleDBError("Like.FindByTargetAndUser", req, err)
}
// 建立按讚記錄
like := &entity.Like{
TargetID: req.CommentID,
UserUID: req.UserUID,
TargetType: "comment",
}
if err := uc.Like.Insert(ctx, like); err != nil {
return uc.handleDBError("Like.Insert", req, err)
}
// 增加評論的按讚數
if err := uc.Comment.IncrementLikeCount(ctx, req.CommentID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to increment like count: %v", err))
}
return nil
}
// UnlikeComment 取消按讚評論
func (uc *CommentUseCase) UnlikeComment(ctx context.Context, req domainUsecase.UnlikeCommentRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.CommentID == zeroUUID {
return errs.InputInvalidRangeError("comment_id is required")
}
if req.UserUID == "" {
return errs.InputInvalidRangeError("user_uid is required")
}
// 刪除按讚記錄
if err := uc.Like.DeleteByTargetAndUser(ctx, req.CommentID, req.UserUID, "comment"); err != nil {
if repository.IsNotFound(err) {
// 已經取消按讚,直接返回成功
return nil
}
return uc.handleDBError("Like.DeleteByTargetAndUser", req, err)
}
// 減少評論的按讚數
if err := uc.Comment.DecrementLikeCount(ctx, req.CommentID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to decrement like count: %v", err))
}
return nil
}
// validateCreateCommentRequest 驗證建立評論請求
func (uc *CommentUseCase) validateCreateCommentRequest(req domainUsecase.CreateCommentRequest) error {
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return errs.InputInvalidRangeError("post_id is required")
}
if req.AuthorUID == "" {
return errs.InputInvalidRangeError("author_uid is required")
}
if req.Content == "" {
return errs.InputInvalidRangeError("content is required")
}
return nil
}
// mapCommentToResponse 將 Comment 實體轉換為 CommentResponse
func (uc *CommentUseCase) mapCommentToResponse(comment *entity.Comment) *domainUsecase.CommentResponse {
return &domainUsecase.CommentResponse{
ID: comment.ID,
PostID: comment.PostID,
AuthorUID: comment.AuthorUID,
ParentID: comment.ParentID,
Content: comment.Content,
Status: comment.Status,
LikeCount: comment.LikeCount,
ReplyCount: comment.ReplyCount,
CreatedAt: comment.CreatedAt,
UpdatedAt: comment.UpdatedAt,
}
}
// handleDBError 處理資料庫錯誤
func (uc *CommentUseCase) handleDBError(funcName string, req any, err error) error {
return errs.DBErrorErrorL(
uc.Logger,
[]errs.LogField{
{Key: "func", Val: funcName},
{Key: "req", Val: req},
{Key: "error", Val: err.Error()},
},
fmt.Sprintf("database operation failed: %s", funcName),
).Wrap(err)
}

801
pkg/post/usecase/post.go Normal file
View File

@ -0,0 +1,801 @@
package usecase
import (
"context"
"errors"
"fmt"
"math"
errs "backend/pkg/library/errors"
"backend/pkg/post/domain/entity"
"backend/pkg/post/domain/post"
domainRepo "backend/pkg/post/domain/repository"
domainUsecase "backend/pkg/post/domain/usecase"
"backend/pkg/post/repository"
"github.com/gocql/gocql"
)
// PostUseCaseParam 定義 PostUseCase 的初始化參數
type PostUseCaseParam struct {
Post domainRepo.PostRepository
Comment domainRepo.CommentRepository
Like domainRepo.LikeRepository
Tag domainRepo.TagRepository
Category domainRepo.CategoryRepository
Logger errs.Logger
}
// PostUseCase 實作 domain usecase 介面
type PostUseCase struct {
PostUseCaseParam
}
// MustPostUseCase 創建新的 PostUseCase如果失敗會 panic
func MustPostUseCase(param PostUseCaseParam) domainUsecase.PostUseCase {
return &PostUseCase{
PostUseCaseParam: param,
}
}
// CreatePost 創建新貼文
func (uc *PostUseCase) CreatePost(ctx context.Context, req domainUsecase.CreatePostRequest) (*domainUsecase.PostResponse, error) {
// 驗證輸入
if err := uc.validateCreatePostRequest(req); err != nil {
return nil, err
}
// 建立貼文實體
post := &entity.Post{
AuthorUID: req.AuthorUID,
Title: req.Title,
Content: req.Content,
Type: req.Type,
CategoryID: req.CategoryID,
Tags: req.Tags,
Images: req.Images,
VideoURL: req.VideoURL,
LinkURL: req.LinkURL,
Status: req.Status,
}
// 如果狀態未指定,預設為草稿
if post.Status == 0 {
post.Status = post.PostStatusDraft
}
// 插入資料庫
if err := uc.Post.Insert(ctx, post); err != nil {
return nil, uc.handleDBError("Post.Insert", req, err)
}
// 處理標籤(更新標籤的貼文數)
if err := uc.updateTagPostCounts(ctx, req.Tags, true); err != nil {
// 記錄錯誤但不中斷流程
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
}
// 處理分類(更新分類的貼文數)
if req.CategoryID != nil {
if err := uc.Category.IncrementPostCount(ctx, *req.CategoryID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to increment category post count: %v", err))
}
}
return uc.mapPostToResponse(post), nil
}
// GetPost 取得貼文
func (uc *PostUseCase) GetPost(ctx context.Context, req domainUsecase.GetPostRequest) (*domainUsecase.PostResponse, error) {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return nil, errs.InputInvalidRangeError("post_id is required")
}
// 查詢貼文
post, err := uc.Post.FindOne(ctx, req.PostID)
if err != nil {
if repository.IsNotFound(err) {
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
}
return nil, uc.handleDBError("Post.FindOne", req, err)
}
// 如果提供了 UserUID增加瀏覽數
if req.UserUID != nil {
if err := uc.Post.IncrementViewCount(ctx, req.PostID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to increment view count: %v", err))
}
}
return uc.mapPostToResponse(post), nil
}
// UpdatePost 更新貼文
func (uc *PostUseCase) UpdatePost(ctx context.Context, req domainUsecase.UpdatePostRequest) (*domainUsecase.PostResponse, error) {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return nil, errs.InputInvalidRangeError("post_id is required")
}
if req.AuthorUID == "" {
return nil, errs.InputInvalidRangeError("author_uid is required")
}
// 查詢現有貼文
post, err := uc.Post.FindOne(ctx, req.PostID)
if err != nil {
if repository.IsNotFound(err) {
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
}
return nil, uc.handleDBError("Post.FindOne", req, err)
}
// 驗證權限
if post.AuthorUID != req.AuthorUID {
return nil, errs.ResNotFoundError("not authorized to update this post")
}
// 檢查是否可編輯
if !post.IsEditable() {
return nil, errs.ResNotFoundError("post is not editable")
}
// 更新欄位
if req.Title != nil {
post.Title = *req.Title
}
if req.Content != nil {
post.Content = *req.Content
}
if req.Type != nil {
post.Type = *req.Type
}
if req.CategoryID != nil {
// 更新分類計數
if post.CategoryID != nil && *post.CategoryID != *req.CategoryID {
if err := uc.Category.DecrementPostCount(ctx, *post.CategoryID); err != nil {
uc.Logger.Error("failed to decrement category post count", errs.LogField{Key: "error", Val: err.Error()})
}
if err := uc.Category.IncrementPostCount(ctx, *req.CategoryID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to increment category post count: %v", err))
}
}
post.CategoryID = req.CategoryID
}
if req.Tags != nil {
// 更新標籤計數
oldTags := post.Tags
post.Tags = req.Tags
if err := uc.updateTagPostCountsDiff(ctx, oldTags, req.Tags); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
}
}
if req.Images != nil {
post.Images = req.Images
}
if req.VideoURL != nil {
post.VideoURL = req.VideoURL
}
if req.LinkURL != nil {
post.LinkURL = req.LinkURL
}
// 更新資料庫
if err := uc.Post.Update(ctx, post); err != nil {
return nil, uc.handleDBError("Post.Update", req, err)
}
return uc.mapPostToResponse(post), nil
}
// DeletePost 刪除貼文(軟刪除)
func (uc *PostUseCase) DeletePost(ctx context.Context, req domainUsecase.DeletePostRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return errs.InputInvalidRangeError("post_id is required")
}
if req.AuthorUID == "" {
return errs.InputInvalidRangeError("author_uid is required")
}
// 查詢貼文
post, err := uc.Post.FindOne(ctx, req.PostID)
if err != nil {
if repository.IsNotFound(err) {
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
}
return uc.handleDBError("Post.FindOne", req, err)
}
// 驗證權限
if post.AuthorUID != req.AuthorUID {
return errs.ResNotFoundError("not authorized to delete this post")
}
// 刪除貼文
if err := uc.Post.Delete(ctx, req.PostID); err != nil {
return uc.handleDBError("Post.Delete", req, err)
}
// 更新標籤和分類計數
if len(post.Tags) > 0 {
if err := uc.updateTagPostCounts(ctx, post.Tags, false); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
}
}
if post.CategoryID != nil {
if err := uc.Category.DecrementPostCount(ctx, *post.CategoryID); err != nil {
uc.Logger.Error("failed to decrement category post count", errs.LogField{Key: "error", Val: err.Error()})
}
}
return nil
}
// PublishPost 發布貼文
func (uc *PostUseCase) PublishPost(ctx context.Context, req domainUsecase.PublishPostRequest) (*domainUsecase.PostResponse, error) {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return nil, errs.InputInvalidRangeError("post_id is required")
}
if req.AuthorUID == "" {
return nil, errs.InputInvalidRangeError("author_uid is required")
}
// 查詢貼文
post, err := uc.Post.FindOne(ctx, req.PostID)
if err != nil {
if repository.IsNotFound(err) {
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
}
return nil, uc.handleDBError("Post.FindOne", req, err)
}
// 驗證權限
if post.AuthorUID != req.AuthorUID {
return nil, errs.ResNotFoundError("not authorized to publish this post")
}
// 發布貼文
post.Publish()
// 更新資料庫
if err := uc.Post.Update(ctx, post); err != nil {
return nil, uc.handleDBError("Post.Update", req, err)
}
return uc.mapPostToResponse(post), nil
}
// ArchivePost 歸檔貼文
func (uc *PostUseCase) ArchivePost(ctx context.Context, req domainUsecase.ArchivePostRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return errs.InputInvalidRangeError("post_id is required")
}
if req.AuthorUID == "" {
return errs.InputInvalidRangeError("author_uid is required")
}
// 查詢貼文
post, err := uc.Post.FindOne(ctx, req.PostID)
if err != nil {
if repository.IsNotFound(err) {
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
}
return uc.handleDBError("Post.FindOne", req, err)
}
// 驗證權限
if post.AuthorUID != req.AuthorUID {
return errs.ResNotFoundError("not authorized to archive this post")
}
// 歸檔貼文
post.Archive()
// 更新資料庫
return uc.Post.Update(ctx, post)
}
// ListPosts 列出貼文
func (uc *PostUseCase) ListPosts(ctx context.Context, req domainUsecase.ListPostsRequest) (*domainUsecase.ListPostsResponse, error) {
// 驗證分頁參數
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageIndex <= 0 {
req.PageIndex = 1
}
// 構建查詢參數
params := &domainRepo.PostQueryParams{
CategoryID: req.CategoryID,
Tag: req.Tag,
Status: req.Status,
Type: req.Type,
AuthorUID: req.AuthorUID,
CreateStartTime: req.CreateStartTime,
CreateEndTime: req.CreateEndTime,
PageSize: req.PageSize,
PageIndex: req.PageIndex,
OrderBy: req.OrderBy,
OrderDirection: req.OrderDirection,
}
// 執行查詢
var posts []*entity.Post
var total int64
var err error
if req.CategoryID != nil {
posts, total, err = uc.Post.FindByCategoryID(ctx, *req.CategoryID, params)
} else if req.Tag != nil {
posts, total, err = uc.Post.FindByTag(ctx, *req.Tag, params)
} else if req.AuthorUID != nil {
posts, total, err = uc.Post.FindByAuthorUID(ctx, *req.AuthorUID, params)
} else if req.Status != nil {
posts, total, err = uc.Post.FindByStatus(ctx, *req.Status, params)
} else {
// 預設查詢所有已發布的貼文
published := post.PostStatusPublished
params.Status = &published
posts, total, err = uc.Post.FindByStatus(ctx, published, params)
}
if err != nil {
return nil, uc.handleDBError("Post.FindBy*", req, err)
}
// 轉換為 Response
responses := make([]domainUsecase.PostResponse, len(posts))
for i, p := range posts {
responses[i] = *uc.mapPostToResponse(p)
}
return &domainUsecase.ListPostsResponse{
Data: responses,
Page: domainUsecase.Pager{
PageIndex: req.PageIndex,
PageSize: req.PageSize,
Total: total,
TotalPage: calculateTotalPages(total, req.PageSize),
},
}, nil
}
// ListPostsByAuthor 根據作者列出貼文
func (uc *PostUseCase) ListPostsByAuthor(ctx context.Context, req domainUsecase.ListPostsByAuthorRequest) (*domainUsecase.ListPostsResponse, error) {
if req.AuthorUID == "" {
return nil, errs.InputInvalidRangeError("author_uid is required")
}
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageIndex <= 0 {
req.PageIndex = 1
}
params := &domainRepo.PostQueryParams{
Status: req.Status,
PageSize: req.PageSize,
PageIndex: req.PageIndex,
OrderBy: "created_at",
OrderDirection: "DESC",
}
posts, total, err := uc.Post.FindByAuthorUID(ctx, req.AuthorUID, params)
if err != nil {
return nil, uc.handleDBError("Post.FindByAuthorUID", req, err)
}
responses := make([]domainUsecase.PostResponse, len(posts))
for i, p := range posts {
responses[i] = *uc.mapPostToResponse(p)
}
return &domainUsecase.ListPostsResponse{
Data: responses,
Page: domainUsecase.Pager{
PageIndex: req.PageIndex,
PageSize: req.PageSize,
Total: total,
TotalPage: calculateTotalPages(total, req.PageSize),
},
}, nil
}
// ListPostsByCategory 根據分類列出貼文
func (uc *PostUseCase) ListPostsByCategory(ctx context.Context, req domainUsecase.ListPostsByCategoryRequest) (*domainUsecase.ListPostsResponse, error) {
var zeroUUID gocql.UUID
if req.CategoryID == zeroUUID {
return nil, errs.InputInvalidRangeError("category_id is required")
}
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageIndex <= 0 {
req.PageIndex = 1
}
params := &domainRepo.PostQueryParams{
Status: req.Status,
PageSize: req.PageSize,
PageIndex: req.PageIndex,
OrderBy: "created_at",
OrderDirection: "DESC",
}
posts, total, err := uc.Post.FindByCategoryID(ctx, req.CategoryID, params)
if err != nil {
return nil, uc.handleDBError("Post.FindByCategoryID", req, err)
}
responses := make([]domainUsecase.PostResponse, len(posts))
for i, p := range posts {
responses[i] = *uc.mapPostToResponse(p)
}
return &domainUsecase.ListPostsResponse{
Data: responses,
Page: domainUsecase.Pager{
PageIndex: req.PageIndex,
PageSize: req.PageSize,
Total: total,
TotalPage: calculateTotalPages(total, req.PageSize),
},
}, nil
}
// ListPostsByTag 根據標籤列出貼文
func (uc *PostUseCase) ListPostsByTag(ctx context.Context, req domainUsecase.ListPostsByTagRequest) (*domainUsecase.ListPostsResponse, error) {
if req.Tag == "" {
return nil, errs.InputInvalidRangeError("tag is required")
}
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageIndex <= 0 {
req.PageIndex = 1
}
params := &domainRepo.PostQueryParams{
Status: req.Status,
PageSize: req.PageSize,
PageIndex: req.PageIndex,
OrderBy: "created_at",
OrderDirection: "DESC",
}
posts, total, err := uc.Post.FindByTag(ctx, req.Tag, params)
if err != nil {
return nil, uc.handleDBError("Post.FindByTag", req, err)
}
responses := make([]domainUsecase.PostResponse, len(posts))
for i, p := range posts {
responses[i] = *uc.mapPostToResponse(p)
}
return &domainUsecase.ListPostsResponse{
Data: responses,
Page: domainUsecase.Pager{
PageIndex: req.PageIndex,
PageSize: req.PageSize,
Total: total,
TotalPage: calculateTotalPages(total, req.PageSize),
},
}, nil
}
// GetPinnedPosts 取得置頂貼文
func (uc *PostUseCase) GetPinnedPosts(ctx context.Context, req domainUsecase.GetPinnedPostsRequest) (*domainUsecase.ListPostsResponse, error) {
limit := int64(10)
if req.Limit > 0 {
limit = req.Limit
}
posts, err := uc.Post.FindPinnedPosts(ctx, limit)
if err != nil {
return nil, uc.handleDBError("Post.FindPinnedPosts", req, err)
}
responses := make([]domainUsecase.PostResponse, len(posts))
for i, p := range posts {
responses[i] = *uc.mapPostToResponse(p)
}
return &domainUsecase.ListPostsResponse{
Data: responses,
Page: domainUsecase.Pager{
PageIndex: 1,
PageSize: limit,
Total: int64(len(responses)),
TotalPage: 1,
},
}, nil
}
// LikePost 按讚貼文
func (uc *PostUseCase) LikePost(ctx context.Context, req domainUsecase.LikePostRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return errs.InputInvalidRangeError("post_id is required")
}
if req.UserUID == "" {
return errs.InputInvalidRangeError("user_uid is required")
}
// 檢查是否已經按讚
existingLike, err := uc.Like.FindByTargetAndUser(ctx, req.PostID, req.UserUID, "post")
if err == nil && existingLike != nil {
// 已經按讚,直接返回成功
return nil
}
if err != nil && !repository.IsNotFound(err) {
return uc.handleDBError("Like.FindByTargetAndUser", req, err)
}
// 建立按讚記錄
like := &entity.Like{
TargetID: req.PostID,
UserUID: req.UserUID,
TargetType: "post",
}
if err := uc.Like.Insert(ctx, like); err != nil {
return uc.handleDBError("Like.Insert", req, err)
}
// 增加貼文的按讚數
if err := uc.Post.IncrementLikeCount(ctx, req.PostID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to increment like count: %v", err))
}
return nil
}
// UnlikePost 取消按讚
func (uc *PostUseCase) UnlikePost(ctx context.Context, req domainUsecase.UnlikePostRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return errs.InputInvalidRangeError("post_id is required")
}
if req.UserUID == "" {
return errs.InputInvalidRangeError("user_uid is required")
}
// 刪除按讚記錄
if err := uc.Like.DeleteByTargetAndUser(ctx, req.PostID, req.UserUID, "post"); err != nil {
if repository.IsNotFound(err) {
// 已經取消按讚,直接返回成功
return nil
}
return uc.handleDBError("Like.DeleteByTargetAndUser", req, err)
}
// 減少貼文的按讚數
if err := uc.Post.DecrementLikeCount(ctx, req.PostID); err != nil {
uc.Logger.Error(fmt.Sprintf("failed to decrement like count: %v", err))
}
return nil
}
// ViewPost 瀏覽貼文(增加瀏覽數)
func (uc *PostUseCase) ViewPost(ctx context.Context, req domainUsecase.ViewPostRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return errs.InputInvalidRangeError("post_id is required")
}
// 增加瀏覽數
if err := uc.Post.IncrementViewCount(ctx, req.PostID); err != nil {
return uc.handleDBError("Post.IncrementViewCount", req, err)
}
return nil
}
// PinPost 置頂貼文
func (uc *PostUseCase) PinPost(ctx context.Context, req domainUsecase.PinPostRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return errs.InputInvalidRangeError("post_id is required")
}
if req.AuthorUID == "" {
return errs.InputInvalidRangeError("author_uid is required")
}
// 查詢貼文
post, err := uc.Post.FindOne(ctx, req.PostID)
if err != nil {
if repository.IsNotFound(err) {
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
}
return uc.handleDBError("Post.FindOne", req, err)
}
// 驗證權限
if post.AuthorUID != req.AuthorUID {
return errs.ResNotFoundError("not authorized to pin this post")
}
// 置頂貼文
return uc.Post.PinPost(ctx, req.PostID)
}
// UnpinPost 取消置頂
func (uc *PostUseCase) UnpinPost(ctx context.Context, req domainUsecase.UnpinPostRequest) error {
// 驗證輸入
var zeroUUID gocql.UUID
if req.PostID == zeroUUID {
return errs.InputInvalidRangeError("post_id is required")
}
if req.AuthorUID == "" {
return errs.InputInvalidRangeError("author_uid is required")
}
// 查詢貼文
post, err := uc.Post.FindOne(ctx, req.PostID)
if err != nil {
if repository.IsNotFound(err) {
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
}
return uc.handleDBError("Post.FindOne", req, err)
}
// 驗證權限
if post.AuthorUID != req.AuthorUID {
return errs.ResNotFoundError("not authorized to unpin this post")
}
// 取消置頂
return uc.Post.UnpinPost(ctx, req.PostID)
}
// validateCreatePostRequest 驗證建立貼文請求
func (uc *PostUseCase) validateCreatePostRequest(req domainUsecase.CreatePostRequest) error {
if req.AuthorUID == "" {
return errs.InputInvalidRangeError("author_uid is required")
}
if req.Title == "" {
return errs.InputInvalidRangeError("title is required")
}
if req.Content == "" {
return errs.InputInvalidRangeError("content is required")
}
if !req.Type.IsValid() {
return errs.InputInvalidRangeError("invalid post type")
}
return nil
}
// mapPostToResponse 將 Post 實體轉換為 PostResponse
func (uc *PostUseCase) mapPostToResponse(post *entity.Post) *domainUsecase.PostResponse {
return &domainUsecase.PostResponse{
ID: post.ID,
AuthorUID: post.AuthorUID,
Title: post.Title,
Content: post.Content,
Type: post.Type,
Status: post.Status,
CategoryID: post.CategoryID,
Tags: post.Tags,
Images: post.Images,
VideoURL: post.VideoURL,
LinkURL: post.LinkURL,
LikeCount: post.LikeCount,
CommentCount: post.CommentCount,
ViewCount: post.ViewCount,
IsPinned: post.IsPinned,
PinnedAt: post.PinnedAt,
PublishedAt: post.PublishedAt,
CreatedAt: post.CreatedAt,
UpdatedAt: post.UpdatedAt,
}
}
// handleDBError 處理資料庫錯誤
func (uc *PostUseCase) handleDBError(funcName string, req any, err error) error {
return errs.DBErrorErrorL(
uc.Logger,
[]errs.LogField{
{Key: "func", Val: funcName},
{Key: "req", Val: req},
{Key: "error", Val: err.Error()},
},
fmt.Sprintf("database operation failed: %s", funcName),
).Wrap(err)
}
// updateTagPostCounts 更新標籤的貼文數
func (uc *PostUseCase) updateTagPostCounts(ctx context.Context, tags []string, increment bool) error {
if len(tags) == 0 {
return nil
}
// 查詢或建立標籤
for _, tagName := range tags {
tag, err := uc.Tag.FindByName(ctx, tagName)
if err != nil {
if repository.IsNotFound(err) {
// 建立新標籤
newTag := &entity.Tag{
Name: tagName,
}
if err := uc.Tag.Insert(ctx, newTag); err != nil {
return fmt.Errorf("failed to create tag: %w", err)
}
tag = newTag
} else {
return fmt.Errorf("failed to find tag: %w", err)
}
}
// 更新計數
if increment {
if err := uc.Tag.IncrementPostCount(ctx, tag.ID); err != nil {
return fmt.Errorf("failed to increment tag count: %w", err)
}
} else {
if err := uc.Tag.DecrementPostCount(ctx, tag.ID); err != nil {
return fmt.Errorf("failed to decrement tag count: %w", err)
}
}
}
return nil
}
// updateTagPostCountsDiff 更新標籤計數(處理差異)
func (uc *PostUseCase) updateTagPostCountsDiff(ctx context.Context, oldTags, newTags []string) error {
// 找出新增和刪除的標籤
oldTagMap := make(map[string]bool)
for _, tag := range oldTags {
oldTagMap[tag] = true
}
newTagMap := make(map[string]bool)
for _, tag := range newTags {
newTagMap[tag] = true
}
// 新增的標籤
for _, tag := range newTags {
if !oldTagMap[tag] {
if err := uc.updateTagPostCounts(ctx, []string{tag}, true); err != nil {
return err
}
}
}
// 刪除的標籤
for _, tag := range oldTags {
if !newTagMap[tag] {
if err := uc.updateTagPostCounts(ctx, []string{tag}, false); err != nil {
return err
}
}
}
return nil
}
// calculateTotalPages 計算總頁數
func calculateTotalPages(total, pageSize int64) int64 {
if pageSize <= 0 {
return 0
}
return int64(math.Ceil(float64(total) / float64(pageSize)))
}