package cassandra import ( "context" "fmt" "github.com/gocql/gocql" "github.com/scylladb/gocqlx/v3/qb" ) // Condition 定義查詢條件介面 type Condition interface { Build() (qb.Cmp, map[string]any) } // Eq 等於條件 func Eq(column string, value any) Condition { return &eqCondition{column: column, value: value} } type eqCondition struct { column string value any } func (c *eqCondition) Build() (qb.Cmp, map[string]any) { return qb.Eq(c.column), map[string]any{c.column: c.value} } // In IN 條件 func In(column string, values []any) Condition { return &inCondition{column: column, values: values} } type inCondition struct { column string values []any } func (c *inCondition) Build() (qb.Cmp, map[string]any) { return qb.In(c.column), map[string]any{c.column: c.values} } // Gt 大於條件 func Gt(column string, value any) Condition { return >Condition{column: column, value: value} } type gtCondition struct { column string value any } func (c *gtCondition) Build() (qb.Cmp, map[string]any) { return qb.Gt(c.column), map[string]any{c.column: c.value} } // Lt 小於條件 func Lt(column string, value any) Condition { return <Condition{column: column, value: value} } type ltCondition struct { column string value any } func (c *ltCondition) Build() (qb.Cmp, map[string]any) { return qb.Lt(c.column), map[string]any{c.column: c.value} } // QueryBuilder 定義查詢構建器介面 type QueryBuilder[T Table] interface { Where(condition Condition) QueryBuilder[T] OrderBy(column string, order Order) QueryBuilder[T] Limit(n int) QueryBuilder[T] Select(columns ...string) QueryBuilder[T] Scan(ctx context.Context, dest *[]T) error One(ctx context.Context) (T, error) Count(ctx context.Context) (int64, error) } // queryBuilder 是 QueryBuilder 的具體實作 type queryBuilder[T Table] struct { repo *repository[T] conditions []Condition orders []orderBy limit int columns []string } type orderBy struct { column string order Order } // newQueryBuilder 創建新的查詢構建器 func newQueryBuilder[T Table](repo *repository[T]) QueryBuilder[T] { return &queryBuilder[T]{ repo: repo, } } // Where 添加 WHERE 條件 func (q *queryBuilder[T]) Where(condition Condition) QueryBuilder[T] { q.conditions = append(q.conditions, condition) return q } // OrderBy 添加排序 func (q *queryBuilder[T]) OrderBy(column string, order Order) QueryBuilder[T] { q.orders = append(q.orders, orderBy{column: column, order: order}) return q } // Limit 設置限制 func (q *queryBuilder[T]) Limit(n int) QueryBuilder[T] { q.limit = n return q } // Select 指定要查詢的欄位 func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] { q.columns = append(q.columns, columns...) return q } // Scan 執行查詢並將結果掃描到 dest func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error { if dest == nil { return ErrInvalidInput.WithTable(q.repo.table).WithError( fmt.Errorf("destination cannot be nil"), ) } builder := qb.Select(q.repo.table) // 添加欄位 if len(q.columns) > 0 { builder = builder.Columns(q.columns...) } else { builder = builder.Columns(q.repo.metadata.Columns...) } // 添加條件 bindMap := make(map[string]any) var cmps []qb.Cmp for _, cond := range q.conditions { cmp, binds := cond.Build() cmps = append(cmps, cmp) for k, v := range binds { bindMap[k] = v } } if len(cmps) > 0 { builder = builder.Where(cmps...) } // 添加排序 for _, o := range q.orders { order := qb.ASC if o.order == DESC { order = qb.DESC } builder = builder.OrderBy(o.column, order) } // 添加限制 if q.limit > 0 { builder = builder.Limit(uint(q.limit)) } stmt, names := builder.ToCql() query := q.repo.db.withContextAndTimestamp(ctx, q.repo.db.session.Query(stmt, names).BindMap(bindMap)) return query.SelectRelease(dest) } // One 執行查詢並返回單筆結果 func (q *queryBuilder[T]) One(ctx context.Context) (T, error) { var zero T q.limit = 1 var results []T if err := q.Scan(ctx, &results); err != nil { return zero, err } if len(results) == 0 { return zero, ErrNotFound.WithTable(q.repo.table) } return results[0], nil } // Count 計算符合條件的記錄數 func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) { builder := qb.Select(q.repo.table).Columns("COUNT(*)") // 添加條件 bindMap := make(map[string]any) var cmps []qb.Cmp for _, cond := range q.conditions { cmp, binds := cond.Build() cmps = append(cmps, cmp) for k, v := range binds { bindMap[k] = v } } if len(cmps) > 0 { builder = builder.Where(cmps...) } stmt, names := builder.ToCql() query := q.repo.db.withContextAndTimestamp(ctx, q.repo.db.session.Query(stmt, names).BindMap(bindMap)) var count int64 err := query.GetRelease(&count) if err == gocql.ErrNotFound { return 0, nil // COUNT 查詢不會返回 ErrNotFound,但為了安全起見 } return count, err }