feat: add cassandra lib
This commit is contained in:
parent
1786e7c690
commit
97d6c8f499
|
|
@ -504,6 +504,79 @@ err = userRepo.Query().
|
|||
Scan(ctx, &users)
|
||||
```
|
||||
|
||||
## SAI 索引管理
|
||||
|
||||
### 建立 SAI 索引
|
||||
|
||||
```go
|
||||
// 檢查是否支援 SAI
|
||||
if !db.SaiSupported() {
|
||||
log.Fatal("SAI is not supported in this Cassandra version")
|
||||
}
|
||||
|
||||
// 建立標準索引
|
||||
err := db.CreateSAIIndex(ctx, "my_keyspace", "users", "email", "users_email_idx", nil)
|
||||
if err != nil {
|
||||
log.Printf("建立索引失敗: %v", err)
|
||||
}
|
||||
|
||||
// 建立全文索引(不區分大小寫)
|
||||
opts := &cassandra.SAIIndexOptions{
|
||||
IndexType: cassandra.SAIIndexTypeFullText,
|
||||
IsAsync: false,
|
||||
CaseSensitive: false,
|
||||
}
|
||||
err = db.CreateSAIIndex(ctx, "my_keyspace", "posts", "content", "posts_content_ft_idx", opts)
|
||||
```
|
||||
|
||||
### 查詢 SAI 索引
|
||||
|
||||
```go
|
||||
// 列出資料表的所有 SAI 索引
|
||||
indexes, err := db.ListSAIIndexes(ctx, "my_keyspace", "users")
|
||||
if err != nil {
|
||||
log.Printf("查詢索引失敗: %v", err)
|
||||
} else {
|
||||
for _, idx := range indexes {
|
||||
fmt.Printf("索引: %s, 欄位: %s, 類型: %s\n", idx.Name, idx.Column, idx.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// 檢查索引是否存在
|
||||
exists, err := db.CheckSAIIndexExists(ctx, "my_keyspace", "users_email_idx")
|
||||
if err != nil {
|
||||
log.Printf("檢查索引失敗: %v", err)
|
||||
} else if exists {
|
||||
fmt.Println("索引存在")
|
||||
}
|
||||
```
|
||||
|
||||
### 刪除 SAI 索引
|
||||
|
||||
```go
|
||||
// 刪除索引
|
||||
err := db.DropSAIIndex(ctx, "my_keyspace", "users_email_idx")
|
||||
if err != nil {
|
||||
log.Printf("刪除索引失敗: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### SAI 索引類型
|
||||
|
||||
- **SAIIndexTypeStandard**: 標準索引(等於查詢)
|
||||
- **SAIIndexTypeCollection**: 集合索引(用於 list、set、map)
|
||||
- **SAIIndexTypeFullText**: 全文索引
|
||||
|
||||
### SAI 索引選項
|
||||
|
||||
```go
|
||||
opts := &cassandra.SAIIndexOptions{
|
||||
IndexType: cassandra.SAIIndexTypeFullText, // 索引類型
|
||||
IsAsync: false, // 是否異步建立
|
||||
CaseSensitive: true, // 是否區分大小寫
|
||||
}
|
||||
```
|
||||
|
||||
## 注意事項
|
||||
|
||||
### 1. 主鍵要求
|
||||
|
|
@ -540,6 +613,8 @@ err = userRepo.Query().
|
|||
- 建立索引前請先檢查 `db.SaiSupported()`
|
||||
- 索引建立是異步操作,可能需要一些時間
|
||||
- 刪除索引時使用 `IF EXISTS`,避免索引不存在時報錯
|
||||
- 使用 SAI 索引可以大幅提升非主鍵欄位的查詢效能
|
||||
- 全文索引支援不區分大小寫的搜尋
|
||||
|
||||
## 完整範例
|
||||
|
||||
|
|
|
|||
|
|
@ -12,43 +12,43 @@ import (
|
|||
type SAIIndexType string
|
||||
|
||||
const (
|
||||
// SAIIndexTypeStandard 標準索引(預設)
|
||||
SAIIndexTypeStandard SAIIndexType = "standard"
|
||||
// SAIIndexTypeFrozen 用於 frozen 類型
|
||||
SAIIndexTypeFrozen SAIIndexType = "frozen"
|
||||
// SAIIndexTypeStandard 標準索引(等於查詢)
|
||||
SAIIndexTypeStandard SAIIndexType = "STANDARD"
|
||||
// SAIIndexTypeCollection 集合索引(用於 list、set、map)
|
||||
SAIIndexTypeCollection SAIIndexType = "COLLECTION"
|
||||
// SAIIndexTypeFullText 全文索引
|
||||
SAIIndexTypeFullText SAIIndexType = "FULL_TEXT"
|
||||
)
|
||||
|
||||
// SAIIndexOptions 定義 SAI 索引選項
|
||||
type SAIIndexOptions struct {
|
||||
CaseSensitive *bool // 是否區分大小寫(預設:true)
|
||||
Normalize *bool // 是否正規化(預設:false)
|
||||
Analyzer string // 分析器(如 "StandardAnalyzer")
|
||||
IndexType SAIIndexType // 索引類型
|
||||
IsAsync bool // 是否異步建立索引
|
||||
CaseSensitive bool // 是否區分大小寫(用於全文索引)
|
||||
}
|
||||
|
||||
// SAIIndexInfo 表示 SAI 索引資訊
|
||||
type SAIIndexInfo struct {
|
||||
KeyspaceName string // Keyspace 名稱
|
||||
TableName string // 表名稱
|
||||
IndexName string // 索引名稱
|
||||
ColumnName string // 欄位名稱
|
||||
IndexType string // 索引類型
|
||||
Options map[string]string // 索引選項
|
||||
// DefaultSAIIndexOptions 返回預設的 SAI 索引選項
|
||||
func DefaultSAIIndexOptions() *SAIIndexOptions {
|
||||
return &SAIIndexOptions{
|
||||
IndexType: SAIIndexTypeStandard,
|
||||
IsAsync: false,
|
||||
CaseSensitive: true,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSAIIndex 建立 SAI 索引
|
||||
// keyspace: keyspace 名稱,如果為空則使用預設 keyspace
|
||||
// table: 表名稱
|
||||
// keyspace: keyspace 名稱
|
||||
// table: 資料表名稱
|
||||
// column: 欄位名稱
|
||||
// indexName: 索引名稱(可選,如果為空則自動生成)
|
||||
// options: 索引選項(可選)
|
||||
func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column string, indexName string, options *SAIIndexOptions) error {
|
||||
// opts: 索引選項(可選,如果為 nil 則使用預設選項)
|
||||
func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column, indexName string, opts *SAIIndexOptions) error {
|
||||
// 檢查是否支援 SAI
|
||||
if !db.saiSupported {
|
||||
return ErrSAINotSupported
|
||||
return ErrInvalidInput.WithError(fmt.Errorf("SAI is not supported in Cassandra version %s (requires 4.0.9+ or 5.0+)", db.version))
|
||||
}
|
||||
|
||||
if keyspace == "" {
|
||||
keyspace = db.defaultKeyspace
|
||||
}
|
||||
// 驗證參數
|
||||
if keyspace == "" {
|
||||
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
||||
}
|
||||
|
|
@ -59,51 +59,71 @@ func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column string
|
|||
return ErrInvalidInput.WithError(fmt.Errorf("column is required"))
|
||||
}
|
||||
|
||||
// 生成索引名稱(如果未提供)
|
||||
// 使用預設選項如果未提供
|
||||
if opts == nil {
|
||||
opts = DefaultSAIIndexOptions()
|
||||
}
|
||||
|
||||
// 生成索引名稱如果未提供
|
||||
if indexName == "" {
|
||||
indexName = fmt.Sprintf("%s_%s_%s_idx", table, column, "sai")
|
||||
indexName = fmt.Sprintf("%s_%s_sai_idx", table, column)
|
||||
}
|
||||
|
||||
// 構建 CREATE INDEX 語句
|
||||
stmt := fmt.Sprintf("CREATE INDEX %s ON %s.%s (%s) USING 'sai'", indexName, keyspace, table, column)
|
||||
var stmt strings.Builder
|
||||
stmt.WriteString("CREATE CUSTOM INDEX IF NOT EXISTS ")
|
||||
stmt.WriteString(indexName)
|
||||
stmt.WriteString(" ON ")
|
||||
stmt.WriteString(keyspace)
|
||||
stmt.WriteString(".")
|
||||
stmt.WriteString(table)
|
||||
stmt.WriteString(" (")
|
||||
stmt.WriteString(column)
|
||||
stmt.WriteString(") USING 'StorageAttachedIndex'")
|
||||
|
||||
// 添加選項
|
||||
if options != nil {
|
||||
opts := make([]string, 0)
|
||||
if options.CaseSensitive != nil {
|
||||
opts = append(opts, fmt.Sprintf("'case_sensitive': %v", *options.CaseSensitive))
|
||||
}
|
||||
if options.Normalize != nil {
|
||||
opts = append(opts, fmt.Sprintf("'normalize': %v", *options.Normalize))
|
||||
}
|
||||
if options.Analyzer != "" {
|
||||
opts = append(opts, fmt.Sprintf("'analyzer': '%s'", options.Analyzer))
|
||||
}
|
||||
if len(opts) > 0 {
|
||||
stmt += " WITH OPTIONS = {" + strings.Join(opts, ", ") + "}"
|
||||
}
|
||||
var options []string
|
||||
if opts.IsAsync {
|
||||
options = append(options, "'async'='true'")
|
||||
}
|
||||
|
||||
// 執行建立索引
|
||||
q := db.session.Query(stmt, nil).WithContext(ctx).Consistency(gocql.Quorum)
|
||||
if err := q.ExecRelease(); err != nil {
|
||||
return ErrInvalidInput.WithTable(table).WithError(fmt.Errorf("failed to create SAI index: %w", err))
|
||||
// 根據索引類型添加特定選項
|
||||
switch opts.IndexType {
|
||||
case SAIIndexTypeFullText:
|
||||
if !opts.CaseSensitive {
|
||||
options = append(options, "'case_sensitive'='false'")
|
||||
} else {
|
||||
options = append(options, "'case_sensitive'='true'")
|
||||
}
|
||||
case SAIIndexTypeCollection:
|
||||
// Collection 索引不需要額外選項
|
||||
}
|
||||
|
||||
// 如果有選項,添加到語句中
|
||||
if len(options) > 0 {
|
||||
stmt.WriteString(" WITH OPTIONS = {")
|
||||
stmt.WriteString(strings.Join(options, ", "))
|
||||
stmt.WriteString("}")
|
||||
}
|
||||
|
||||
// 執行建立索引語句
|
||||
query := db.session.Query(stmt.String(), nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum)
|
||||
|
||||
err := query.ExecRelease()
|
||||
if err != nil {
|
||||
return ErrInvalidInput.WithError(fmt.Errorf("failed to create SAI index: %w", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropSAIIndex 刪除 SAI 索引
|
||||
// keyspace: keyspace 名稱,如果為空則使用預設 keyspace
|
||||
// keyspace: keyspace 名稱
|
||||
// indexName: 索引名稱
|
||||
func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) error {
|
||||
if !db.saiSupported {
|
||||
return ErrSAINotSupported
|
||||
}
|
||||
|
||||
if keyspace == "" {
|
||||
keyspace = db.defaultKeyspace
|
||||
}
|
||||
// 驗證參數
|
||||
if keyspace == "" {
|
||||
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
||||
}
|
||||
|
|
@ -114,73 +134,66 @@ func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) erro
|
|||
// 構建 DROP INDEX 語句
|
||||
stmt := fmt.Sprintf("DROP INDEX IF EXISTS %s.%s", keyspace, indexName)
|
||||
|
||||
// 執行刪除索引
|
||||
q := db.session.Query(stmt, nil).WithContext(ctx).Consistency(gocql.Quorum)
|
||||
if err := q.ExecRelease(); err != nil {
|
||||
// 執行刪除索引語句
|
||||
query := db.session.Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum)
|
||||
|
||||
err := query.ExecRelease()
|
||||
if err != nil {
|
||||
return ErrInvalidInput.WithError(fmt.Errorf("failed to drop SAI index: %w", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListSAIIndexes 列出指定表的 SAI 索引
|
||||
// keyspace: keyspace 名稱,如果為空則使用預設 keyspace
|
||||
// table: 表名稱(可選,如果為空則列出所有表的索引)
|
||||
// ListSAIIndexes 列出指定資料表的所有 SAI 索引
|
||||
// keyspace: keyspace 名稱
|
||||
// table: 資料表名稱
|
||||
func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAIIndexInfo, error) {
|
||||
if !db.saiSupported {
|
||||
return nil, ErrSAINotSupported
|
||||
}
|
||||
|
||||
if keyspace == "" {
|
||||
keyspace = db.defaultKeyspace
|
||||
}
|
||||
// 驗證參數
|
||||
if keyspace == "" {
|
||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
||||
}
|
||||
|
||||
// 構建查詢語句
|
||||
// system_schema.indexes 表的欄位:keyspace_name, table_name, index_name, kind, options, index_type
|
||||
stmt := "SELECT keyspace_name, table_name, index_name, kind, options FROM system_schema.indexes WHERE keyspace_name = ?"
|
||||
args := []interface{}{keyspace}
|
||||
names := []string{"keyspace_name"}
|
||||
|
||||
if table != "" {
|
||||
stmt += " AND table_name = ?"
|
||||
args = append(args, table)
|
||||
names = append(names, "table_name")
|
||||
if table == "" {
|
||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("table is required"))
|
||||
}
|
||||
|
||||
// 執行查詢
|
||||
// 查詢系統表獲取索引資訊
|
||||
// system_schema.indexes 表的結構:keyspace_name, table_name, index_name, kind, options
|
||||
stmt := `
|
||||
SELECT index_name, kind, options
|
||||
FROM system_schema.indexes
|
||||
WHERE keyspace_name = ? AND table_name = ?
|
||||
`
|
||||
|
||||
var indexes []SAIIndexInfo
|
||||
iter := db.session.Query(stmt, names).Bind(args...).WithContext(ctx).Consistency(gocql.One).Iter()
|
||||
iter := db.session.Query(stmt, []string{"keyspace_name", "table_name"}).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.One).
|
||||
Bind(keyspace, table).
|
||||
Iter()
|
||||
|
||||
var keyspaceName, tableName, indexName, kind string
|
||||
var indexName, kind string
|
||||
var options map[string]string
|
||||
|
||||
for iter.Scan(&keyspaceName, &tableName, &indexName, &kind, &options) {
|
||||
// 只處理 SAI 索引(kind = 'CUSTOM' 且 index_type 在 options 中)
|
||||
indexType, ok := options["class_name"]
|
||||
if !ok || !strings.Contains(indexType, "StorageAttachedIndex") {
|
||||
continue
|
||||
}
|
||||
|
||||
// 從 options 中提取 column_name
|
||||
// SAI 索引的 target 欄位在 options 中
|
||||
for iter.Scan(&indexName, &kind, &options) {
|
||||
// 檢查是否為 SAI 索引(kind = 'CUSTOM' 且 class_name 包含 StorageAttachedIndex)
|
||||
if kind == "CUSTOM" {
|
||||
if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") {
|
||||
// 從 options 中提取 target(欄位名稱)
|
||||
columnName := ""
|
||||
if target, ok := options["target"]; ok {
|
||||
// target 格式通常是 "column_name" 或 "(column_name)"
|
||||
columnName = strings.Trim(target, "()\"'")
|
||||
}
|
||||
|
||||
indexes = append(indexes, SAIIndexInfo{
|
||||
KeyspaceName: keyspaceName,
|
||||
TableName: tableName,
|
||||
IndexName: indexName,
|
||||
ColumnName: columnName,
|
||||
IndexType: "sai",
|
||||
Name: indexName,
|
||||
Type: "StorageAttachedIndex",
|
||||
Options: options,
|
||||
Column: columnName,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("failed to list SAI indexes: %w", err))
|
||||
|
|
@ -189,59 +202,88 @@ func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAI
|
|||
return indexes, nil
|
||||
}
|
||||
|
||||
// GetSAIIndex 獲取指定索引的資訊
|
||||
// keyspace: keyspace 名稱,如果為空則使用預設 keyspace
|
||||
// indexName: 索引名稱
|
||||
func (db *DB) GetSAIIndex(ctx context.Context, keyspace, indexName string) (*SAIIndexInfo, error) {
|
||||
if !db.saiSupported {
|
||||
return nil, ErrSAINotSupported
|
||||
// SAIIndexInfo 表示 SAI 索引資訊
|
||||
type SAIIndexInfo struct {
|
||||
Name string // 索引名稱
|
||||
Type string // 索引類型
|
||||
Options map[string]string // 索引選項
|
||||
Column string // 索引欄位名稱
|
||||
}
|
||||
|
||||
// CheckSAIIndexExists 檢查 SAI 索引是否存在
|
||||
// keyspace: keyspace 名稱
|
||||
// indexName: 索引名稱
|
||||
func (db *DB) CheckSAIIndexExists(ctx context.Context, keyspace, indexName string) (bool, error) {
|
||||
// 驗證參數
|
||||
if keyspace == "" {
|
||||
keyspace = db.defaultKeyspace
|
||||
}
|
||||
if keyspace == "" {
|
||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
||||
return false, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
||||
}
|
||||
if indexName == "" {
|
||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
|
||||
return false, ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
|
||||
}
|
||||
|
||||
// 構建查詢語句
|
||||
stmt := "SELECT keyspace_name, table_name, index_name, kind, options FROM system_schema.indexes WHERE keyspace_name = ? AND index_name = ?"
|
||||
args := []interface{}{keyspace, indexName}
|
||||
names := []string{"keyspace_name", "index_name"}
|
||||
// 查詢系統表檢查索引是否存在
|
||||
stmt := `
|
||||
SELECT index_name, kind, options
|
||||
FROM system_schema.indexes
|
||||
WHERE keyspace_name = ? AND index_name = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
var keyspaceName, tableName, idxName, kind string
|
||||
var foundIndexName, kind string
|
||||
var options map[string]string
|
||||
err := db.session.Query(stmt, []string{"keyspace_name", "index_name"}).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.One).
|
||||
Bind(keyspace, indexName).
|
||||
Scan(&foundIndexName, &kind, &options)
|
||||
|
||||
// 執行查詢
|
||||
err := db.session.Query(stmt, names).Bind(args...).WithContext(ctx).Consistency(gocql.One).Scan(&keyspaceName, &tableName, &idxName, &kind, &options)
|
||||
if err != nil {
|
||||
if err == gocql.ErrNotFound {
|
||||
return nil, ErrNotFound.WithError(fmt.Errorf("index not found: %s", indexName))
|
||||
return false, nil
|
||||
}
|
||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("failed to get index: %w", err))
|
||||
if err != nil {
|
||||
return false, ErrInvalidInput.WithError(fmt.Errorf("failed to check SAI index existence: %w", err))
|
||||
}
|
||||
|
||||
// 檢查是否為 SAI 索引
|
||||
indexType, ok := options["class_name"]
|
||||
if !ok || !strings.Contains(indexType, "StorageAttachedIndex") {
|
||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("index %s is not a SAI index", indexName))
|
||||
if kind == "CUSTOM" {
|
||||
if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 從 options 中提取 column_name
|
||||
columnName := ""
|
||||
if target, ok := options["target"]; ok {
|
||||
columnName = strings.Trim(target, "()\"'")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return &SAIIndexInfo{
|
||||
KeyspaceName: keyspaceName,
|
||||
TableName: tableName,
|
||||
IndexName: idxName,
|
||||
ColumnName: columnName,
|
||||
IndexType: "sai",
|
||||
Options: options,
|
||||
}, nil
|
||||
// WaitForSAIIndex 等待 SAI 索引建立完成(用於異步建立)
|
||||
// keyspace: keyspace 名稱
|
||||
// indexName: 索引名稱
|
||||
// maxWaitTime: 最大等待時間(秒)
|
||||
func (db *DB) WaitForSAIIndex(ctx context.Context, keyspace, indexName string, maxWaitTime int) error {
|
||||
// 驗證參數
|
||||
if keyspace == "" {
|
||||
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
||||
}
|
||||
if indexName == "" {
|
||||
return ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
|
||||
}
|
||||
|
||||
// 查詢索引狀態
|
||||
// 注意:Cassandra 沒有直接的索引狀態查詢,這裡需要通過檢查索引是否可用來判斷
|
||||
// 實際實作可能需要根據具體的 Cassandra 版本調整
|
||||
|
||||
// 簡單實作:檢查索引是否存在
|
||||
exists, err := db.CheckSAIIndexExists(ctx, keyspace, indexName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return ErrInvalidInput.WithError(fmt.Errorf("index %s does not exist", indexName))
|
||||
}
|
||||
|
||||
// 注意:實際的等待邏輯可能需要查詢系統表或使用其他方法
|
||||
// 這裡只是基本框架,實際使用時可能需要根據具體需求調整
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,117 +1,50 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateSAIIndex(t *testing.T) {
|
||||
func TestDefaultSAIIndexOptions(t *testing.T) {
|
||||
opts := DefaultSAIIndexOptions()
|
||||
assert.NotNil(t, opts)
|
||||
assert.Equal(t, SAIIndexTypeStandard, opts.IndexType)
|
||||
assert.False(t, opts.IsAsync)
|
||||
assert.True(t, opts.CaseSensitive)
|
||||
}
|
||||
|
||||
func TestCreateSAIIndex_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyspace string
|
||||
table string
|
||||
column string
|
||||
indexName string
|
||||
options *SAIIndexOptions
|
||||
description string
|
||||
opts *SAIIndexOptions
|
||||
wantErr bool
|
||||
validate func(*testing.T, error)
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "create basic SAI index",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "name",
|
||||
indexName: "test_name_idx",
|
||||
options: nil,
|
||||
description: "should create a basic SAI index",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create SAI index with auto-generated name",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "email",
|
||||
indexName: "",
|
||||
options: nil,
|
||||
description: "should auto-generate index name",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create SAI index with case insensitive option",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "title",
|
||||
indexName: "test_title_idx",
|
||||
options: &SAIIndexOptions{CaseSensitive: boolPtr(false)},
|
||||
description: "should create index with case insensitive option",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create SAI index with normalize option",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "content",
|
||||
indexName: "test_content_idx",
|
||||
options: &SAIIndexOptions{Normalize: boolPtr(true)},
|
||||
description: "should create index with normalize option",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create SAI index with analyzer",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "description",
|
||||
indexName: "test_desc_idx",
|
||||
options: &SAIIndexOptions{Analyzer: "StandardAnalyzer"},
|
||||
description: "should create index with analyzer",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create SAI index with all options",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "text",
|
||||
indexName: "test_text_idx",
|
||||
options: &SAIIndexOptions{CaseSensitive: boolPtr(false), Normalize: boolPtr(true), Analyzer: "StandardAnalyzer"},
|
||||
description: "should create index with all options",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing keyspace",
|
||||
keyspace: "",
|
||||
table: "test_table",
|
||||
column: "name",
|
||||
column: "test_column",
|
||||
indexName: "test_idx",
|
||||
options: nil,
|
||||
description: "should return error when keyspace is empty and no default",
|
||||
opts: nil,
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
var e *Error
|
||||
if assert.ErrorAs(t, err, &e) {
|
||||
assert.Equal(t, ErrCodeInvalidInput, e.Code)
|
||||
}
|
||||
},
|
||||
errMsg: "keyspace is required",
|
||||
},
|
||||
{
|
||||
name: "missing table",
|
||||
keyspace: "test_keyspace",
|
||||
table: "",
|
||||
column: "name",
|
||||
column: "test_column",
|
||||
indexName: "test_idx",
|
||||
options: nil,
|
||||
description: "should return error when table is empty",
|
||||
opts: nil,
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
var e *Error
|
||||
if assert.ErrorAs(t, err, &e) {
|
||||
assert.Equal(t, ErrCodeInvalidInput, e.Code)
|
||||
}
|
||||
},
|
||||
errMsg: "table is required",
|
||||
},
|
||||
{
|
||||
name: "missing column",
|
||||
|
|
@ -119,265 +52,216 @@ func TestCreateSAIIndex(t *testing.T) {
|
|||
table: "test_table",
|
||||
column: "",
|
||||
indexName: "test_idx",
|
||||
options: nil,
|
||||
description: "should return error when column is empty",
|
||||
opts: nil,
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
var e *Error
|
||||
if assert.ErrorAs(t, err, &e) {
|
||||
assert.Equal(t, ErrCodeInvalidInput, e.Code)
|
||||
}
|
||||
errMsg: "column is required",
|
||||
},
|
||||
{
|
||||
name: "valid parameters with default options",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "test_column",
|
||||
indexName: "test_idx",
|
||||
opts: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid parameters with custom options",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "test_column",
|
||||
indexName: "test_idx",
|
||||
opts: &SAIIndexOptions{
|
||||
IndexType: SAIIndexTypeFullText,
|
||||
IsAsync: true,
|
||||
CaseSensitive: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "auto-generate index name",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
column: "test_column",
|
||||
indexName: "",
|
||||
opts: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 注意:這需要一個有效的 DB 實例和 SAI 支援
|
||||
// 在實際測試中,需要使用 testcontainers 或 mock
|
||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
||||
_ = tt
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDropSAIIndex(t *testing.T) {
|
||||
func TestDropSAIIndex_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyspace string
|
||||
indexName string
|
||||
description string
|
||||
wantErr bool
|
||||
validate func(*testing.T, error)
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "drop existing index",
|
||||
keyspace: "test_keyspace",
|
||||
indexName: "test_name_idx",
|
||||
description: "should drop existing index",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "drop non-existent index",
|
||||
keyspace: "test_keyspace",
|
||||
indexName: "non_existent_idx",
|
||||
description: "should not error when dropping non-existent index (IF EXISTS)",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing keyspace",
|
||||
keyspace: "",
|
||||
indexName: "test_idx",
|
||||
description: "should return error when keyspace is empty and no default",
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
var e *Error
|
||||
if assert.ErrorAs(t, err, &e) {
|
||||
assert.Equal(t, ErrCodeInvalidInput, e.Code)
|
||||
}
|
||||
},
|
||||
errMsg: "keyspace is required",
|
||||
},
|
||||
{
|
||||
name: "missing index name",
|
||||
keyspace: "test_keyspace",
|
||||
indexName: "",
|
||||
description: "should return error when index name is empty",
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
var e *Error
|
||||
if assert.ErrorAs(t, err, &e) {
|
||||
assert.Equal(t, ErrCodeInvalidInput, e.Code)
|
||||
}
|
||||
errMsg: "index name is required",
|
||||
},
|
||||
{
|
||||
name: "valid parameters",
|
||||
keyspace: "test_keyspace",
|
||||
indexName: "test_idx",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 注意:這需要一個有效的 DB 實例和 SAI 支援
|
||||
// 在實際測試中,需要使用 testcontainers 或 mock
|
||||
// 注意:這需要一個有效的 DB 實例
|
||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
||||
_ = tt
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSAIIndexes(t *testing.T) {
|
||||
func TestListSAIIndexes_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyspace string
|
||||
table string
|
||||
description string
|
||||
wantErr bool
|
||||
validate func(*testing.T, []SAIIndexInfo, error)
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "list all indexes in keyspace",
|
||||
keyspace: "test_keyspace",
|
||||
table: "",
|
||||
description: "should list all SAI indexes in keyspace",
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, indexes []SAIIndexInfo, err error) {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, indexes)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list indexes for specific table",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
description: "should list SAI indexes for specific table",
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, indexes []SAIIndexInfo, err error) {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, indexes)
|
||||
for _, idx := range indexes {
|
||||
assert.Equal(t, "test_table", idx.TableName)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing keyspace",
|
||||
keyspace: "",
|
||||
table: "",
|
||||
description: "should return error when keyspace is empty and no default",
|
||||
table: "test_table",
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, indexes []SAIIndexInfo, err error) {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, indexes)
|
||||
var e *Error
|
||||
if assert.ErrorAs(t, err, &e) {
|
||||
assert.Equal(t, ErrCodeInvalidInput, e.Code)
|
||||
}
|
||||
errMsg: "keyspace is required",
|
||||
},
|
||||
{
|
||||
name: "missing table",
|
||||
keyspace: "test_keyspace",
|
||||
table: "",
|
||||
wantErr: true,
|
||||
errMsg: "table is required",
|
||||
},
|
||||
{
|
||||
name: "valid parameters",
|
||||
keyspace: "test_keyspace",
|
||||
table: "test_table",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 注意:這需要一個有效的 DB 實例和 SAI 支援
|
||||
// 在實際測試中,需要使用 testcontainers 或 mock
|
||||
// 注意:這需要一個有效的 DB 實例
|
||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
||||
_ = tt
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSAIIndex(t *testing.T) {
|
||||
func TestCheckSAIIndexExists_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyspace string
|
||||
indexName string
|
||||
description string
|
||||
wantErr bool
|
||||
validate func(*testing.T, *SAIIndexInfo, error)
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "get existing index",
|
||||
keyspace: "test_keyspace",
|
||||
indexName: "test_name_idx",
|
||||
description: "should get existing SAI index",
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, index *SAIIndexInfo, err error) {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, index)
|
||||
assert.Equal(t, "test_name_idx", index.IndexName)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get non-existent index",
|
||||
keyspace: "test_keyspace",
|
||||
indexName: "non_existent_idx",
|
||||
description: "should return ErrNotFound",
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, index *SAIIndexInfo, err error) {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, index)
|
||||
assert.True(t, IsNotFound(err))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing keyspace",
|
||||
keyspace: "",
|
||||
indexName: "test_idx",
|
||||
description: "should return error when keyspace is empty and no default",
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, index *SAIIndexInfo, err error) {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, index)
|
||||
},
|
||||
errMsg: "keyspace is required",
|
||||
},
|
||||
{
|
||||
name: "missing index name",
|
||||
keyspace: "test_keyspace",
|
||||
indexName: "",
|
||||
description: "should return error when index name is empty",
|
||||
wantErr: true,
|
||||
validate: func(t *testing.T, index *SAIIndexInfo, err error) {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, index)
|
||||
errMsg: "index name is required",
|
||||
},
|
||||
{
|
||||
name: "valid parameters",
|
||||
keyspace: "test_keyspace",
|
||||
indexName: "test_idx",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 注意:這需要一個有效的 DB 實例和 SAI 支援
|
||||
// 在實際測試中,需要使用 testcontainers 或 mock
|
||||
// 注意:這需要一個有效的 DB 實例
|
||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
||||
_ = tt
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAIIndexOptions(t *testing.T) {
|
||||
t.Run("default options", func(t *testing.T) {
|
||||
opts := &SAIIndexOptions{}
|
||||
assert.Nil(t, opts.CaseSensitive)
|
||||
assert.Nil(t, opts.Normalize)
|
||||
assert.Empty(t, opts.Analyzer)
|
||||
})
|
||||
func TestSAIIndexType_Constants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
indexType SAIIndexType
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "standard index type",
|
||||
indexType: SAIIndexTypeStandard,
|
||||
expected: "STANDARD",
|
||||
},
|
||||
{
|
||||
name: "collection index type",
|
||||
indexType: SAIIndexTypeCollection,
|
||||
expected: "COLLECTION",
|
||||
},
|
||||
{
|
||||
name: "full text index type",
|
||||
indexType: SAIIndexTypeFullText,
|
||||
expected: "FULL_TEXT",
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("with case sensitive", func(t *testing.T) {
|
||||
caseSensitive := false
|
||||
opts := &SAIIndexOptions{CaseSensitive: &caseSensitive}
|
||||
assert.NotNil(t, opts.CaseSensitive)
|
||||
assert.False(t, *opts.CaseSensitive)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, string(tt.indexType))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("with normalize", func(t *testing.T) {
|
||||
normalize := true
|
||||
opts := &SAIIndexOptions{Normalize: &normalize}
|
||||
assert.NotNil(t, opts.Normalize)
|
||||
assert.True(t, *opts.Normalize)
|
||||
})
|
||||
|
||||
t.Run("with analyzer", func(t *testing.T) {
|
||||
opts := &SAIIndexOptions{Analyzer: "StandardAnalyzer"}
|
||||
assert.Equal(t, "StandardAnalyzer", opts.Analyzer)
|
||||
func TestCreateSAIIndex_NotSupported(t *testing.T) {
|
||||
t.Run("should return error when SAI not supported", func(t *testing.T) {
|
||||
// 注意:這需要一個不支援 SAI 的 DB 實例
|
||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
||||
})
|
||||
}
|
||||
|
||||
func TestSAIIndexInfo(t *testing.T) {
|
||||
t.Run("index info structure", func(t *testing.T) {
|
||||
info := SAIIndexInfo{
|
||||
KeyspaceName: "test_keyspace",
|
||||
TableName: "test_table",
|
||||
IndexName: "test_idx",
|
||||
ColumnName: "name",
|
||||
IndexType: "sai",
|
||||
Options: map[string]string{"target": "name"},
|
||||
}
|
||||
func TestCreateSAIIndex_IndexNameGeneration(t *testing.T) {
|
||||
t.Run("should generate index name when not provided", func(t *testing.T) {
|
||||
// 測試自動生成索引名稱的邏輯
|
||||
// 格式應該是: {table}_{column}_sai_idx
|
||||
table := "users"
|
||||
column := "email"
|
||||
expected := "users_email_sai_idx"
|
||||
|
||||
assert.Equal(t, "test_keyspace", info.KeyspaceName)
|
||||
assert.Equal(t, "test_table", info.TableName)
|
||||
assert.Equal(t, "test_idx", info.IndexName)
|
||||
assert.Equal(t, "name", info.ColumnName)
|
||||
assert.Equal(t, "sai", info.IndexType)
|
||||
assert.NotNil(t, info.Options)
|
||||
// 這裡只是測試命名邏輯,實際建立需要 DB 實例
|
||||
generated := fmt.Sprintf("%s_%s_sai_idx", table, column)
|
||||
assert.Equal(t, expected, generated)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ package post
|
|||
// CommentStatus 評論狀態
|
||||
type CommentStatus int32
|
||||
|
||||
func (s *CommentStatus) CodeToString() string {
|
||||
result, ok := commentStatusMap[*s]
|
||||
func (s CommentStatus) CodeToString() string {
|
||||
result, ok := commentStatusMap[s]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
|
@ -17,8 +17,8 @@ var commentStatusMap = map[CommentStatus]string{
|
|||
CommentStatusHidden: "hidden", // 隱藏
|
||||
}
|
||||
|
||||
func (s *CommentStatus) ToInt32() int32 {
|
||||
return int32(*s)
|
||||
func (s CommentStatus) ToInt32() int32 {
|
||||
return int32(s)
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ package post
|
|||
// Status 貼文狀態
|
||||
type Status int32
|
||||
|
||||
func (s *Status) CodeToString() string {
|
||||
result, ok := postStatusMap[*s]
|
||||
func (s Status) CodeToString() string {
|
||||
result, ok := postStatusMap[s]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
|
@ -19,8 +19,8 @@ var postStatusMap = map[Status]string{
|
|||
PostStatusHidden: "hidden", // 隱藏
|
||||
}
|
||||
|
||||
func (s *Status) ToInt32() int32 {
|
||||
return int32(*s)
|
||||
func (s Status) ToInt32() int32 {
|
||||
return int32(s)
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ package post
|
|||
// Type 貼文類型
|
||||
type Type int32
|
||||
|
||||
func (t *Type) CodeToString() string {
|
||||
result, ok := postTypeMap[*t]
|
||||
func (t Type) CodeToString() string {
|
||||
result, ok := postTypeMap[t]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
|
@ -12,28 +12,28 @@ func (t *Type) CodeToString() string {
|
|||
}
|
||||
|
||||
var postTypeMap = map[Type]string{
|
||||
PostTypeText: "text", // 純文字
|
||||
PostTypeImage: "image", // 圖片
|
||||
PostTypeVideo: "video", // 影片
|
||||
PostTypeLink: "link", // 連結
|
||||
PostTypePoll: "poll", // 投票
|
||||
PostTypeArticle: "article", // 長文
|
||||
TypeText: "text", // 純文字
|
||||
TypeImage: "image", // 圖片
|
||||
TypeVideo: "video", // 影片
|
||||
TypeLink: "link", // 連結
|
||||
TypePoll: "poll", // 投票
|
||||
TypeArticle: "article", // 長文
|
||||
}
|
||||
|
||||
func (t *Type) ToInt32() int32 {
|
||||
return int32(*t)
|
||||
func (t Type) ToInt32() int32 {
|
||||
return int32(t)
|
||||
}
|
||||
|
||||
const (
|
||||
PostTypeText Type = 0 // 純文字
|
||||
PostTypeImage Type = 1 // 圖片
|
||||
PostTypeVideo Type = 2 // 影片
|
||||
PostTypeLink Type = 3 // 連結
|
||||
PostTypePoll Type = 4 // 投票
|
||||
PostTypeArticle Type = 5 // 長文
|
||||
TypeText Type = 0 // 純文字
|
||||
TypeImage Type = 1 // 圖片
|
||||
TypeVideo Type = 2 // 影片
|
||||
TypeLink Type = 3 // 連結
|
||||
TypePoll Type = 4 // 投票
|
||||
TypeArticle Type = 5 // 長文
|
||||
)
|
||||
|
||||
// IsValid returns true if the type is valid
|
||||
func (t Type) IsValid() bool {
|
||||
return t >= PostTypeText && t <= PostTypeArticle
|
||||
return t >= TypeText && t <= TypeArticle
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,26 +4,23 @@ import (
|
|||
"context"
|
||||
|
||||
"backend/pkg/post/domain/entity"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// CategoryRepository defines the interface for category data access operations
|
||||
type CategoryRepository interface {
|
||||
BaseCategoryRepository
|
||||
FindBySlug(ctx context.Context, slug string) (*entity.Category, error)
|
||||
FindByParentID(ctx context.Context, parentID *gocql.UUID) ([]*entity.Category, error)
|
||||
FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error)
|
||||
FindRootCategories(ctx context.Context) ([]*entity.Category, error)
|
||||
FindActive(ctx context.Context) ([]*entity.Category, error)
|
||||
IncrementPostCount(ctx context.Context, categoryID gocql.UUID) error
|
||||
DecrementPostCount(ctx context.Context, categoryID gocql.UUID) error
|
||||
IncrementPostCount(ctx context.Context, categoryID string) error
|
||||
DecrementPostCount(ctx context.Context, categoryID string) error
|
||||
}
|
||||
|
||||
// BaseCategoryRepository defines basic CRUD operations for categories
|
||||
type BaseCategoryRepository interface {
|
||||
Insert(ctx context.Context, data *entity.Category) error
|
||||
FindOne(ctx context.Context, id gocql.UUID) (*entity.Category, error)
|
||||
FindOne(ctx context.Context, id string) (*entity.Category, error)
|
||||
Update(ctx context.Context, data *entity.Category) error
|
||||
Delete(ctx context.Context, id gocql.UUID) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,263 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"backend/pkg/library/cassandra"
|
||||
"backend/pkg/post/domain/entity"
|
||||
domainRepo "backend/pkg/post/domain/repository"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// CategoryRepositoryParam 定義 CategoryRepository 的初始化參數
|
||||
type CategoryRepositoryParam struct {
|
||||
DB *cassandra.DB
|
||||
Keyspace string
|
||||
}
|
||||
|
||||
// CategoryRepository 實作 domain repository 介面
|
||||
type CategoryRepository struct {
|
||||
repo cassandra.Repository[*entity.Category]
|
||||
db *cassandra.DB
|
||||
keyspace string
|
||||
}
|
||||
|
||||
// NewCategoryRepository 創建新的 CategoryRepository
|
||||
func NewCategoryRepository(param CategoryRepositoryParam) domainRepo.CategoryRepository {
|
||||
repo, err := cassandra.NewRepository[*entity.Category](param.DB, param.Keyspace)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create category repository: %v", err))
|
||||
}
|
||||
|
||||
keyspace := param.Keyspace
|
||||
if keyspace == "" {
|
||||
keyspace = param.DB.GetDefaultKeyspace()
|
||||
}
|
||||
|
||||
return &CategoryRepository{
|
||||
repo: repo,
|
||||
db: param.DB,
|
||||
keyspace: keyspace,
|
||||
}
|
||||
}
|
||||
|
||||
// Insert 插入單筆分類
|
||||
func (r *CategoryRepository) Insert(ctx context.Context, data *entity.Category) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
if data.ParentID == nil {
|
||||
data.ParentID = &gocql.UUID{}
|
||||
}
|
||||
|
||||
// 設置時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
// 如果是新分類,生成 ID
|
||||
if data.IsNew() {
|
||||
data.ID = gocql.TimeUUID()
|
||||
}
|
||||
|
||||
// Slug 轉為小寫
|
||||
data.Slug = strings.ToLower(strings.TrimSpace(data.Slug))
|
||||
|
||||
return r.repo.Insert(ctx, data)
|
||||
}
|
||||
|
||||
// FindOne 根據 ID 查詢單筆分類
|
||||
func (r *CategoryRepository) FindOne(ctx context.Context, id string) (*entity.Category, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
uuid, err := gocql.ParseUUID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if uuid == zeroUUID {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
category, err := r.repo.Get(ctx, id)
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find category: %w", err)
|
||||
}
|
||||
|
||||
return category, nil
|
||||
}
|
||||
|
||||
// Update 更新分類
|
||||
func (r *CategoryRepository) Update(ctx context.Context, data *entity.Category) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 更新時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
// Slug 轉為小寫
|
||||
data.Slug = strings.ToLower(strings.TrimSpace(data.Slug))
|
||||
|
||||
return r.repo.Update(ctx, data)
|
||||
}
|
||||
|
||||
// Delete 刪除分類
|
||||
func (r *CategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
var zeroUUID gocql.UUID
|
||||
uuid, err := gocql.ParseUUID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if uuid == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
return r.repo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// FindBySlug 根據 slug 查詢分類
|
||||
func (r *CategoryRepository) FindBySlug(ctx context.Context, slug string) (*entity.Category, error) {
|
||||
if slug == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 標準化 slug
|
||||
slug = strings.ToLower(strings.TrimSpace(slug))
|
||||
|
||||
// 構建查詢(要有 SAI 索引在 slug 欄位上)
|
||||
query := r.repo.Query().Where(cassandra.Eq("slug", slug))
|
||||
|
||||
var categories []*entity.Category
|
||||
if err := query.Scan(ctx, &categories); err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query category: %w", err)
|
||||
}
|
||||
|
||||
if len(categories) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return categories[0], nil
|
||||
}
|
||||
|
||||
// FindByParentID 根據父分類 ID 查詢子分類
|
||||
func (r *CategoryRepository) FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error) {
|
||||
query := r.repo.Query()
|
||||
var zeroUUID gocql.UUID
|
||||
if parentID != "" {
|
||||
// 構建查詢(有 SAI 索引在 parentID 欄位上)
|
||||
uuid, err := gocql.ParseUUID(parentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if uuid != zeroUUID {
|
||||
query = query.Where(cassandra.Eq("parent_id", uuid))
|
||||
}
|
||||
} else {
|
||||
query = query.Where(cassandra.Eq("parent_id", zeroUUID))
|
||||
}
|
||||
|
||||
// 按 sort_order 排序
|
||||
query = query.OrderBy("sort_order", cassandra.ASC)
|
||||
|
||||
var categories []*entity.Category
|
||||
if err := query.Scan(ctx, &categories); err != nil {
|
||||
return nil, fmt.Errorf("failed to query categories: %w", err)
|
||||
}
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
// FindRootCategories 查詢根分類
|
||||
func (r *CategoryRepository) FindRootCategories(ctx context.Context) ([]*entity.Category, error) {
|
||||
return r.FindByParentID(ctx, "")
|
||||
}
|
||||
|
||||
// FindActive 查詢啟用的分類
|
||||
func (r *CategoryRepository) FindActive(ctx context.Context) ([]*entity.Category, error) {
|
||||
query := r.repo.Query().
|
||||
Where(cassandra.Eq("is_active", true)).
|
||||
OrderBy("sort_order", cassandra.ASC)
|
||||
|
||||
var categories []*entity.Category
|
||||
if err := query.Scan(ctx, &categories); err != nil {
|
||||
return nil, fmt.Errorf("failed to query active categories: %w", err)
|
||||
}
|
||||
|
||||
result := categories
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:post_count 欄位必須是 counter 類型
|
||||
func (r *CategoryRepository) IncrementPostCount(ctx context.Context, categoryID string) error {
|
||||
uuid, err := gocql.ParseUUID(categoryID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 使用 counter 原子更新操作:UPDATE categories SET post_count = post_count + 1 WHERE id = ?
|
||||
var zeroCategory entity.Category
|
||||
tableName := zeroCategory.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(uuid)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to increment post count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:post_count 欄位必須是 counter 類型
|
||||
func (r *CategoryRepository) DecrementPostCount(ctx context.Context, categoryID string) error {
|
||||
uuid, err := gocql.ParseUUID(categoryID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 使用 counter 原子更新操作:UPDATE categories SET post_count = post_count - 1 WHERE id = ?
|
||||
var zeroCategory entity.Category
|
||||
tableName := zeroCategory.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(uuid)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to decrement post count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,383 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"backend/pkg/library/cassandra"
|
||||
"backend/pkg/post/domain/entity"
|
||||
"backend/pkg/post/domain/post"
|
||||
domainRepo "backend/pkg/post/domain/repository"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// CommentRepositoryParam 定義 CommentRepository 的初始化參數
|
||||
type CommentRepositoryParam struct {
|
||||
DB *cassandra.DB
|
||||
Keyspace string
|
||||
}
|
||||
|
||||
// CommentRepository 實作 domain repository 介面
|
||||
type CommentRepository struct {
|
||||
repo cassandra.Repository[*entity.Comment]
|
||||
db *cassandra.DB
|
||||
keyspace string
|
||||
}
|
||||
|
||||
// NewCommentRepository 創建新的 CommentRepository
|
||||
func NewCommentRepository(param CommentRepositoryParam) domainRepo.CommentRepository {
|
||||
repo, err := cassandra.NewRepository[*entity.Comment](param.DB, param.Keyspace)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create comment repository: %v", err))
|
||||
}
|
||||
|
||||
keyspace := param.Keyspace
|
||||
if keyspace == "" {
|
||||
keyspace = param.DB.GetDefaultKeyspace()
|
||||
}
|
||||
|
||||
return &CommentRepository{
|
||||
repo: repo,
|
||||
db: param.DB,
|
||||
keyspace: keyspace,
|
||||
}
|
||||
}
|
||||
|
||||
// Insert 插入單筆評論
|
||||
func (r *CommentRepository) Insert(ctx context.Context, data *entity.Comment) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 設置時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
// 如果是新評論,生成 ID
|
||||
if data.IsNew() {
|
||||
data.ID = gocql.TimeUUID()
|
||||
}
|
||||
|
||||
return r.repo.Insert(ctx, data)
|
||||
}
|
||||
|
||||
// FindOne 根據 ID 查詢單筆評論
|
||||
func (r *CommentRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Comment, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if id == zeroUUID {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
comment, err := r.repo.Get(ctx, id)
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find comment: %w", err)
|
||||
}
|
||||
|
||||
return comment, nil
|
||||
}
|
||||
|
||||
// Update 更新評論
|
||||
func (r *CommentRepository) Update(ctx context.Context, data *entity.Comment) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 更新時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
return r.repo.Update(ctx, data)
|
||||
}
|
||||
|
||||
// Delete 刪除評論(軟刪除)
|
||||
func (r *CommentRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if id == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 先查詢評論
|
||||
comment, err := r.FindOne(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 軟刪除:標記為已刪除
|
||||
comment.Delete()
|
||||
return r.Update(ctx, comment)
|
||||
}
|
||||
|
||||
// FindByPostID 根據貼文 ID 查詢評論
|
||||
func (r *CommentRepository) FindByPostID(ctx context.Context, postID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if postID == zeroUUID {
|
||||
return nil, 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 構建查詢(使用 PostID 作為 clustering key)
|
||||
query := r.repo.Query().Where(cassandra.Eq("post_id", postID))
|
||||
|
||||
// 添加父評論過濾(如果指定,只查詢回覆)
|
||||
if params != nil && params.ParentID != nil {
|
||||
query = query.Where(cassandra.Eq("parent_id", *params.ParentID))
|
||||
} else {
|
||||
// 如果沒有指定 ParentID,只查詢頂層評論(parent_id 為 null)
|
||||
// 注意:Cassandra 不支援直接查詢 null,需要特殊處理
|
||||
// 這裡簡化處理,實際可能需要使用 Materialized View
|
||||
}
|
||||
|
||||
// 添加狀態過濾
|
||||
if params != nil && params.Status != nil {
|
||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||||
} else {
|
||||
// 預設只查詢已發布的評論
|
||||
published := post.CommentStatusPublished
|
||||
query = query.Where(cassandra.Eq("status", published))
|
||||
}
|
||||
|
||||
// 添加排序
|
||||
orderBy := "created_at"
|
||||
if params != nil && params.OrderBy != "" {
|
||||
orderBy = params.OrderBy
|
||||
}
|
||||
order := cassandra.ASC
|
||||
if params != nil && params.OrderDirection == "DESC" {
|
||||
order = cassandra.DESC
|
||||
}
|
||||
query = query.OrderBy(orderBy, order)
|
||||
|
||||
// 添加分頁
|
||||
pageSize := int64(20)
|
||||
if params != nil && params.PageSize > 0 {
|
||||
pageSize = params.PageSize
|
||||
}
|
||||
limit := int(pageSize)
|
||||
query = query.Limit(limit)
|
||||
|
||||
// 執行查詢
|
||||
var comments []*entity.Comment
|
||||
if err := query.Scan(ctx, &comments); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query comments: %w", err)
|
||||
}
|
||||
|
||||
result := comments
|
||||
|
||||
total := int64(len(result))
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// FindByParentID 根據父評論 ID 查詢回覆
|
||||
func (r *CommentRepository) FindByParentID(ctx context.Context, parentID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if parentID == zeroUUID {
|
||||
return nil, 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
query := r.repo.Query().Where(cassandra.Eq("parent_id", parentID))
|
||||
|
||||
// 添加狀態過濾
|
||||
if params != nil && params.Status != nil {
|
||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||||
} else {
|
||||
published := post.CommentStatusPublished
|
||||
query = query.Where(cassandra.Eq("status", published))
|
||||
}
|
||||
|
||||
// 添加排序和分頁
|
||||
orderBy := "created_at"
|
||||
if params != nil && params.OrderBy != "" {
|
||||
orderBy = params.OrderBy
|
||||
}
|
||||
order := cassandra.ASC
|
||||
if params != nil && params.OrderDirection == "DESC" {
|
||||
order = cassandra.DESC
|
||||
}
|
||||
query = query.OrderBy(orderBy, order)
|
||||
|
||||
pageSize := int64(20)
|
||||
if params != nil && params.PageSize > 0 {
|
||||
pageSize = params.PageSize
|
||||
}
|
||||
query = query.Limit(int(pageSize))
|
||||
|
||||
var comments []*entity.Comment
|
||||
if err := query.Scan(ctx, &comments); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query replies: %w", err)
|
||||
}
|
||||
|
||||
return comments, int64(len(comments)), nil
|
||||
}
|
||||
|
||||
// FindByAuthorUID 根據作者 UID 查詢評論
|
||||
func (r *CommentRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
|
||||
if authorUID == "" {
|
||||
return nil, 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID))
|
||||
|
||||
// 添加狀態過濾
|
||||
if params != nil && params.Status != nil {
|
||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||||
}
|
||||
|
||||
// 添加排序和分頁
|
||||
orderBy := "created_at"
|
||||
if params != nil && params.OrderBy != "" {
|
||||
orderBy = params.OrderBy
|
||||
}
|
||||
order := cassandra.DESC
|
||||
if params != nil && params.OrderDirection == "ASC" {
|
||||
order = cassandra.ASC
|
||||
}
|
||||
query = query.OrderBy(orderBy, order)
|
||||
|
||||
pageSize := int64(20)
|
||||
if params != nil && params.PageSize > 0 {
|
||||
pageSize = params.PageSize
|
||||
}
|
||||
query = query.Limit(int(pageSize))
|
||||
|
||||
var comments []*entity.Comment
|
||||
if err := query.Scan(ctx, &comments); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query comments: %w", err)
|
||||
}
|
||||
|
||||
return comments, int64(len(comments)), nil
|
||||
}
|
||||
|
||||
// FindReplies 查詢指定評論的回覆
|
||||
func (r *CommentRepository) FindReplies(ctx context.Context, commentID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
|
||||
return r.FindByParentID(ctx, commentID, params)
|
||||
}
|
||||
|
||||
// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:like_count 欄位必須是 counter 類型
|
||||
func (r *CommentRepository) IncrementLikeCount(ctx context.Context, commentID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if commentID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroComment entity.Comment
|
||||
tableName := zeroComment.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(commentID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to increment like count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:like_count 欄位必須是 counter 類型
|
||||
func (r *CommentRepository) DecrementLikeCount(ctx context.Context, commentID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if commentID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroComment entity.Comment
|
||||
tableName := zeroComment.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(commentID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to decrement like count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementReplyCount 增加回覆數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:reply_count 欄位必須是 counter 類型
|
||||
func (r *CommentRepository) IncrementReplyCount(ctx context.Context, commentID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if commentID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroComment entity.Comment
|
||||
tableName := zeroComment.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET reply_count = reply_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(commentID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to increment reply count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecrementReplyCount 減少回覆數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:reply_count 欄位必須是 counter 類型
|
||||
func (r *CommentRepository) DecrementReplyCount(ctx context.Context, commentID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if commentID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroComment entity.Comment
|
||||
tableName := zeroComment.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET reply_count = reply_count - 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(commentID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to decrement reply count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新評論狀態
|
||||
func (r *CommentRepository) UpdateStatus(ctx context.Context, commentID gocql.UUID, status post.CommentStatus) error {
|
||||
comment, err := r.FindOne(ctx, commentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
comment.Status = status
|
||||
return r.Update(ctx, comment)
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"backend/pkg/library/cassandra"
|
||||
)
|
||||
|
||||
// Common repository errors
|
||||
var (
|
||||
// ErrNotFound is returned when a requested resource is not found
|
||||
ErrNotFound = errors.New("resource not found")
|
||||
|
||||
// ErrInvalidInput is returned when input validation fails
|
||||
ErrInvalidInput = errors.New("invalid input")
|
||||
|
||||
// ErrDuplicateKey is returned when attempting to insert a document with a duplicate key
|
||||
ErrDuplicateKey = errors.New("duplicate key error")
|
||||
)
|
||||
|
||||
// IsNotFound checks if the error is a not found error
|
||||
func IsNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if err == ErrNotFound {
|
||||
return true
|
||||
}
|
||||
if cassandra.IsNotFound(err) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"backend/pkg/library/cassandra"
|
||||
"backend/pkg/post/domain/entity"
|
||||
domainRepo "backend/pkg/post/domain/repository"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// LikeRepositoryParam 定義 LikeRepository 的初始化參數
|
||||
type LikeRepositoryParam struct {
|
||||
DB *cassandra.DB
|
||||
Keyspace string
|
||||
}
|
||||
|
||||
// LikeRepository 實作 domain repository 介面
|
||||
type LikeRepository struct {
|
||||
repo cassandra.Repository[*entity.Like]
|
||||
db *cassandra.DB
|
||||
}
|
||||
|
||||
// NewLikeRepository 創建新的 LikeRepository
|
||||
func NewLikeRepository(param LikeRepositoryParam) domainRepo.LikeRepository {
|
||||
repo, err := cassandra.NewRepository[*entity.Like](param.DB, param.Keyspace)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create like repository: %v", err))
|
||||
}
|
||||
|
||||
return &LikeRepository{
|
||||
repo: repo,
|
||||
db: param.DB,
|
||||
}
|
||||
}
|
||||
|
||||
// Insert 插入單筆按讚
|
||||
func (r *LikeRepository) Insert(ctx context.Context, data *entity.Like) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 設置時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
// 如果是新按讚,生成 ID
|
||||
if data.IsNew() {
|
||||
data.ID = gocql.TimeUUID()
|
||||
}
|
||||
|
||||
return r.repo.Insert(ctx, data)
|
||||
}
|
||||
|
||||
// FindOne 根據 ID 查詢單筆按讚
|
||||
func (r *LikeRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Like, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if id == zeroUUID {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
like, err := r.repo.Get(ctx, id)
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find like: %w", err)
|
||||
}
|
||||
|
||||
return like, nil
|
||||
}
|
||||
|
||||
// Delete 刪除按讚
|
||||
func (r *LikeRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if id == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
return r.repo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// FindByTargetID 根據目標 ID 查詢按讚列表
|
||||
func (r *LikeRepository) FindByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) ([]*entity.Like, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if targetID == zeroUUID {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
if targetType != "post" && targetType != "comment" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 構建查詢
|
||||
query := r.repo.Query().
|
||||
Where(cassandra.Eq("target_id", targetID)).
|
||||
Where(cassandra.Eq("target_type", targetType)).
|
||||
OrderBy("created_at", cassandra.DESC)
|
||||
|
||||
var likes []*entity.Like
|
||||
if err := query.Scan(ctx, &likes); err != nil {
|
||||
return nil, fmt.Errorf("failed to query likes: %w", err)
|
||||
}
|
||||
|
||||
return likes, nil
|
||||
}
|
||||
|
||||
// FindByUserUID 根據用戶 UID 查詢按讚列表
|
||||
func (r *LikeRepository) FindByUserUID(ctx context.Context, userUID string, params *domainRepo.LikeQueryParams) ([]*entity.Like, int64, error) {
|
||||
if userUID == "" {
|
||||
return nil, 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
query := r.repo.Query().Where(cassandra.Eq("user_uid", userUID))
|
||||
|
||||
// 添加目標類型過濾
|
||||
if params != nil && params.TargetType != nil {
|
||||
query = query.Where(cassandra.Eq("target_type", *params.TargetType))
|
||||
}
|
||||
|
||||
// 添加目標 ID 過濾
|
||||
if params != nil && params.TargetID != nil {
|
||||
query = query.Where(cassandra.Eq("target_id", *params.TargetID))
|
||||
}
|
||||
|
||||
// 添加排序
|
||||
orderBy := "created_at"
|
||||
if params != nil && params.OrderBy != "" {
|
||||
orderBy = params.OrderBy
|
||||
}
|
||||
order := cassandra.DESC
|
||||
if params != nil && params.OrderDirection == "ASC" {
|
||||
order = cassandra.ASC
|
||||
}
|
||||
query = query.OrderBy(orderBy, order)
|
||||
|
||||
// 添加分頁
|
||||
pageSize := int64(20)
|
||||
if params != nil && params.PageSize > 0 {
|
||||
pageSize = params.PageSize
|
||||
}
|
||||
query = query.Limit(int(pageSize))
|
||||
|
||||
var likes []*entity.Like
|
||||
if err := query.Scan(ctx, &likes); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query likes: %w", err)
|
||||
}
|
||||
|
||||
result := likes
|
||||
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
// FindByTargetAndUser 根據目標和用戶查詢按讚
|
||||
func (r *LikeRepository) FindByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) (*entity.Like, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if targetID == zeroUUID || userUID == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
if targetType != "post" && targetType != "comment" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 構建查詢
|
||||
query := r.repo.Query().
|
||||
Where(cassandra.Eq("target_id", targetID)).
|
||||
Where(cassandra.Eq("user_uid", userUID)).
|
||||
Where(cassandra.Eq("target_type", targetType)).
|
||||
Limit(1)
|
||||
|
||||
var likes []*entity.Like
|
||||
if err := query.Scan(ctx, &likes); err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query like: %w", err)
|
||||
}
|
||||
|
||||
if len(likes) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return likes[0], nil
|
||||
}
|
||||
|
||||
// CountByTargetID 計算目標的按讚數
|
||||
func (r *LikeRepository) CountByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) (int64, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if targetID == zeroUUID {
|
||||
return 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
if targetType != "post" && targetType != "comment" {
|
||||
return 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 構建查詢
|
||||
query := r.repo.Query().
|
||||
Where(cassandra.Eq("target_id", targetID)).
|
||||
Where(cassandra.Eq("target_type", targetType))
|
||||
|
||||
count, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count likes: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// DeleteByTargetAndUser 根據目標和用戶刪除按讚
|
||||
func (r *LikeRepository) DeleteByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) error {
|
||||
// 先查詢按讚
|
||||
like, err := r.FindByTargetAndUser(ctx, targetID, userUID, targetType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 刪除按讚
|
||||
return r.Delete(ctx, like.ID)
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,511 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"backend/pkg/library/cassandra"
|
||||
"backend/pkg/post/domain/entity"
|
||||
"backend/pkg/post/domain/post"
|
||||
domainRepo "backend/pkg/post/domain/repository"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// PostRepositoryParam 定義 PostRepository 的初始化參數
|
||||
type PostRepositoryParam struct {
|
||||
DB *cassandra.DB
|
||||
Keyspace string
|
||||
}
|
||||
|
||||
// PostRepository 實作 domain repository 介面
|
||||
type PostRepository struct {
|
||||
repo cassandra.Repository[*entity.Post]
|
||||
db *cassandra.DB
|
||||
keyspace string
|
||||
}
|
||||
|
||||
// NewPostRepository 創建新的 PostRepository
|
||||
func NewPostRepository(param PostRepositoryParam) domainRepo.PostRepository {
|
||||
repo, err := cassandra.NewRepository[*entity.Post](param.DB, param.Keyspace)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create post repository: %v", err))
|
||||
}
|
||||
|
||||
keyspace := param.Keyspace
|
||||
if keyspace == "" {
|
||||
keyspace = param.DB.GetDefaultKeyspace()
|
||||
}
|
||||
|
||||
return &PostRepository{
|
||||
repo: repo,
|
||||
db: param.DB,
|
||||
keyspace: keyspace,
|
||||
}
|
||||
}
|
||||
|
||||
// Insert 插入單筆貼文
|
||||
func (r *PostRepository) Insert(ctx context.Context, data *entity.Post) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 設置時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
// 如果是新貼文,生成 ID
|
||||
if data.IsNew() {
|
||||
data.ID = gocql.TimeUUID()
|
||||
}
|
||||
|
||||
// 如果狀態是 published,設置發布時間
|
||||
if data.Status == post.PostStatusPublished && data.PublishedAt == nil {
|
||||
now := data.CreatedAt
|
||||
data.PublishedAt = &now
|
||||
}
|
||||
|
||||
return r.repo.Insert(ctx, data)
|
||||
}
|
||||
|
||||
// FindOne 根據 ID 查詢單筆貼文
|
||||
func (r *PostRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Post, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if id == zeroUUID {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
post, err := r.repo.Get(ctx, id)
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find post: %w", err)
|
||||
}
|
||||
|
||||
return post, nil
|
||||
}
|
||||
|
||||
// Update 更新貼文
|
||||
func (r *PostRepository) Update(ctx context.Context, data *entity.Post) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 更新時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
return r.repo.Update(ctx, data)
|
||||
}
|
||||
|
||||
// Delete 刪除貼文(軟刪除)
|
||||
func (r *PostRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if id == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 先查詢貼文
|
||||
post, err := r.FindOne(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 軟刪除:標記為已刪除
|
||||
post.Delete()
|
||||
return r.Update(ctx, post)
|
||||
}
|
||||
|
||||
// FindByAuthorUID 根據作者 UID 查詢貼文
|
||||
func (r *PostRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
||||
if authorUID == "" {
|
||||
return nil, 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 構建查詢
|
||||
query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID))
|
||||
|
||||
// 添加狀態過濾
|
||||
if params != nil && params.Status != nil {
|
||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||||
}
|
||||
|
||||
// 添加排序
|
||||
orderBy := "created_at"
|
||||
if params != nil && params.OrderBy != "" {
|
||||
orderBy = params.OrderBy
|
||||
}
|
||||
order := cassandra.DESC
|
||||
if params != nil && params.OrderDirection == "ASC" {
|
||||
order = cassandra.ASC
|
||||
}
|
||||
query = query.OrderBy(orderBy, order)
|
||||
|
||||
// 添加分頁
|
||||
pageSize := int64(20)
|
||||
if params != nil && params.PageSize > 0 {
|
||||
pageSize = params.PageSize
|
||||
}
|
||||
pageIndex := int64(1)
|
||||
if params != nil && params.PageIndex > 0 {
|
||||
pageIndex = params.PageIndex
|
||||
}
|
||||
limit := int(pageSize)
|
||||
query = query.Limit(limit)
|
||||
|
||||
// 執行查詢
|
||||
var posts []*entity.Post
|
||||
if err := query.Scan(ctx, &posts); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
||||
}
|
||||
|
||||
result := posts
|
||||
|
||||
// 計算總數(簡化實作,實際應該使用 COUNT 查詢)
|
||||
total := int64(len(posts))
|
||||
if params != nil && params.PageIndex > 1 {
|
||||
// 這裡應該執行 COUNT 查詢,但為了簡化,我們假設有更多結果
|
||||
total = pageSize * pageIndex
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// FindByCategoryID 根據分類 ID 查詢貼文
|
||||
func (r *PostRepository) FindByCategoryID(ctx context.Context, categoryID gocql.UUID, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if categoryID == zeroUUID {
|
||||
return nil, 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 構建查詢
|
||||
query := r.repo.Query().Where(cassandra.Eq("category_id", categoryID))
|
||||
|
||||
// 添加狀態過濾
|
||||
if params != nil && params.Status != nil {
|
||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||||
}
|
||||
|
||||
// 添加排序和分頁(類似 FindByAuthorUID)
|
||||
orderBy := "created_at"
|
||||
if params != nil && params.OrderBy != "" {
|
||||
orderBy = params.OrderBy
|
||||
}
|
||||
order := cassandra.DESC
|
||||
if params != nil && params.OrderDirection == "ASC" {
|
||||
order = cassandra.ASC
|
||||
}
|
||||
query = query.OrderBy(orderBy, order)
|
||||
|
||||
pageSize := int64(20)
|
||||
if params != nil && params.PageSize > 0 {
|
||||
pageSize = params.PageSize
|
||||
}
|
||||
limit := int(pageSize)
|
||||
query = query.Limit(limit)
|
||||
|
||||
var posts []*entity.Post
|
||||
if err := query.Scan(ctx, &posts); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
||||
}
|
||||
|
||||
result := posts
|
||||
|
||||
total := int64(len(posts))
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// FindByTag 根據標籤查詢貼文
|
||||
func (r *PostRepository) FindByTag(ctx context.Context, tagName string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
||||
if tagName == "" {
|
||||
return nil, 0, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 構建查詢(注意:Cassandra 的集合查詢需要使用 CONTAINS,這裡簡化處理)
|
||||
// 實際實作中,可能需要使用 SAI 索引或 Materialized View
|
||||
query := r.repo.Query()
|
||||
|
||||
// 添加狀態過濾
|
||||
if params != nil && params.Status != nil {
|
||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
||||
}
|
||||
|
||||
// 添加排序和分頁
|
||||
orderBy := "created_at"
|
||||
if params != nil && params.OrderBy != "" {
|
||||
orderBy = params.OrderBy
|
||||
}
|
||||
order := cassandra.DESC
|
||||
if params != nil && params.OrderDirection == "ASC" {
|
||||
order = cassandra.ASC
|
||||
}
|
||||
query = query.OrderBy(orderBy, order)
|
||||
|
||||
pageSize := int64(20)
|
||||
if params != nil && params.PageSize > 0 {
|
||||
pageSize = params.PageSize
|
||||
}
|
||||
limit := int(pageSize)
|
||||
query = query.Limit(limit)
|
||||
|
||||
var posts []*entity.Post
|
||||
if err := query.Scan(ctx, &posts); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
||||
}
|
||||
|
||||
// 過濾包含指定標籤的貼文
|
||||
filtered := make([]*entity.Post, 0)
|
||||
for _, p := range posts {
|
||||
for _, tag := range p.Tags {
|
||||
if tag == tagName {
|
||||
filtered = append(filtered, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total := int64(len(filtered))
|
||||
return filtered, total, nil
|
||||
}
|
||||
|
||||
// FindPinnedPosts 查詢置頂貼文
|
||||
func (r *PostRepository) FindPinnedPosts(ctx context.Context, limit int64) ([]*entity.Post, error) {
|
||||
query := r.repo.Query().
|
||||
Where(cassandra.Eq("is_pinned", true)).
|
||||
Where(cassandra.Eq("status", post.PostStatusPublished)).
|
||||
OrderBy("pinned_at", cassandra.DESC).
|
||||
Limit(int(limit))
|
||||
|
||||
var posts []*entity.Post
|
||||
if err := query.Scan(ctx, &posts); err != nil {
|
||||
return nil, fmt.Errorf("failed to query pinned posts: %w", err)
|
||||
}
|
||||
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
// FindByStatus 根據狀態查詢貼文
|
||||
func (r *PostRepository) FindByStatus(ctx context.Context, status post.Status, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
||||
query := r.repo.Query().Where(cassandra.Eq("status", status))
|
||||
|
||||
// 添加排序和分頁
|
||||
orderBy := "created_at"
|
||||
if params != nil && params.OrderBy != "" {
|
||||
orderBy = params.OrderBy
|
||||
}
|
||||
order := cassandra.DESC
|
||||
if params != nil && params.OrderDirection == "ASC" {
|
||||
order = cassandra.ASC
|
||||
}
|
||||
query = query.OrderBy(orderBy, order)
|
||||
|
||||
pageSize := int64(20)
|
||||
if params != nil && params.PageSize > 0 {
|
||||
pageSize = params.PageSize
|
||||
}
|
||||
limit := int(pageSize)
|
||||
query = query.Limit(limit)
|
||||
|
||||
var posts []*entity.Post
|
||||
if err := query.Scan(ctx, &posts); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
||||
}
|
||||
|
||||
result := posts
|
||||
|
||||
total := int64(len(posts))
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:like_count 欄位必須是 counter 類型
|
||||
func (r *PostRepository) IncrementLikeCount(ctx context.Context, postID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if postID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroPost entity.Post
|
||||
tableName := zeroPost.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(postID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to increment like count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:like_count 欄位必須是 counter 類型
|
||||
func (r *PostRepository) DecrementLikeCount(ctx context.Context, postID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if postID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroPost entity.Post
|
||||
tableName := zeroPost.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(postID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to decrement like count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementCommentCount 增加評論數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:comment_count 欄位必須是 counter 類型
|
||||
func (r *PostRepository) IncrementCommentCount(ctx context.Context, postID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if postID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroPost entity.Post
|
||||
tableName := zeroPost.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(postID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to increment comment count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecrementCommentCount 減少評論數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:comment_count 欄位必須是 counter 類型
|
||||
func (r *PostRepository) DecrementCommentCount(ctx context.Context, postID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if postID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroPost entity.Post
|
||||
tableName := zeroPost.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count - 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(postID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to decrement comment count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementViewCount 增加瀏覽數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:view_count 欄位必須是 counter 類型
|
||||
func (r *PostRepository) IncrementViewCount(ctx context.Context, postID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if postID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
var zeroPost entity.Post
|
||||
tableName := zeroPost.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET view_count = view_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(postID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to increment view count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新貼文狀態
|
||||
func (r *PostRepository) UpdateStatus(ctx context.Context, postID gocql.UUID, status post.Status) error {
|
||||
postEntity, err := r.FindOne(ctx, postID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postEntity.Status = status
|
||||
publishedStatus := post.PostStatusPublished
|
||||
if status == publishedStatus && postEntity.PublishedAt == nil {
|
||||
now := postEntity.UpdatedAt
|
||||
postEntity.PublishedAt = &now
|
||||
}
|
||||
|
||||
return r.Update(ctx, postEntity)
|
||||
}
|
||||
|
||||
// PinPost 置頂貼文
|
||||
func (r *PostRepository) PinPost(ctx context.Context, postID gocql.UUID) error {
|
||||
post, err := r.FindOne(ctx, postID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
post.Pin()
|
||||
return r.Update(ctx, post)
|
||||
}
|
||||
|
||||
// UnpinPost 取消置頂
|
||||
func (r *PostRepository) UnpinPost(ctx context.Context, postID gocql.UUID) error {
|
||||
post, err := r.FindOne(ctx, postID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
post.Unpin()
|
||||
return r.Update(ctx, post)
|
||||
}
|
||||
|
||||
// calculateTotalPages 計算總頁數
|
||||
func calculateTotalPages(total, pageSize int64) int64 {
|
||||
if pageSize <= 0 {
|
||||
return 0
|
||||
}
|
||||
return int64(math.Ceil(float64(total) / float64(pageSize)))
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,250 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"backend/pkg/library/cassandra"
|
||||
"backend/pkg/post/domain/entity"
|
||||
domainRepo "backend/pkg/post/domain/repository"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// TagRepositoryParam 定義 TagRepository 的初始化參數
|
||||
type TagRepositoryParam struct {
|
||||
DB *cassandra.DB
|
||||
Keyspace string
|
||||
}
|
||||
|
||||
// TagRepository 實作 domain repository 介面
|
||||
type TagRepository struct {
|
||||
repo cassandra.Repository[*entity.Tag]
|
||||
db *cassandra.DB
|
||||
keyspace string
|
||||
}
|
||||
|
||||
// NewTagRepository 創建新的 TagRepository
|
||||
func NewTagRepository(param TagRepositoryParam) domainRepo.TagRepository {
|
||||
repo, err := cassandra.NewRepository[*entity.Tag](param.DB, param.Keyspace)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create tag repository: %v", err))
|
||||
}
|
||||
|
||||
keyspace := param.Keyspace
|
||||
if keyspace == "" {
|
||||
keyspace = param.DB.GetDefaultKeyspace()
|
||||
}
|
||||
|
||||
return &TagRepository{
|
||||
repo: repo,
|
||||
db: param.DB,
|
||||
keyspace: keyspace,
|
||||
}
|
||||
}
|
||||
|
||||
// Insert 插入單筆標籤
|
||||
func (r *TagRepository) Insert(ctx context.Context, data *entity.Tag) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 設置時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
// 如果是新標籤,生成 ID
|
||||
if data.IsNew() {
|
||||
data.ID = gocql.TimeUUID()
|
||||
}
|
||||
|
||||
// 標籤名稱轉為小寫(統一格式)
|
||||
data.Name = strings.ToLower(strings.TrimSpace(data.Name))
|
||||
|
||||
return r.repo.Insert(ctx, data)
|
||||
}
|
||||
|
||||
// FindOne 根據 ID 查詢單筆標籤
|
||||
func (r *TagRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Tag, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if id == zeroUUID {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
tag, err := r.repo.Get(ctx, id)
|
||||
if err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find tag: %w", err)
|
||||
}
|
||||
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
// Update 更新標籤
|
||||
func (r *TagRepository) Update(ctx context.Context, data *entity.Tag) error {
|
||||
if data == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 驗證資料
|
||||
if err := data.Validate(); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
||||
}
|
||||
|
||||
// 更新時間戳
|
||||
data.SetTimestamps()
|
||||
|
||||
// 標籤名稱轉為小寫
|
||||
data.Name = strings.ToLower(strings.TrimSpace(data.Name))
|
||||
|
||||
return r.repo.Update(ctx, data)
|
||||
}
|
||||
|
||||
// Delete 刪除標籤
|
||||
func (r *TagRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if id == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
return r.repo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// FindByName 根據名稱查詢標籤
|
||||
func (r *TagRepository) FindByName(ctx context.Context, name string) (*entity.Tag, error) {
|
||||
if name == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
// 標準化名稱
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
|
||||
// 構建查詢(假設有 SAI 索引在 name 欄位上)
|
||||
query := r.repo.Query().Where(cassandra.Eq("name", name))
|
||||
|
||||
var tags []*entity.Tag
|
||||
if err := query.Scan(ctx, &tags); err != nil {
|
||||
if cassandra.IsNotFound(err) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query tag: %w", err)
|
||||
}
|
||||
|
||||
if len(tags) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return tags[0], nil
|
||||
}
|
||||
|
||||
// FindByNames 根據名稱列表查詢標籤
|
||||
func (r *TagRepository) FindByNames(ctx context.Context, names []string) ([]*entity.Tag, error) {
|
||||
if len(names) == 0 {
|
||||
return []*entity.Tag{}, nil
|
||||
}
|
||||
|
||||
// 標準化名稱
|
||||
normalizedNames := make([]string, len(names))
|
||||
for i, name := range names {
|
||||
normalizedNames[i] = strings.ToLower(strings.TrimSpace(name))
|
||||
}
|
||||
|
||||
// 構建查詢(使用 IN 條件)
|
||||
query := r.repo.Query().Where(cassandra.In("name", toAnySlice(normalizedNames)))
|
||||
|
||||
var tags []*entity.Tag
|
||||
if err := query.Scan(ctx, &tags); err != nil {
|
||||
return nil, fmt.Errorf("failed to query tags: %w", err)
|
||||
}
|
||||
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
// FindPopular 查詢熱門標籤
|
||||
func (r *TagRepository) FindPopular(ctx context.Context, limit int64) ([]*entity.Tag, error) {
|
||||
// 構建查詢,按 post_count 降序排列
|
||||
query := r.repo.Query().
|
||||
OrderBy("post_count", cassandra.DESC).
|
||||
Limit(int(limit))
|
||||
|
||||
var tags []*entity.Tag
|
||||
if err := query.Scan(ctx, &tags); err != nil {
|
||||
return nil, fmt.Errorf("failed to query popular tags: %w", err)
|
||||
}
|
||||
|
||||
result := tags
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:post_count 欄位必須是 counter 類型
|
||||
func (r *TagRepository) IncrementPostCount(ctx context.Context, tagID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if tagID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 使用 counter 原子更新操作:UPDATE tags SET post_count = post_count + 1 WHERE id = ?
|
||||
var zeroTag entity.Tag
|
||||
tableName := zeroTag.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(tagID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to increment post count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件)
|
||||
// 注意:post_count 欄位必須是 counter 類型
|
||||
func (r *TagRepository) DecrementPostCount(ctx context.Context, tagID gocql.UUID) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if tagID == zeroUUID {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// 使用 counter 原子更新操作:UPDATE tags SET post_count = post_count - 1 WHERE id = ?
|
||||
var zeroTag entity.Tag
|
||||
tableName := zeroTag.TableName()
|
||||
if r.keyspace == "" {
|
||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName)
|
||||
query := r.db.GetSession().Query(stmt, nil).
|
||||
WithContext(ctx).
|
||||
Consistency(gocql.Quorum).
|
||||
Bind(tagID)
|
||||
|
||||
if err := query.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("failed to decrement post count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// toAnySlice 將 string slice 轉換為 []any
|
||||
func toAnySlice(strs []string) []any {
|
||||
result := make([]any, len(strs))
|
||||
for i, s := range strs {
|
||||
result[i] = s
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
@ -0,0 +1,455 @@
|
|||
package usecase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
errs "backend/pkg/library/errors"
|
||||
"backend/pkg/post/domain/entity"
|
||||
"backend/pkg/post/domain/post"
|
||||
domainRepo "backend/pkg/post/domain/repository"
|
||||
domainUsecase "backend/pkg/post/domain/usecase"
|
||||
"backend/pkg/post/repository"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// CommentUseCaseParam 定義 CommentUseCase 的初始化參數
|
||||
type CommentUseCaseParam struct {
|
||||
Comment domainRepo.CommentRepository
|
||||
Post domainRepo.PostRepository
|
||||
Like domainRepo.LikeRepository
|
||||
Logger errs.Logger
|
||||
}
|
||||
|
||||
// CommentUseCase 實作 domain usecase 介面
|
||||
type CommentUseCase struct {
|
||||
CommentUseCaseParam
|
||||
}
|
||||
|
||||
// MustCommentUseCase 創建新的 CommentUseCase(如果失敗會 panic)
|
||||
func MustCommentUseCase(param CommentUseCaseParam) domainUsecase.CommentUseCase {
|
||||
return &CommentUseCase{
|
||||
CommentUseCaseParam: param,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateComment 創建新評論
|
||||
func (uc *CommentUseCase) CreateComment(ctx context.Context, req domainUsecase.CreateCommentRequest) (*domainUsecase.CommentResponse, error) {
|
||||
// 驗證輸入
|
||||
if err := uc.validateCreateCommentRequest(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 驗證貼文存在
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
|
||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
||||
}
|
||||
return nil, uc.handleDBError("Post.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 檢查貼文是否可見
|
||||
if !post.IsVisible() {
|
||||
return nil, errs.ResNotFoundError("cannot comment on non-visible post")
|
||||
}
|
||||
|
||||
// 建立評論實體
|
||||
comment := &entity.Comment{
|
||||
PostID: req.PostID,
|
||||
AuthorUID: req.AuthorUID,
|
||||
ParentID: req.ParentID,
|
||||
Content: req.Content,
|
||||
Status: post.CommentStatusPublished,
|
||||
}
|
||||
|
||||
// 插入資料庫
|
||||
if err := uc.Comment.Insert(ctx, comment); err != nil {
|
||||
return nil, uc.handleDBError("Comment.Insert", req, err)
|
||||
}
|
||||
|
||||
// 如果是回覆,增加父評論的回覆數
|
||||
if req.ParentID != nil {
|
||||
if err := uc.Comment.IncrementReplyCount(ctx, *req.ParentID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to increment reply count: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// 增加貼文的評論數
|
||||
if err := uc.Post.IncrementCommentCount(ctx, req.PostID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to increment comment count: %v", err))
|
||||
}
|
||||
|
||||
return uc.mapCommentToResponse(comment), nil
|
||||
}
|
||||
|
||||
// GetComment 取得評論
|
||||
func (uc *CommentUseCase) GetComment(ctx context.Context, req domainUsecase.GetCommentRequest) (*domainUsecase.CommentResponse, error) {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.CommentID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("comment_id is required")
|
||||
}
|
||||
|
||||
// 查詢評論
|
||||
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return nil, errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
|
||||
}
|
||||
return nil, uc.handleDBError("Comment.FindOne", req, err)
|
||||
}
|
||||
|
||||
return uc.mapCommentToResponse(comment), nil
|
||||
}
|
||||
|
||||
// UpdateComment 更新評論
|
||||
func (uc *CommentUseCase) UpdateComment(ctx context.Context, req domainUsecase.UpdateCommentRequest) (*domainUsecase.CommentResponse, error) {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.CommentID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("comment_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
if req.Content == "" {
|
||||
return nil, errs.InputInvalidRangeError("content is required")
|
||||
}
|
||||
|
||||
// 查詢現有評論
|
||||
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return nil, errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
|
||||
}
|
||||
return nil, uc.handleDBError("Comment.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 驗證權限
|
||||
if comment.AuthorUID != req.AuthorUID {
|
||||
return nil, errs.ResNotFoundError("not authorized to update this comment")
|
||||
}
|
||||
|
||||
// 檢查是否可見
|
||||
if !comment.IsVisible() {
|
||||
return nil, errs.ResNotFoundError("comment is not visible")
|
||||
}
|
||||
|
||||
// 更新內容
|
||||
comment.Content = req.Content
|
||||
|
||||
// 更新資料庫
|
||||
if err := uc.Comment.Update(ctx, comment); err != nil {
|
||||
return nil, uc.handleDBError("Comment.Update", req, err)
|
||||
}
|
||||
|
||||
return uc.mapCommentToResponse(comment), nil
|
||||
}
|
||||
|
||||
// DeleteComment 刪除評論(軟刪除)
|
||||
func (uc *CommentUseCase) DeleteComment(ctx context.Context, req domainUsecase.DeleteCommentRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.CommentID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("comment_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
|
||||
// 查詢評論
|
||||
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
|
||||
}
|
||||
return uc.handleDBError("Comment.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 驗證權限
|
||||
if comment.AuthorUID != req.AuthorUID {
|
||||
return errs.ResNotFoundError("not authorized to delete this comment")
|
||||
}
|
||||
|
||||
// 刪除評論
|
||||
if err := uc.Comment.Delete(ctx, req.CommentID); err != nil {
|
||||
return uc.handleDBError("Comment.Delete", req, err)
|
||||
}
|
||||
|
||||
// 如果是回覆,減少父評論的回覆數
|
||||
if comment.ParentID != nil {
|
||||
if err := uc.Comment.DecrementReplyCount(ctx, *comment.ParentID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to decrement reply count: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// 減少貼文的評論數
|
||||
if err := uc.Post.DecrementCommentCount(ctx, comment.PostID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to decrement comment count: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListComments 列出評論
|
||||
func (uc *CommentUseCase) ListComments(ctx context.Context, req domainUsecase.ListCommentsRequest) (*domainUsecase.ListCommentsResponse, error) {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageIndex <= 0 {
|
||||
req.PageIndex = 1
|
||||
}
|
||||
|
||||
// 構建查詢參數
|
||||
params := &domainRepo.CommentQueryParams{
|
||||
PostID: &req.PostID,
|
||||
ParentID: req.ParentID,
|
||||
PageSize: req.PageSize,
|
||||
PageIndex: req.PageIndex,
|
||||
OrderBy: req.OrderBy,
|
||||
OrderDirection: req.OrderDirection,
|
||||
}
|
||||
|
||||
// 如果 OrderBy 未指定,預設為 created_at
|
||||
if params.OrderBy == "" {
|
||||
params.OrderBy = "created_at"
|
||||
}
|
||||
// 如果 OrderDirection 未指定,預設為 ASC
|
||||
if params.OrderDirection == "" {
|
||||
params.OrderDirection = "ASC"
|
||||
}
|
||||
|
||||
// 執行查詢
|
||||
comments, total, err := uc.Comment.FindByPostID(ctx, req.PostID, params)
|
||||
if err != nil {
|
||||
return nil, uc.handleDBError("Comment.FindByPostID", req, err)
|
||||
}
|
||||
|
||||
// 轉換為 Response
|
||||
responses := make([]domainUsecase.CommentResponse, len(comments))
|
||||
for i, c := range comments {
|
||||
responses[i] = *uc.mapCommentToResponse(c)
|
||||
}
|
||||
|
||||
return &domainUsecase.ListCommentsResponse{
|
||||
Data: responses,
|
||||
Page: domainUsecase.Pager{
|
||||
PageIndex: req.PageIndex,
|
||||
PageSize: req.PageSize,
|
||||
Total: total,
|
||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListReplies 列出回覆
|
||||
func (uc *CommentUseCase) ListReplies(ctx context.Context, req domainUsecase.ListRepliesRequest) (*domainUsecase.ListCommentsResponse, error) {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.CommentID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("comment_id is required")
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageIndex <= 0 {
|
||||
req.PageIndex = 1
|
||||
}
|
||||
|
||||
// 構建查詢參數
|
||||
params := &domainRepo.CommentQueryParams{
|
||||
PageSize: req.PageSize,
|
||||
PageIndex: req.PageIndex,
|
||||
OrderBy: "created_at",
|
||||
OrderDirection: "ASC",
|
||||
}
|
||||
|
||||
// 執行查詢
|
||||
comments, total, err := uc.Comment.FindReplies(ctx, req.CommentID, params)
|
||||
if err != nil {
|
||||
return nil, uc.handleDBError("Comment.FindReplies", req, err)
|
||||
}
|
||||
|
||||
// 轉換為 Response
|
||||
responses := make([]domainUsecase.CommentResponse, len(comments))
|
||||
for i, c := range comments {
|
||||
responses[i] = *uc.mapCommentToResponse(c)
|
||||
}
|
||||
|
||||
return &domainUsecase.ListCommentsResponse{
|
||||
Data: responses,
|
||||
Page: domainUsecase.Pager{
|
||||
PageIndex: req.PageIndex,
|
||||
PageSize: req.PageSize,
|
||||
Total: total,
|
||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListCommentsByAuthor 根據作者列出評論
|
||||
func (uc *CommentUseCase) ListCommentsByAuthor(ctx context.Context, req domainUsecase.ListCommentsByAuthorRequest) (*domainUsecase.ListCommentsResponse, error) {
|
||||
if req.AuthorUID == "" {
|
||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageIndex <= 0 {
|
||||
req.PageIndex = 1
|
||||
}
|
||||
|
||||
params := &domainRepo.CommentQueryParams{
|
||||
PageSize: req.PageSize,
|
||||
PageIndex: req.PageIndex,
|
||||
OrderBy: "created_at",
|
||||
OrderDirection: "DESC",
|
||||
}
|
||||
|
||||
comments, total, err := uc.Comment.FindByAuthorUID(ctx, req.AuthorUID, params)
|
||||
if err != nil {
|
||||
return nil, uc.handleDBError("Comment.FindByAuthorUID", req, err)
|
||||
}
|
||||
|
||||
responses := make([]domainUsecase.CommentResponse, len(comments))
|
||||
for i, c := range comments {
|
||||
responses[i] = *uc.mapCommentToResponse(c)
|
||||
}
|
||||
|
||||
return &domainUsecase.ListCommentsResponse{
|
||||
Data: responses,
|
||||
Page: domainUsecase.Pager{
|
||||
PageIndex: req.PageIndex,
|
||||
PageSize: req.PageSize,
|
||||
Total: total,
|
||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LikeComment 按讚評論
|
||||
func (uc *CommentUseCase) LikeComment(ctx context.Context, req domainUsecase.LikeCommentRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.CommentID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("comment_id is required")
|
||||
}
|
||||
if req.UserUID == "" {
|
||||
return errs.InputInvalidRangeError("user_uid is required")
|
||||
}
|
||||
|
||||
// 檢查是否已經按讚
|
||||
existingLike, err := uc.Like.FindByTargetAndUser(ctx, req.CommentID, req.UserUID, "comment")
|
||||
if err == nil && existingLike != nil {
|
||||
// 已經按讚,直接返回成功
|
||||
return nil
|
||||
}
|
||||
if err != nil && !repository.IsNotFound(err) {
|
||||
return uc.handleDBError("Like.FindByTargetAndUser", req, err)
|
||||
}
|
||||
|
||||
// 建立按讚記錄
|
||||
like := &entity.Like{
|
||||
TargetID: req.CommentID,
|
||||
UserUID: req.UserUID,
|
||||
TargetType: "comment",
|
||||
}
|
||||
|
||||
if err := uc.Like.Insert(ctx, like); err != nil {
|
||||
return uc.handleDBError("Like.Insert", req, err)
|
||||
}
|
||||
|
||||
// 增加評論的按讚數
|
||||
if err := uc.Comment.IncrementLikeCount(ctx, req.CommentID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to increment like count: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnlikeComment 取消按讚評論
|
||||
func (uc *CommentUseCase) UnlikeComment(ctx context.Context, req domainUsecase.UnlikeCommentRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.CommentID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("comment_id is required")
|
||||
}
|
||||
if req.UserUID == "" {
|
||||
return errs.InputInvalidRangeError("user_uid is required")
|
||||
}
|
||||
|
||||
// 刪除按讚記錄
|
||||
if err := uc.Like.DeleteByTargetAndUser(ctx, req.CommentID, req.UserUID, "comment"); err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
// 已經取消按讚,直接返回成功
|
||||
return nil
|
||||
}
|
||||
return uc.handleDBError("Like.DeleteByTargetAndUser", req, err)
|
||||
}
|
||||
|
||||
// 減少評論的按讚數
|
||||
if err := uc.Comment.DecrementLikeCount(ctx, req.CommentID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to decrement like count: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCreateCommentRequest 驗證建立評論請求
|
||||
func (uc *CommentUseCase) validateCreateCommentRequest(req domainUsecase.CreateCommentRequest) error {
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
if req.Content == "" {
|
||||
return errs.InputInvalidRangeError("content is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mapCommentToResponse 將 Comment 實體轉換為 CommentResponse
|
||||
func (uc *CommentUseCase) mapCommentToResponse(comment *entity.Comment) *domainUsecase.CommentResponse {
|
||||
return &domainUsecase.CommentResponse{
|
||||
ID: comment.ID,
|
||||
PostID: comment.PostID,
|
||||
AuthorUID: comment.AuthorUID,
|
||||
ParentID: comment.ParentID,
|
||||
Content: comment.Content,
|
||||
Status: comment.Status,
|
||||
LikeCount: comment.LikeCount,
|
||||
ReplyCount: comment.ReplyCount,
|
||||
CreatedAt: comment.CreatedAt,
|
||||
UpdatedAt: comment.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// handleDBError 處理資料庫錯誤
|
||||
func (uc *CommentUseCase) handleDBError(funcName string, req any, err error) error {
|
||||
return errs.DBErrorErrorL(
|
||||
uc.Logger,
|
||||
[]errs.LogField{
|
||||
{Key: "func", Val: funcName},
|
||||
{Key: "req", Val: req},
|
||||
{Key: "error", Val: err.Error()},
|
||||
},
|
||||
fmt.Sprintf("database operation failed: %s", funcName),
|
||||
).Wrap(err)
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,801 @@
|
|||
package usecase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
errs "backend/pkg/library/errors"
|
||||
"backend/pkg/post/domain/entity"
|
||||
"backend/pkg/post/domain/post"
|
||||
domainRepo "backend/pkg/post/domain/repository"
|
||||
domainUsecase "backend/pkg/post/domain/usecase"
|
||||
"backend/pkg/post/repository"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// PostUseCaseParam 定義 PostUseCase 的初始化參數
|
||||
type PostUseCaseParam struct {
|
||||
Post domainRepo.PostRepository
|
||||
Comment domainRepo.CommentRepository
|
||||
Like domainRepo.LikeRepository
|
||||
Tag domainRepo.TagRepository
|
||||
Category domainRepo.CategoryRepository
|
||||
Logger errs.Logger
|
||||
}
|
||||
|
||||
// PostUseCase 實作 domain usecase 介面
|
||||
type PostUseCase struct {
|
||||
PostUseCaseParam
|
||||
}
|
||||
|
||||
// MustPostUseCase 創建新的 PostUseCase(如果失敗會 panic)
|
||||
func MustPostUseCase(param PostUseCaseParam) domainUsecase.PostUseCase {
|
||||
return &PostUseCase{
|
||||
PostUseCaseParam: param,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePost 創建新貼文
|
||||
func (uc *PostUseCase) CreatePost(ctx context.Context, req domainUsecase.CreatePostRequest) (*domainUsecase.PostResponse, error) {
|
||||
// 驗證輸入
|
||||
if err := uc.validateCreatePostRequest(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 建立貼文實體
|
||||
post := &entity.Post{
|
||||
AuthorUID: req.AuthorUID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Type: req.Type,
|
||||
CategoryID: req.CategoryID,
|
||||
Tags: req.Tags,
|
||||
Images: req.Images,
|
||||
VideoURL: req.VideoURL,
|
||||
LinkURL: req.LinkURL,
|
||||
Status: req.Status,
|
||||
}
|
||||
|
||||
// 如果狀態未指定,預設為草稿
|
||||
if post.Status == 0 {
|
||||
post.Status = post.PostStatusDraft
|
||||
}
|
||||
|
||||
// 插入資料庫
|
||||
if err := uc.Post.Insert(ctx, post); err != nil {
|
||||
return nil, uc.handleDBError("Post.Insert", req, err)
|
||||
}
|
||||
|
||||
// 處理標籤(更新標籤的貼文數)
|
||||
if err := uc.updateTagPostCounts(ctx, req.Tags, true); err != nil {
|
||||
// 記錄錯誤但不中斷流程
|
||||
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
|
||||
}
|
||||
|
||||
// 處理分類(更新分類的貼文數)
|
||||
if req.CategoryID != nil {
|
||||
if err := uc.Category.IncrementPostCount(ctx, *req.CategoryID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to increment category post count: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
return uc.mapPostToResponse(post), nil
|
||||
}
|
||||
|
||||
// GetPost 取得貼文
|
||||
func (uc *PostUseCase) GetPost(ctx context.Context, req domainUsecase.GetPostRequest) (*domainUsecase.PostResponse, error) {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
|
||||
// 查詢貼文
|
||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
||||
}
|
||||
return nil, uc.handleDBError("Post.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 如果提供了 UserUID,增加瀏覽數
|
||||
if req.UserUID != nil {
|
||||
if err := uc.Post.IncrementViewCount(ctx, req.PostID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to increment view count: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
return uc.mapPostToResponse(post), nil
|
||||
}
|
||||
|
||||
// UpdatePost 更新貼文
|
||||
func (uc *PostUseCase) UpdatePost(ctx context.Context, req domainUsecase.UpdatePostRequest) (*domainUsecase.PostResponse, error) {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
|
||||
// 查詢現有貼文
|
||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
||||
}
|
||||
return nil, uc.handleDBError("Post.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 驗證權限
|
||||
if post.AuthorUID != req.AuthorUID {
|
||||
return nil, errs.ResNotFoundError("not authorized to update this post")
|
||||
}
|
||||
|
||||
// 檢查是否可編輯
|
||||
if !post.IsEditable() {
|
||||
return nil, errs.ResNotFoundError("post is not editable")
|
||||
}
|
||||
|
||||
// 更新欄位
|
||||
if req.Title != nil {
|
||||
post.Title = *req.Title
|
||||
}
|
||||
if req.Content != nil {
|
||||
post.Content = *req.Content
|
||||
}
|
||||
if req.Type != nil {
|
||||
post.Type = *req.Type
|
||||
}
|
||||
if req.CategoryID != nil {
|
||||
// 更新分類計數
|
||||
if post.CategoryID != nil && *post.CategoryID != *req.CategoryID {
|
||||
if err := uc.Category.DecrementPostCount(ctx, *post.CategoryID); err != nil {
|
||||
uc.Logger.Error("failed to decrement category post count", errs.LogField{Key: "error", Val: err.Error()})
|
||||
}
|
||||
if err := uc.Category.IncrementPostCount(ctx, *req.CategoryID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to increment category post count: %v", err))
|
||||
}
|
||||
}
|
||||
post.CategoryID = req.CategoryID
|
||||
}
|
||||
if req.Tags != nil {
|
||||
// 更新標籤計數
|
||||
oldTags := post.Tags
|
||||
post.Tags = req.Tags
|
||||
if err := uc.updateTagPostCountsDiff(ctx, oldTags, req.Tags); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
|
||||
}
|
||||
}
|
||||
if req.Images != nil {
|
||||
post.Images = req.Images
|
||||
}
|
||||
if req.VideoURL != nil {
|
||||
post.VideoURL = req.VideoURL
|
||||
}
|
||||
if req.LinkURL != nil {
|
||||
post.LinkURL = req.LinkURL
|
||||
}
|
||||
|
||||
// 更新資料庫
|
||||
if err := uc.Post.Update(ctx, post); err != nil {
|
||||
return nil, uc.handleDBError("Post.Update", req, err)
|
||||
}
|
||||
|
||||
return uc.mapPostToResponse(post), nil
|
||||
}
|
||||
|
||||
// DeletePost 刪除貼文(軟刪除)
|
||||
func (uc *PostUseCase) DeletePost(ctx context.Context, req domainUsecase.DeletePostRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
|
||||
// 查詢貼文
|
||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
||||
}
|
||||
return uc.handleDBError("Post.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 驗證權限
|
||||
if post.AuthorUID != req.AuthorUID {
|
||||
return errs.ResNotFoundError("not authorized to delete this post")
|
||||
}
|
||||
|
||||
// 刪除貼文
|
||||
if err := uc.Post.Delete(ctx, req.PostID); err != nil {
|
||||
return uc.handleDBError("Post.Delete", req, err)
|
||||
}
|
||||
|
||||
// 更新標籤和分類計數
|
||||
if len(post.Tags) > 0 {
|
||||
if err := uc.updateTagPostCounts(ctx, post.Tags, false); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
|
||||
}
|
||||
}
|
||||
if post.CategoryID != nil {
|
||||
if err := uc.Category.DecrementPostCount(ctx, *post.CategoryID); err != nil {
|
||||
uc.Logger.Error("failed to decrement category post count", errs.LogField{Key: "error", Val: err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishPost 發布貼文
|
||||
func (uc *PostUseCase) PublishPost(ctx context.Context, req domainUsecase.PublishPostRequest) (*domainUsecase.PostResponse, error) {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
|
||||
// 查詢貼文
|
||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
||||
}
|
||||
return nil, uc.handleDBError("Post.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 驗證權限
|
||||
if post.AuthorUID != req.AuthorUID {
|
||||
return nil, errs.ResNotFoundError("not authorized to publish this post")
|
||||
}
|
||||
|
||||
// 發布貼文
|
||||
post.Publish()
|
||||
|
||||
// 更新資料庫
|
||||
if err := uc.Post.Update(ctx, post); err != nil {
|
||||
return nil, uc.handleDBError("Post.Update", req, err)
|
||||
}
|
||||
|
||||
return uc.mapPostToResponse(post), nil
|
||||
}
|
||||
|
||||
// ArchivePost 歸檔貼文
|
||||
func (uc *PostUseCase) ArchivePost(ctx context.Context, req domainUsecase.ArchivePostRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
|
||||
// 查詢貼文
|
||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
||||
}
|
||||
return uc.handleDBError("Post.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 驗證權限
|
||||
if post.AuthorUID != req.AuthorUID {
|
||||
return errs.ResNotFoundError("not authorized to archive this post")
|
||||
}
|
||||
|
||||
// 歸檔貼文
|
||||
post.Archive()
|
||||
|
||||
// 更新資料庫
|
||||
return uc.Post.Update(ctx, post)
|
||||
}
|
||||
|
||||
// ListPosts 列出貼文
|
||||
func (uc *PostUseCase) ListPosts(ctx context.Context, req domainUsecase.ListPostsRequest) (*domainUsecase.ListPostsResponse, error) {
|
||||
// 驗證分頁參數
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageIndex <= 0 {
|
||||
req.PageIndex = 1
|
||||
}
|
||||
|
||||
// 構建查詢參數
|
||||
params := &domainRepo.PostQueryParams{
|
||||
CategoryID: req.CategoryID,
|
||||
Tag: req.Tag,
|
||||
Status: req.Status,
|
||||
Type: req.Type,
|
||||
AuthorUID: req.AuthorUID,
|
||||
CreateStartTime: req.CreateStartTime,
|
||||
CreateEndTime: req.CreateEndTime,
|
||||
PageSize: req.PageSize,
|
||||
PageIndex: req.PageIndex,
|
||||
OrderBy: req.OrderBy,
|
||||
OrderDirection: req.OrderDirection,
|
||||
}
|
||||
|
||||
// 執行查詢
|
||||
var posts []*entity.Post
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
if req.CategoryID != nil {
|
||||
posts, total, err = uc.Post.FindByCategoryID(ctx, *req.CategoryID, params)
|
||||
} else if req.Tag != nil {
|
||||
posts, total, err = uc.Post.FindByTag(ctx, *req.Tag, params)
|
||||
} else if req.AuthorUID != nil {
|
||||
posts, total, err = uc.Post.FindByAuthorUID(ctx, *req.AuthorUID, params)
|
||||
} else if req.Status != nil {
|
||||
posts, total, err = uc.Post.FindByStatus(ctx, *req.Status, params)
|
||||
} else {
|
||||
// 預設查詢所有已發布的貼文
|
||||
published := post.PostStatusPublished
|
||||
params.Status = &published
|
||||
posts, total, err = uc.Post.FindByStatus(ctx, published, params)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, uc.handleDBError("Post.FindBy*", req, err)
|
||||
}
|
||||
|
||||
// 轉換為 Response
|
||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
||||
for i, p := range posts {
|
||||
responses[i] = *uc.mapPostToResponse(p)
|
||||
}
|
||||
|
||||
return &domainUsecase.ListPostsResponse{
|
||||
Data: responses,
|
||||
Page: domainUsecase.Pager{
|
||||
PageIndex: req.PageIndex,
|
||||
PageSize: req.PageSize,
|
||||
Total: total,
|
||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListPostsByAuthor 根據作者列出貼文
|
||||
func (uc *PostUseCase) ListPostsByAuthor(ctx context.Context, req domainUsecase.ListPostsByAuthorRequest) (*domainUsecase.ListPostsResponse, error) {
|
||||
if req.AuthorUID == "" {
|
||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageIndex <= 0 {
|
||||
req.PageIndex = 1
|
||||
}
|
||||
|
||||
params := &domainRepo.PostQueryParams{
|
||||
Status: req.Status,
|
||||
PageSize: req.PageSize,
|
||||
PageIndex: req.PageIndex,
|
||||
OrderBy: "created_at",
|
||||
OrderDirection: "DESC",
|
||||
}
|
||||
|
||||
posts, total, err := uc.Post.FindByAuthorUID(ctx, req.AuthorUID, params)
|
||||
if err != nil {
|
||||
return nil, uc.handleDBError("Post.FindByAuthorUID", req, err)
|
||||
}
|
||||
|
||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
||||
for i, p := range posts {
|
||||
responses[i] = *uc.mapPostToResponse(p)
|
||||
}
|
||||
|
||||
return &domainUsecase.ListPostsResponse{
|
||||
Data: responses,
|
||||
Page: domainUsecase.Pager{
|
||||
PageIndex: req.PageIndex,
|
||||
PageSize: req.PageSize,
|
||||
Total: total,
|
||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListPostsByCategory 根據分類列出貼文
|
||||
func (uc *PostUseCase) ListPostsByCategory(ctx context.Context, req domainUsecase.ListPostsByCategoryRequest) (*domainUsecase.ListPostsResponse, error) {
|
||||
var zeroUUID gocql.UUID
|
||||
if req.CategoryID == zeroUUID {
|
||||
return nil, errs.InputInvalidRangeError("category_id is required")
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageIndex <= 0 {
|
||||
req.PageIndex = 1
|
||||
}
|
||||
|
||||
params := &domainRepo.PostQueryParams{
|
||||
Status: req.Status,
|
||||
PageSize: req.PageSize,
|
||||
PageIndex: req.PageIndex,
|
||||
OrderBy: "created_at",
|
||||
OrderDirection: "DESC",
|
||||
}
|
||||
|
||||
posts, total, err := uc.Post.FindByCategoryID(ctx, req.CategoryID, params)
|
||||
if err != nil {
|
||||
return nil, uc.handleDBError("Post.FindByCategoryID", req, err)
|
||||
}
|
||||
|
||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
||||
for i, p := range posts {
|
||||
responses[i] = *uc.mapPostToResponse(p)
|
||||
}
|
||||
|
||||
return &domainUsecase.ListPostsResponse{
|
||||
Data: responses,
|
||||
Page: domainUsecase.Pager{
|
||||
PageIndex: req.PageIndex,
|
||||
PageSize: req.PageSize,
|
||||
Total: total,
|
||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListPostsByTag 根據標籤列出貼文
|
||||
func (uc *PostUseCase) ListPostsByTag(ctx context.Context, req domainUsecase.ListPostsByTagRequest) (*domainUsecase.ListPostsResponse, error) {
|
||||
if req.Tag == "" {
|
||||
return nil, errs.InputInvalidRangeError("tag is required")
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageIndex <= 0 {
|
||||
req.PageIndex = 1
|
||||
}
|
||||
|
||||
params := &domainRepo.PostQueryParams{
|
||||
Status: req.Status,
|
||||
PageSize: req.PageSize,
|
||||
PageIndex: req.PageIndex,
|
||||
OrderBy: "created_at",
|
||||
OrderDirection: "DESC",
|
||||
}
|
||||
|
||||
posts, total, err := uc.Post.FindByTag(ctx, req.Tag, params)
|
||||
if err != nil {
|
||||
return nil, uc.handleDBError("Post.FindByTag", req, err)
|
||||
}
|
||||
|
||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
||||
for i, p := range posts {
|
||||
responses[i] = *uc.mapPostToResponse(p)
|
||||
}
|
||||
|
||||
return &domainUsecase.ListPostsResponse{
|
||||
Data: responses,
|
||||
Page: domainUsecase.Pager{
|
||||
PageIndex: req.PageIndex,
|
||||
PageSize: req.PageSize,
|
||||
Total: total,
|
||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPinnedPosts 取得置頂貼文
|
||||
func (uc *PostUseCase) GetPinnedPosts(ctx context.Context, req domainUsecase.GetPinnedPostsRequest) (*domainUsecase.ListPostsResponse, error) {
|
||||
limit := int64(10)
|
||||
if req.Limit > 0 {
|
||||
limit = req.Limit
|
||||
}
|
||||
|
||||
posts, err := uc.Post.FindPinnedPosts(ctx, limit)
|
||||
if err != nil {
|
||||
return nil, uc.handleDBError("Post.FindPinnedPosts", req, err)
|
||||
}
|
||||
|
||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
||||
for i, p := range posts {
|
||||
responses[i] = *uc.mapPostToResponse(p)
|
||||
}
|
||||
|
||||
return &domainUsecase.ListPostsResponse{
|
||||
Data: responses,
|
||||
Page: domainUsecase.Pager{
|
||||
PageIndex: 1,
|
||||
PageSize: limit,
|
||||
Total: int64(len(responses)),
|
||||
TotalPage: 1,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LikePost 按讚貼文
|
||||
func (uc *PostUseCase) LikePost(ctx context.Context, req domainUsecase.LikePostRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.UserUID == "" {
|
||||
return errs.InputInvalidRangeError("user_uid is required")
|
||||
}
|
||||
|
||||
// 檢查是否已經按讚
|
||||
existingLike, err := uc.Like.FindByTargetAndUser(ctx, req.PostID, req.UserUID, "post")
|
||||
if err == nil && existingLike != nil {
|
||||
// 已經按讚,直接返回成功
|
||||
return nil
|
||||
}
|
||||
if err != nil && !repository.IsNotFound(err) {
|
||||
return uc.handleDBError("Like.FindByTargetAndUser", req, err)
|
||||
}
|
||||
|
||||
// 建立按讚記錄
|
||||
like := &entity.Like{
|
||||
TargetID: req.PostID,
|
||||
UserUID: req.UserUID,
|
||||
TargetType: "post",
|
||||
}
|
||||
|
||||
if err := uc.Like.Insert(ctx, like); err != nil {
|
||||
return uc.handleDBError("Like.Insert", req, err)
|
||||
}
|
||||
|
||||
// 增加貼文的按讚數
|
||||
if err := uc.Post.IncrementLikeCount(ctx, req.PostID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to increment like count: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnlikePost 取消按讚
|
||||
func (uc *PostUseCase) UnlikePost(ctx context.Context, req domainUsecase.UnlikePostRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.UserUID == "" {
|
||||
return errs.InputInvalidRangeError("user_uid is required")
|
||||
}
|
||||
|
||||
// 刪除按讚記錄
|
||||
if err := uc.Like.DeleteByTargetAndUser(ctx, req.PostID, req.UserUID, "post"); err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
// 已經取消按讚,直接返回成功
|
||||
return nil
|
||||
}
|
||||
return uc.handleDBError("Like.DeleteByTargetAndUser", req, err)
|
||||
}
|
||||
|
||||
// 減少貼文的按讚數
|
||||
if err := uc.Post.DecrementLikeCount(ctx, req.PostID); err != nil {
|
||||
uc.Logger.Error(fmt.Sprintf("failed to decrement like count: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ViewPost 瀏覽貼文(增加瀏覽數)
|
||||
func (uc *PostUseCase) ViewPost(ctx context.Context, req domainUsecase.ViewPostRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
|
||||
// 增加瀏覽數
|
||||
if err := uc.Post.IncrementViewCount(ctx, req.PostID); err != nil {
|
||||
return uc.handleDBError("Post.IncrementViewCount", req, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PinPost 置頂貼文
|
||||
func (uc *PostUseCase) PinPost(ctx context.Context, req domainUsecase.PinPostRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
|
||||
// 查詢貼文
|
||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
||||
}
|
||||
return uc.handleDBError("Post.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 驗證權限
|
||||
if post.AuthorUID != req.AuthorUID {
|
||||
return errs.ResNotFoundError("not authorized to pin this post")
|
||||
}
|
||||
|
||||
// 置頂貼文
|
||||
return uc.Post.PinPost(ctx, req.PostID)
|
||||
}
|
||||
|
||||
// UnpinPost 取消置頂
|
||||
func (uc *PostUseCase) UnpinPost(ctx context.Context, req domainUsecase.UnpinPostRequest) error {
|
||||
// 驗證輸入
|
||||
var zeroUUID gocql.UUID
|
||||
if req.PostID == zeroUUID {
|
||||
return errs.InputInvalidRangeError("post_id is required")
|
||||
}
|
||||
if req.AuthorUID == "" {
|
||||
return errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
|
||||
// 查詢貼文
|
||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
||||
}
|
||||
return uc.handleDBError("Post.FindOne", req, err)
|
||||
}
|
||||
|
||||
// 驗證權限
|
||||
if post.AuthorUID != req.AuthorUID {
|
||||
return errs.ResNotFoundError("not authorized to unpin this post")
|
||||
}
|
||||
|
||||
// 取消置頂
|
||||
return uc.Post.UnpinPost(ctx, req.PostID)
|
||||
}
|
||||
|
||||
// validateCreatePostRequest 驗證建立貼文請求
|
||||
func (uc *PostUseCase) validateCreatePostRequest(req domainUsecase.CreatePostRequest) error {
|
||||
if req.AuthorUID == "" {
|
||||
return errs.InputInvalidRangeError("author_uid is required")
|
||||
}
|
||||
if req.Title == "" {
|
||||
return errs.InputInvalidRangeError("title is required")
|
||||
}
|
||||
if req.Content == "" {
|
||||
return errs.InputInvalidRangeError("content is required")
|
||||
}
|
||||
if !req.Type.IsValid() {
|
||||
return errs.InputInvalidRangeError("invalid post type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mapPostToResponse 將 Post 實體轉換為 PostResponse
|
||||
func (uc *PostUseCase) mapPostToResponse(post *entity.Post) *domainUsecase.PostResponse {
|
||||
return &domainUsecase.PostResponse{
|
||||
ID: post.ID,
|
||||
AuthorUID: post.AuthorUID,
|
||||
Title: post.Title,
|
||||
Content: post.Content,
|
||||
Type: post.Type,
|
||||
Status: post.Status,
|
||||
CategoryID: post.CategoryID,
|
||||
Tags: post.Tags,
|
||||
Images: post.Images,
|
||||
VideoURL: post.VideoURL,
|
||||
LinkURL: post.LinkURL,
|
||||
LikeCount: post.LikeCount,
|
||||
CommentCount: post.CommentCount,
|
||||
ViewCount: post.ViewCount,
|
||||
IsPinned: post.IsPinned,
|
||||
PinnedAt: post.PinnedAt,
|
||||
PublishedAt: post.PublishedAt,
|
||||
CreatedAt: post.CreatedAt,
|
||||
UpdatedAt: post.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// handleDBError 處理資料庫錯誤
|
||||
func (uc *PostUseCase) handleDBError(funcName string, req any, err error) error {
|
||||
return errs.DBErrorErrorL(
|
||||
uc.Logger,
|
||||
[]errs.LogField{
|
||||
{Key: "func", Val: funcName},
|
||||
{Key: "req", Val: req},
|
||||
{Key: "error", Val: err.Error()},
|
||||
},
|
||||
fmt.Sprintf("database operation failed: %s", funcName),
|
||||
).Wrap(err)
|
||||
}
|
||||
|
||||
// updateTagPostCounts 更新標籤的貼文數
|
||||
func (uc *PostUseCase) updateTagPostCounts(ctx context.Context, tags []string, increment bool) error {
|
||||
if len(tags) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查詢或建立標籤
|
||||
for _, tagName := range tags {
|
||||
tag, err := uc.Tag.FindByName(ctx, tagName)
|
||||
if err != nil {
|
||||
if repository.IsNotFound(err) {
|
||||
// 建立新標籤
|
||||
newTag := &entity.Tag{
|
||||
Name: tagName,
|
||||
}
|
||||
if err := uc.Tag.Insert(ctx, newTag); err != nil {
|
||||
return fmt.Errorf("failed to create tag: %w", err)
|
||||
}
|
||||
tag = newTag
|
||||
} else {
|
||||
return fmt.Errorf("failed to find tag: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新計數
|
||||
if increment {
|
||||
if err := uc.Tag.IncrementPostCount(ctx, tag.ID); err != nil {
|
||||
return fmt.Errorf("failed to increment tag count: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := uc.Tag.DecrementPostCount(ctx, tag.ID); err != nil {
|
||||
return fmt.Errorf("failed to decrement tag count: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateTagPostCountsDiff 更新標籤計數(處理差異)
|
||||
func (uc *PostUseCase) updateTagPostCountsDiff(ctx context.Context, oldTags, newTags []string) error {
|
||||
// 找出新增和刪除的標籤
|
||||
oldTagMap := make(map[string]bool)
|
||||
for _, tag := range oldTags {
|
||||
oldTagMap[tag] = true
|
||||
}
|
||||
|
||||
newTagMap := make(map[string]bool)
|
||||
for _, tag := range newTags {
|
||||
newTagMap[tag] = true
|
||||
}
|
||||
|
||||
// 新增的標籤
|
||||
for _, tag := range newTags {
|
||||
if !oldTagMap[tag] {
|
||||
if err := uc.updateTagPostCounts(ctx, []string{tag}, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 刪除的標籤
|
||||
for _, tag := range oldTags {
|
||||
if !newTagMap[tag] {
|
||||
if err := uc.updateTagPostCounts(ctx, []string{tag}, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateTotalPages 計算總頁數
|
||||
func calculateTotalPages(total, pageSize int64) int64 {
|
||||
if pageSize <= 0 {
|
||||
return 0
|
||||
}
|
||||
return int64(math.Ceil(float64(total) / float64(pageSize)))
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue