backend/pkg/library/cassandra/query.go

227 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cassandra
import (
"context"
"fmt"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v2/qb"
)
// Condition 定義查詢條件介面
type Condition interface {
Build() (qb.Cmp, map[string]any)
}
// Eq 等於條件
func Eq(column string, value any) Condition {
return &eqCondition{column: column, value: value}
}
type eqCondition struct {
column string
value any
}
func (c *eqCondition) Build() (qb.Cmp, map[string]any) {
return qb.Eq(c.column), map[string]any{c.column: c.value}
}
// In IN 條件
func In(column string, values []any) Condition {
return &inCondition{column: column, values: values}
}
type inCondition struct {
column string
values []any
}
func (c *inCondition) Build() (qb.Cmp, map[string]any) {
return qb.In(c.column), map[string]any{c.column: c.values}
}
// Gt 大於條件
func Gt(column string, value any) Condition {
return &gtCondition{column: column, value: value}
}
type gtCondition struct {
column string
value any
}
func (c *gtCondition) Build() (qb.Cmp, map[string]any) {
return qb.Gt(c.column), map[string]any{c.column: c.value}
}
// Lt 小於條件
func Lt(column string, value any) Condition {
return &ltCondition{column: column, value: value}
}
type ltCondition struct {
column string
value any
}
func (c *ltCondition) Build() (qb.Cmp, map[string]any) {
return qb.Lt(c.column), map[string]any{c.column: c.value}
}
// QueryBuilder 定義查詢構建器介面
type QueryBuilder[T Table] interface {
Where(condition Condition) QueryBuilder[T]
OrderBy(column string, order Order) QueryBuilder[T]
Limit(n int) QueryBuilder[T]
Select(columns ...string) QueryBuilder[T]
Scan(ctx context.Context, dest *[]T) error
One(ctx context.Context) (T, error)
Count(ctx context.Context) (int64, error)
}
// queryBuilder 是 QueryBuilder 的具體實作
type queryBuilder[T Table] struct {
repo *repository[T]
conditions []Condition
orders []orderBy
limit int
columns []string
}
type orderBy struct {
column string
order Order
}
// newQueryBuilder 創建新的查詢構建器
func newQueryBuilder[T Table](repo *repository[T]) QueryBuilder[T] {
return &queryBuilder[T]{
repo: repo,
}
}
// Where 添加 WHERE 條件
func (q *queryBuilder[T]) Where(condition Condition) QueryBuilder[T] {
q.conditions = append(q.conditions, condition)
return q
}
// OrderBy 添加排序
func (q *queryBuilder[T]) OrderBy(column string, order Order) QueryBuilder[T] {
q.orders = append(q.orders, orderBy{column: column, order: order})
return q
}
// Limit 設置限制
func (q *queryBuilder[T]) Limit(n int) QueryBuilder[T] {
q.limit = n
return q
}
// Select 指定要查詢的欄位
func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] {
q.columns = append(q.columns, columns...)
return q
}
// Scan 執行查詢並將結果掃描到 dest
func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
if dest == nil {
return ErrInvalidInput.WithTable(q.repo.table).WithError(
fmt.Errorf("destination cannot be nil"),
)
}
builder := qb.Select(q.repo.table)
// 添加欄位
if len(q.columns) > 0 {
builder = builder.Columns(q.columns...)
} else {
builder = builder.Columns(q.repo.metadata.Columns...)
}
// 添加條件
bindMap := make(map[string]any)
var cmps []qb.Cmp
for _, cond := range q.conditions {
cmp, binds := cond.Build()
cmps = append(cmps, cmp)
for k, v := range binds {
bindMap[k] = v
}
}
if len(cmps) > 0 {
builder = builder.Where(cmps...)
}
// 添加排序
for _, o := range q.orders {
order := qb.ASC
if o.order == DESC {
order = qb.DESC
}
builder = builder.OrderBy(o.column, order)
}
// 添加限制
if q.limit > 0 {
builder = builder.Limit(uint(q.limit))
}
stmt, names := builder.ToCql()
query := q.repo.db.withContextAndTimestamp(ctx,
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
return query.SelectRelease(dest)
}
// One 執行查詢並返回單筆結果
func (q *queryBuilder[T]) One(ctx context.Context) (T, error) {
var zero T
q.limit = 1
var results []T
if err := q.Scan(ctx, &results); err != nil {
return zero, err
}
if len(results) == 0 {
return zero, ErrNotFound.WithTable(q.repo.table)
}
return results[0], nil
}
// Count 計算符合條件的記錄數
func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) {
builder := qb.Select(q.repo.table).Columns("COUNT(*)")
// 添加條件
bindMap := make(map[string]any)
var cmps []qb.Cmp
for _, cond := range q.conditions {
cmp, binds := cond.Build()
cmps = append(cmps, cmp)
for k, v := range binds {
bindMap[k] = v
}
}
if len(cmps) > 0 {
builder = builder.Where(cmps...)
}
stmt, names := builder.ToCql()
query := q.repo.db.withContextAndTimestamp(ctx,
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
var count int64
err := query.GetRelease(&count)
if err == gocql.ErrNotFound {
return 0, nil // COUNT 查詢不會返回 ErrNotFound但為了安全起見
}
return count, err
}