227 lines
4.9 KiB
Go
227 lines
4.9 KiB
Go
package cassandra
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
|
||
"github.com/gocql/gocql"
|
||
"github.com/scylladb/gocqlx/v2/qb"
|
||
)
|
||
|
||
// Condition 定義查詢條件介面
|
||
type Condition interface {
|
||
Build() (qb.Cmp, map[string]any)
|
||
}
|
||
|
||
// Eq 等於條件
|
||
func Eq(column string, value any) Condition {
|
||
return &eqCondition{column: column, value: value}
|
||
}
|
||
|
||
type eqCondition struct {
|
||
column string
|
||
value any
|
||
}
|
||
|
||
func (c *eqCondition) Build() (qb.Cmp, map[string]any) {
|
||
return qb.Eq(c.column), map[string]any{c.column: c.value}
|
||
}
|
||
|
||
// In IN 條件
|
||
func In(column string, values []any) Condition {
|
||
return &inCondition{column: column, values: values}
|
||
}
|
||
|
||
type inCondition struct {
|
||
column string
|
||
values []any
|
||
}
|
||
|
||
func (c *inCondition) Build() (qb.Cmp, map[string]any) {
|
||
return qb.In(c.column), map[string]any{c.column: c.values}
|
||
}
|
||
|
||
// Gt 大於條件
|
||
func Gt(column string, value any) Condition {
|
||
return >Condition{column: column, value: value}
|
||
}
|
||
|
||
type gtCondition struct {
|
||
column string
|
||
value any
|
||
}
|
||
|
||
func (c *gtCondition) Build() (qb.Cmp, map[string]any) {
|
||
return qb.Gt(c.column), map[string]any{c.column: c.value}
|
||
}
|
||
|
||
// Lt 小於條件
|
||
func Lt(column string, value any) Condition {
|
||
return <Condition{column: column, value: value}
|
||
}
|
||
|
||
type ltCondition struct {
|
||
column string
|
||
value any
|
||
}
|
||
|
||
func (c *ltCondition) Build() (qb.Cmp, map[string]any) {
|
||
return qb.Lt(c.column), map[string]any{c.column: c.value}
|
||
}
|
||
|
||
// QueryBuilder 定義查詢構建器介面
|
||
type QueryBuilder[T Table] interface {
|
||
Where(condition Condition) QueryBuilder[T]
|
||
OrderBy(column string, order Order) QueryBuilder[T]
|
||
Limit(n int) QueryBuilder[T]
|
||
Select(columns ...string) QueryBuilder[T]
|
||
Scan(ctx context.Context, dest *[]T) error
|
||
One(ctx context.Context) (T, error)
|
||
Count(ctx context.Context) (int64, error)
|
||
}
|
||
|
||
// queryBuilder 是 QueryBuilder 的具體實作
|
||
type queryBuilder[T Table] struct {
|
||
repo *repository[T]
|
||
conditions []Condition
|
||
orders []orderBy
|
||
limit int
|
||
columns []string
|
||
}
|
||
|
||
type orderBy struct {
|
||
column string
|
||
order Order
|
||
}
|
||
|
||
// newQueryBuilder 創建新的查詢構建器
|
||
func newQueryBuilder[T Table](repo *repository[T]) QueryBuilder[T] {
|
||
return &queryBuilder[T]{
|
||
repo: repo,
|
||
}
|
||
}
|
||
|
||
// Where 添加 WHERE 條件
|
||
func (q *queryBuilder[T]) Where(condition Condition) QueryBuilder[T] {
|
||
q.conditions = append(q.conditions, condition)
|
||
return q
|
||
}
|
||
|
||
// OrderBy 添加排序
|
||
func (q *queryBuilder[T]) OrderBy(column string, order Order) QueryBuilder[T] {
|
||
q.orders = append(q.orders, orderBy{column: column, order: order})
|
||
return q
|
||
}
|
||
|
||
// Limit 設置限制
|
||
func (q *queryBuilder[T]) Limit(n int) QueryBuilder[T] {
|
||
q.limit = n
|
||
return q
|
||
}
|
||
|
||
// Select 指定要查詢的欄位
|
||
func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] {
|
||
q.columns = append(q.columns, columns...)
|
||
return q
|
||
}
|
||
|
||
// Scan 執行查詢並將結果掃描到 dest
|
||
func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
|
||
if dest == nil {
|
||
return ErrInvalidInput.WithTable(q.repo.table).WithError(
|
||
fmt.Errorf("destination cannot be nil"),
|
||
)
|
||
}
|
||
|
||
builder := qb.Select(q.repo.table)
|
||
|
||
// 添加欄位
|
||
if len(q.columns) > 0 {
|
||
builder = builder.Columns(q.columns...)
|
||
} else {
|
||
builder = builder.Columns(q.repo.metadata.Columns...)
|
||
}
|
||
|
||
// 添加條件
|
||
bindMap := make(map[string]any)
|
||
var cmps []qb.Cmp
|
||
for _, cond := range q.conditions {
|
||
cmp, binds := cond.Build()
|
||
cmps = append(cmps, cmp)
|
||
for k, v := range binds {
|
||
bindMap[k] = v
|
||
}
|
||
}
|
||
if len(cmps) > 0 {
|
||
builder = builder.Where(cmps...)
|
||
}
|
||
|
||
// 添加排序
|
||
for _, o := range q.orders {
|
||
order := qb.ASC
|
||
if o.order == DESC {
|
||
order = qb.DESC
|
||
}
|
||
|
||
builder = builder.OrderBy(o.column, order)
|
||
}
|
||
|
||
// 添加限制
|
||
if q.limit > 0 {
|
||
builder = builder.Limit(uint(q.limit))
|
||
}
|
||
|
||
stmt, names := builder.ToCql()
|
||
query := q.repo.db.withContextAndTimestamp(ctx,
|
||
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
|
||
|
||
return query.SelectRelease(dest)
|
||
}
|
||
|
||
// One 執行查詢並返回單筆結果
|
||
func (q *queryBuilder[T]) One(ctx context.Context) (T, error) {
|
||
var zero T
|
||
q.limit = 1
|
||
|
||
var results []T
|
||
if err := q.Scan(ctx, &results); err != nil {
|
||
return zero, err
|
||
}
|
||
|
||
if len(results) == 0 {
|
||
return zero, ErrNotFound.WithTable(q.repo.table)
|
||
}
|
||
|
||
return results[0], nil
|
||
}
|
||
|
||
// Count 計算符合條件的記錄數
|
||
func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) {
|
||
builder := qb.Select(q.repo.table).Columns("COUNT(*)")
|
||
|
||
// 添加條件
|
||
bindMap := make(map[string]any)
|
||
var cmps []qb.Cmp
|
||
for _, cond := range q.conditions {
|
||
cmp, binds := cond.Build()
|
||
cmps = append(cmps, cmp)
|
||
for k, v := range binds {
|
||
bindMap[k] = v
|
||
}
|
||
}
|
||
if len(cmps) > 0 {
|
||
builder = builder.Where(cmps...)
|
||
}
|
||
|
||
stmt, names := builder.ToCql()
|
||
query := q.repo.db.withContextAndTimestamp(ctx,
|
||
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
|
||
|
||
var count int64
|
||
err := query.GetRelease(&count)
|
||
if err == gocql.ErrNotFound {
|
||
return 0, nil // COUNT 查詢不會返回 ErrNotFound,但為了安全起見
|
||
}
|
||
return count, err
|
||
}
|