feat: add notification param

This commit is contained in:
王性驊 2025-11-18 17:45:38 +08:00
parent 61fefe26b4
commit 785e7c88e5
33 changed files with 1108 additions and 4165 deletions

2
go.mod
View File

@ -16,6 +16,7 @@ require (
github.com/matcornic/hermes/v2 v2.1.0 github.com/matcornic/hermes/v2 v2.1.0
github.com/minchao/go-mitake v1.0.0 github.com/minchao/go-mitake v1.0.0
github.com/panjf2000/ants/v2 v2.11.3 github.com/panjf2000/ants/v2 v2.11.3
github.com/scylladb/gocqlx/v3 v3.0.4
github.com/segmentio/ksuid v1.0.4 github.com/segmentio/ksuid v1.0.4
github.com/shopspring/decimal v1.4.0 github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
@ -105,6 +106,7 @@ require (
github.com/redis/go-redis/v9 v9.14.0 // indirect github.com/redis/go-redis/v9 v9.14.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect
github.com/russross/blackfriday/v2 v2.0.1 // indirect github.com/russross/blackfriday/v2 v2.0.1 // indirect
github.com/scylladb/go-reflectx v1.0.1 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect github.com/sirupsen/logrus v1.9.3 // indirect

4
go.sum
View File

@ -219,6 +219,10 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/scylladb/go-reflectx v1.0.1 h1:b917wZM7189pZdlND9PbIJ6NQxfDPfBvUaQ7cjj1iZQ=
github.com/scylladb/go-reflectx v1.0.1/go.mod h1:rWnOfDIRWBGN0miMLIcoPt/Dhi2doCMZqwMCJ3KupFc=
github.com/scylladb/gocqlx/v3 v3.0.4 h1:37rMVFEUlsGGNYB7OLR7991KwBYR2WA5TU7wtduClas=
github.com/scylladb/gocqlx/v3 v3.0.4/go.mod h1:3vBkGO+HRh/BYypLWXzurQ45u1BAO0VGBhg5VgperPY=
github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c= github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c=
github.com/segmentio/ksuid v1.0.4/go.mod h1:/XUiZBD3kVx5SmUOl55voK5yeAbBNNIed+2O73XgrPE= github.com/segmentio/ksuid v1.0.4/go.mod h1:/XUiZBD3kVx5SmUOl55voK5yeAbBNNIed+2O73XgrPE=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=

View File

@ -1,55 +0,0 @@
run:
timeout: 3m
issues-exit-code: 2
tests: false # 不檢查測試檔案
linters:
enable:
- govet # 官方靜態分析,抓潛在 bug
- staticcheck # 最強 bug/反模式偵測
- revive # golint 進化版,風格與註解規範
- gofmt # 風格格式化檢查
- goimports # import 排序
- errcheck # error 忽略警告
- ineffassign # 無效賦值
- unused # 未使用變數
- bodyclose # HTTP body close
- gosimple # 靜態分析簡化警告staticcheck 也包含,可選)
- typecheck # 型別檢查
- misspell # 拼字檢查
- gocritic # bug-prone code
- gosec # 資安檢查
- prealloc # slice/array 預分配
- unparam # 未使用參數
issues:
exclude-rules:
- path: _test\.go
linters:
- funlen
- goconst
- cyclop
- gocognit
- lll
- wrapcheck
- contextcheck
linters-settings:
revive:
severity: warning
rules:
- name: blank-imports
severity: error
gofmt:
simplify: true
lll:
line-length: 140
# 可自訂目錄忽略(視專案需求加上)
# skip-dirs:
# - vendor
# - third_party
# 可以設定本機與 CI 上都一致
# env:
# GOLANGCI_LINT_CACHE: ".golangci-lint-cache"

View File

@ -1,8 +0,0 @@
GOFMT ?= gofmt "-s"
GOFILES := $(shell find . -name "*.go")
.PHONY: fmt
fmt: # 格式優化
$(GOFMT) -w $(GOFILES)
goimports -w ./
golangci-lint run

View File

@ -0,0 +1,158 @@
# Cassandra2 - 新一代 Cassandra 客戶端
Cassandra2 是重新設計的 Cassandra 客戶端,提供更簡潔的 API、更好的類型安全性和更清晰的架構。
## 特色
- ✅ **Repository 模式**:每個 Repository 綁定一個 keyspace無需到處傳遞
- ✅ **類型安全**:使用泛型,編譯期類型檢查
- ✅ **簡潔的 API**:統一的查詢介面,流暢的鏈式調用
- ✅ **符合 cursor.md 原則**:小介面、依賴注入、顯式錯誤處理
## 快速開始
### 1. 初始化
```go
import "your-module/pkg/library/cassandra"
// 創建 DB 連接
db, err := cassandra2.New(
cassandra2.WithHosts("localhost"),
cassandra2.WithKeyspace("my_keyspace"),
cassandra2.WithPort(9042),
)
if err != nil {
log.Fatal(err)
}
defer db.Close()
```
### 2. 定義資料模型
```go
type User struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Name string `db:"name"`
Email string `db:"email"`
CreatedAt time.Time `db:"created_at"`
}
func (u User) TableName() string {
return "users"
}
```
### 3. 使用 Repository
```go
// 獲取 Repository
repo, err := db.Repository[User]("my_keyspace")
// 插入
user := User{
ID: gocql.TimeUUID(),
Name: "Alice",
Email: "alice@example.com",
CreatedAt: time.Now(),
}
err = repo.Insert(ctx, user)
// 查詢
var result User
result, err = repo.Get(ctx, user.ID)
// 更新
user.Email = "newemail@example.com"
err = repo.Update(ctx, user)
// 刪除
err = repo.Delete(ctx, user.ID)
```
### 4. 使用 Query Builder
```go
// 條件查詢
var users []User
err = repo.Query().
Where(cassandra2.Eq("status", "active")).
OrderBy("created_at", cassandra2.DESC).
Limit(10).
Scan(ctx, &users)
// 單筆查詢
user, err := repo.Query().
Where(cassandra2.Eq("id", userID)).
One(ctx)
// 計數
count, err := repo.Query().
Where(cassandra2.Eq("status", "active")).
Count(ctx)
```
### 5. Batch 操作
```go
batch := repo.Batch(ctx)
batch.Insert(user1).
Insert(user2).
Update(user3)
err = batch.Commit(ctx)
```
### 6. Transaction 操作
```go
tx := db.Begin(ctx, "my_keyspace")
tx.Insert(user1)
tx.Update(user2)
if err := tx.Commit(ctx); err != nil {
tx.Rollback(ctx)
}
```
## API 對比
### 舊 API (Cassandra 1)
```go
db.Insert(ctx, user, "keyspace")
db.Model(ctx, &User{}, "keyspace").Where(...).Scan(&result)
```
### 新 API (Cassandra 2)
```go
repo := db.Repository[User]("keyspace")
repo.Insert(ctx, user)
repo.Query().Where(...).Scan(ctx, &result)
```
## 主要改進
1. **移除 keyspace 參數**Repository 綁定 keyspace無需重複傳遞
2. **類型安全**:使用泛型,編譯期檢查
3. **統一 API**:只有一套查詢介面
4. **更好的錯誤處理**:統一的錯誤類型,支援 errors.Is/As
## 注意事項
1. **主鍵查詢**`Get` 方法需要完整的 Primary Key。如果是多欄位主鍵需要傳入包含所有主鍵欄位的 struct。
2. **更新行為**`Update` 預設只更新非零值欄位,使用 `UpdateAll` 可更新所有欄位。
3. **Transaction**:這是補償式交易,不是真正的 ACID 交易,適用於最終一致性場景。
## 遷移指南
從 Cassandra 1 遷移到 Cassandra 2
1. 將 `cassandra.NewCassandraDB` 改為 `cassandra2.New`
2. 將 `db.Insert(ctx, doc, keyspace)` 改為 `repo.Insert(ctx, doc)`,其中 `repo = db.Repository[Type](keyspace)`
3. 將 `db.Model(...)` 改為 `repo.Query()`
4. 更新錯誤處理:使用 `cassandra2.IsNotFound` 等函數
## 文檔
詳細的技術設計請參考:
- `REFACTORING_PLAN.md` - 重構計畫
- `TECHNICAL_DESIGN.md` - 技術設計文檔

View File

@ -1,113 +0,0 @@
package cassandra
import (
"context"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/scylladb/gocqlx/v3/table"
)
// TODO: 只保證同一個 PK 下有一致性,中間有失敗的話可能只有失敗不會寫入,其他成功的還是會成功。
// 之後會朝兩個方向走
// 1. 最終一致性:目前的設計是直接寫入副表,然後透過 background worker 讀取 sync_task 表,補寫副表資料。
// 2. 研究 自己做 TX_ID 以及 STATUS 的方案
// 這個是已知問題,一定要解決
// NewBatch 創建一個新的 Batch 操作
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) NewBatch(ctx context.Context, keyspace string) *Batch {
keyspace = getKeyspace(db, keyspace)
session := db.GetSession()
return &Batch{
ctx: ctx,
keyspace: keyspace,
db: db,
batch: gocqlx.Batch{
Batch: session.NewBatch(gocql.LoggedBatch).WithContext(ctx),
},
}
}
type Batch struct {
ctx context.Context
keyspace string
db *CassandraDB
batch gocqlx.Batch
}
func (tx *Batch) Insert(doc any) error {
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
if err != nil {
return err
}
tbl := table.New(metadata)
stmt, names := tbl.Insert()
return tx.batch.BindStruct(tx.db.GetSession().Query(stmt, names), doc)
}
func (tx *Batch) Delete(doc any) error {
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
if err != nil {
return err
}
tbl := table.New(metadata)
stmt, names := tbl.Delete()
return tx.batch.BindStruct(tx.db.GetSession().Query(stmt, names), doc)
}
func (tx *Batch) Update(doc any) error {
metadata, err := GenerateTableMetadata(doc, tx.keyspace)
if err != nil {
return err
}
v := reflect.ValueOf(doc)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
typ := v.Type()
setCols := make([]string, 0)
setVals := make([]any, 0)
whereCols := make([]string, 0)
whereVals := make([]any, 0)
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
val := v.Field(i)
if !val.IsValid() {
continue
}
if contains(metadata.PartKey, tag) || contains(metadata.SortKey, tag) {
whereCols = append(whereCols, tag)
whereVals = append(whereVals, val.Interface())
} else if !isZero(val) {
setCols = append(setCols, tag)
setVals = append(setVals, val.Interface())
}
}
if len(setCols) == 0 {
return ErrNoFieldsToUpdate.WithTable(metadata.Name)
}
builder := qb.Update(metadata.Name).Set(setCols...)
for _, col := range whereCols {
builder = builder.Where(qb.Eq(col))
}
stmt, names := builder.ToCql()
setVals = append(setVals, whereVals...)
return tx.batch.Bind(tx.db.GetSession().Query(stmt, names), setVals...)
}
func (tx *Batch) Commit() error {
session := tx.db.GetSession()
return session.ExecuteBatch(&tx.batch)
}

View File

@ -1,44 +0,0 @@
package cassandra
import (
"context"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
)
func TestBatchTx_AllSuccess(t *testing.T) {
t.Parallel()
ctx := context.Background()
ks := generateRandomKeySpace(t)
now := time.Now()
id1 := gocql.TimeUUID()
id2 := gocql.TimeUUID()
tx := cassandraDBTest.NewBatch(ctx, ks)
err := tx.Insert(&MonkeyEntity{ID: id1, Name: "Alice", UpdateAt: now, CreateAt: now})
assert.NoError(t, err)
err = tx.Insert(&MonkeyEntity{ID: id2, Name: "Bob", UpdateAt: now, CreateAt: now})
assert.NoError(t, err)
err = tx.Update(&MonkeyEntity{ID: id1, Name: "Alice", UpdateAt: now.Add(5 * time.Minute)})
assert.NoError(t, err)
err = tx.Delete(&MonkeyEntity{ID: id2, Name: "Bob"})
assert.NoError(t, err)
err = tx.Commit()
assert.NoError(t, err)
// Alice 應該還在,且被更新
var alice MonkeyEntity
alice.ID, alice.Name = id1, "Alice"
err = cassandraDBTest.Get(ctx, &alice, ks)
assert.NoError(t, err)
assert.WithinDuration(t, now.Add(5*time.Minute), alice.UpdateAt, time.Second)
// Bob 應該被刪除
err = cassandraDBTest.Get(ctx, &MonkeyEntity{ID: id2, Name: "Bob"}, ks)
assert.Error(t, err)
}

View File

@ -1,204 +0,0 @@
package cassandra
import (
"context"
"fmt"
"log"
"os"
"sync/atomic"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
type Container struct {
Ctx context.Context
Container testcontainers.Container
Host string
Port int
}
var cassandraDBTest *CassandraDB
var keyspaceSequence atomic.Int64
func TestMain(m *testing.M) {
container, db := connCassandraForTest()
cassandraDBTest = db
code := m.Run()
cassandraDBTest.Close()
if err := container.Container.Terminate(container.Ctx); err != nil {
log.Fatalf("Failed to terminate Cassandra container: %v", err)
}
log.Println("[TEST] Container terminated")
os.Exit(code)
}
func initCassandraContainer(version string) (Container, error) {
ctx := context.Background()
req := testcontainers.ContainerRequest{
Image: fmt.Sprintf("cassandra:%s", version),
Env: map[string]string{
"CASSANDRA_START_RPC": "true",
"CASSANDRA_NUM_TOKENS": "1",
"CASSANDRA_ENDPOINT_SNITCH": "GossipingPropertyFileSnitch",
"CASSANDRA_DC": "datacenter1",
"CASSANDRA_RACK": "rack1",
"MAX_HEAP_SIZE": "256M",
"HEAP_NEWSIZE": "100M",
},
ExposedPorts: []string{"9042/tcp"},
// 等待 Cassandra 啟動完成的指標字串,依據實際啟動 log 可調整
WaitingFor: wait.ForLog("Created default superuser role 'cassandra'").
WithStartupTimeout(2 * time.Minute),
}
cassandraContainer, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
return Container{}, err
}
host, err := cassandraContainer.Host(ctx)
if err != nil {
return Container{}, err
}
mappedPort, err := cassandraContainer.MappedPort(ctx, "9042")
if err != nil {
return Container{}, err
}
return Container{ctx, cassandraContainer, host, mappedPort.Int()}, nil
}
func connCassandraForTest() (Container, *CassandraDB) {
// 啟動 Cassandra container
dbContainer, err := initCassandraContainer("5.0.4")
if err != nil {
log.Fatalf("Failed to initialize Cassandra container: %v", err)
}
db, err := NewCassandraDB(
[]string{dbContainer.Host},
WithPort(dbContainer.Port),
WithConsistency(gocql.One),
WithNumConns(5),
)
if err != nil {
log.Fatalf("Failed to initialize Cassandra DB: %v", err)
}
// 建立 keyspace 和 table
err = db.EnsureTable(`
CREATE KEYSPACE IF NOT EXISTS my_keyspace
WITH replication = {
'class': 'SimpleStrategy',
'replication_factor': 1
};`)
if err != nil {
log.Fatalf("Failed to create keyspace: %v", err)
}
err = db.EnsureTable(`
CREATE TABLE IF NOT EXISTS my_keyspace.monkey_entity (
id UUID,
name TEXT,
update_at TIMESTAMP,
create_at TIMESTAMP,
PRIMARY KEY ((id), name)
);`)
if err != nil {
log.Fatalf("Failed to create table: %v", err)
}
return dbContainer, db
}
func generateRandomKeySpace(t *testing.T) string {
ks := fmt.Sprintf("my_keyspace_%d", keyspaceSequence.Add(1))
err := cassandraDBTest.EnsureTable(fmt.Sprintf(`
CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = {
'class': 'SimpleStrategy',
'replication_factor': 1
};`, ks))
if err != nil {
t.Fatalf("Failed to create keyspace: %v", err)
}
err = cassandraDBTest.EnsureTable(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s.monkey_entity (
id UUID,
name TEXT,
update_at TIMESTAMP,
create_at TIMESTAMP,
PRIMARY KEY ((id), name)
);`, ks))
if err != nil {
log.Fatalf("Failed to create table: %v", err)
}
return ks
}
// Animal 為不實作 TableName 方法的範例 struct則會以型別名稱轉換成 snake_case
type Animal struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Type string `db:"type"`
}
func (m *Animal) TableName() string {
return "animal"
}
// InvalidEntity 為無 partition key 的範例 struct預期產生錯誤
type InvalidEntity struct {
Field string `db:"field"`
}
type MonkeyEntity struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Name string `db:"name" clustering_key:"true" sai:"true"`
UpdateAt time.Time `db:"update_at"`
CreateAt time.Time `db:"create_at"`
}
func (m *MonkeyEntity) TableName() string {
return "monkey_entity"
}
type CatEntity struct {
ID *gocql.UUID `db:"id" partition_key:"true"`
Name *string `db:"name" partition_key:"true"`
UpdateAt *time.Time `db:"update_at"`
CreateAt *time.Time `db:"create_at" clustering_key:"true"`
}
func (m *CatEntity) TableName() string {
return "cat_entity"
}
type Consistency struct {
ID gocql.UUID `db:"id" partition_key:"true"`
ConsistencyName string `db:"consistency_name"` // can editor
ConsistencyType string `db:"consistency_type"`
LastTaskID string `db:"last_task_id"` // ConsistencyTask ID
Target string `db:"target"` // file name can editor
Status string `db:"status"`
ConsistencyMap string `db:"consistency_map"` // JSON string
CreateAT int64 `db:"create_at"`
UpdateAT int64 `db:"update_at"`
}
func (c *Consistency) TableName() string {
return "consistency"
}

View File

@ -1,209 +0,0 @@
package cassandra
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3"
)
// cassandraConf 是初始化 CassandraDB 所需的內部設定(私有)
type cassandraConf struct {
Hosts []string // Cassandra 主機列表
Port int // 連線埠
Keyspace string // 預設使用的 Keyspace
Username string // 認證用戶名
Password string // 認證密碼
Consistency gocql.Consistency // 一致性級別
ConnectTimeoutSec int // 連線逾時秒數
NumConns int // 每個節點連線數
MaxRetries int // 重試次數
UseAuth bool // 是否使用帳號密碼驗證
RetryMinInterval time.Duration // 重試間隔最小值
RetryMaxInterval time.Duration // 重試間隔最大值
ReconnectInitialInterval time.Duration // 重連初始間隔
ReconnectMaxInterval time.Duration // 重連最大間隔
CQLVersion string // 執行連線的CQL 版本號
}
// CassandraDB 是封裝了 Cassandra 資料庫 session 的結構
type CassandraDB struct {
session gocqlx.Session
SaiSupported bool // 是否支援 sai
Version string // 資料庫版本
defaultKeyspace string // 預設 keyspace
}
// NewCassandraDB 初始化並建立 Cassandra 資料庫連線使用預設設定並可透過Option修改
func NewCassandraDB(hosts []string, opts ...Option) (*CassandraDB, error) {
config := &cassandraConf{
Hosts: hosts,
Port: defaultPort,
Consistency: defaultConsistency,
ConnectTimeoutSec: defaultTimeoutSec,
NumConns: defaultNumConns,
MaxRetries: defaultMaxRetries,
RetryMinInterval: defaultRetryMinInterval,
RetryMaxInterval: defaultRetryMaxInterval,
ReconnectInitialInterval: defaultReconnectInitialInterval,
ReconnectMaxInterval: defaultReconnectMaxInterval,
CQLVersion: defaultCqlVersion,
}
// 套用Option設定選項
for _, opt := range opts {
opt(config)
}
// 建立連線設定
cluster := gocql.NewCluster(config.Hosts...)
cluster.Port = config.Port
cluster.Consistency = config.Consistency
cluster.Timeout = time.Duration(config.ConnectTimeoutSec) * time.Second
cluster.NumConns = config.NumConns
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
NumRetries: config.MaxRetries,
Min: config.RetryMinInterval,
Max: config.RetryMaxInterval,
}
cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{
MaxRetries: config.MaxRetries,
InitialInterval: config.ReconnectInitialInterval,
MaxInterval: config.ReconnectMaxInterval,
}
// 若有提供 Keyspace 則指定
if config.Keyspace != "" {
cluster.Keyspace = config.Keyspace
}
// 若啟用驗證則設定帳號密碼
if config.UseAuth {
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: config.Username,
Password: config.Password,
}
}
// 建立 Session
s, err := gocqlx.WrapSession(cluster.CreateSession())
if err != nil {
return nil, fmt.Errorf("failed to connect to Cassandra cluster (hosts: %v, port: %d): %w", config.Hosts, config.Port, err)
}
db := &CassandraDB{
session: s,
defaultKeyspace: config.Keyspace,
}
version, err := db.getReleaseVersion()
if err != nil {
return nil, fmt.Errorf("failed to get DB version: %w", err)
}
db.Version = version
db.SaiSupported = isSAISupported(version)
return db, nil
}
// Close 關閉 Cassandra 資料庫連線
func (db *CassandraDB) Close() {
db.session.Close()
}
// GetSession 返回目前使用的 Cassandra Session
func (db *CassandraDB) GetSession() gocqlx.Session {
return db.session
}
// GetDefaultKeyspace 返回預設的 keyspace
func (db *CassandraDB) GetDefaultKeyspace() string {
return db.defaultKeyspace
}
// WithKeyspace 返回一個帶有指定 keyspace 的查詢構建器
// 如果 keyspace 為空,則使用預設 keyspace
func (db *CassandraDB) WithKeyspace(keyspace string) *KeyspaceDB {
if keyspace == "" {
keyspace = db.defaultKeyspace
}
return &KeyspaceDB{
db: db,
keyspace: keyspace,
}
}
// KeyspaceDB 是帶有 keyspace 的資料庫包裝器
type KeyspaceDB struct {
db *CassandraDB
keyspace string
}
// GetSession 返回 session
func (kdb *KeyspaceDB) GetSession() gocqlx.Session {
return kdb.db.GetSession()
}
// GetKeyspace 返回 keyspace
func (kdb *KeyspaceDB) GetKeyspace() string {
return kdb.keyspace
}
// EnsureTable 確認並建立資料表
func (db *CassandraDB) EnsureTable(schema string) error {
return db.session.ExecStmt(schema)
}
func (db *CassandraDB) InitVersionSupport() error {
version, err := db.getReleaseVersion()
if err != nil {
return err
}
db.Version = version
db.SaiSupported = isSAISupported(version)
return nil
}
func (db *CassandraDB) getReleaseVersion() (string, error) {
var version string
stmt := "SELECT release_version FROM system.local"
err := db.GetSession().Query(stmt, []string{"release_version"}).Consistency(gocql.One).Scan(&version)
return version, err
}
func isSAISupported(version string) bool {
// 只要 major >=5 就支援
// 4.0.9+ 才有 SAI但不穩強烈建議 5.0+
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
major, _ := strconv.Atoi(parts[0])
minor, _ := strconv.Atoi(parts[1])
if major >= 5 {
return true
}
if major == 4 {
if minor > 0 { // 4.1.x、4.2.x 直接支援
return true
}
if minor == 0 {
patch := 0
if len(parts) >= 3 {
patch, _ = strconv.Atoi(parts[2])
}
if patch >= 9 {
return true
}
}
}
return false
}

View File

@ -1,210 +0,0 @@
package cassandra
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsSAISupported(t *testing.T) {
tests := []struct {
version string
expected bool
}{
{"5.0.0", true}, // 5.x 支援
{"5.1.2", true}, // 5.x 支援
{"6.0.0", true}, // 6.x 理論上也支援
{"4.0.8", false}, // 4.0.8 不支援
{"4.0.9", true}, // 4.0.9 支援
{"4.1.0", true}, // 4.1.0 支援
{"4.2.2", true}, // 4.2.2 支援
{"3.11.10", false}, // 3.x 不支援
{"3.0.0", false},
{"", false}, // 空字串,不支援
{"unknown", false}, // 無效格式
{"4", false}, // 缺 patch不支援
{"4.0", false}, // 缺 patch不支援
{"5", false}, // 缺 minor
{"5.0", true}, // 5.0 預設支援
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
result := isSAISupported(tt.version)
assert.Equal(t, tt.expected, result, "version: %s", tt.version)
})
}
}
// TestCassandraDB_Integration_TableDriven 使用 table-driven 方式整合測試
// func TestCassandraDB_Integration_TableDriven(t *testing.T) {
// // 啟動 Cassandra container
// dbContainer, err := initCassandraContainer("5.0.4")
// defer func() {
// _ = dbContainer.Container.Terminate(dbContainer.Ctx)
// fmt.Println("[TEST] Container terminated")
// }()
// // 建立 CassandraDB 連線
// hosts := []string{dbContainer.Host}
// db, err := NewCassandraDB(
// hosts,
// WithPort(dbContainer.Port),
// WithConsistency(gocql.One),
// WithNumConns(2),
// )
// assert.NoError(t, err, "should success create CassandraDB")
// assert.NotNil(t, db, "db should not be nil")
// assert.NotNil(t, db.GetSession(), "get Session should not be nil")
// err = db.EnsureTable("CREATE KEYSPACE my_keyspace\nWITH replication = {\n 'class': 'SimpleStrategy',\n 'replication_factor': 1\n};\n")
// assert.NoError(t, err, "should success ensure table")
// // 注意:由於 Close 會關閉 session因此請把測試 Close 的子案例放在所有使用 session 的子案例之後
// tests := []struct {
// name string
// action func() error
// wantErr bool
// }{
// {
// name: "ok",
// action: func() error {
// // 建立一個合法的資料表 (使用 IF NOT EXISTS 避免重複建立錯誤)
// schema := "CREATE TABLE IF NOT EXISTS my_keyspace.test (id uuid PRIMARY KEY, name text)"
// return db.EnsureTable(schema)
// },
// wantErr: false,
// },
// {
// name: "failed to ensure table since wrong schema",
// action: func() error {
// // 傳入無效的 CQL 語法,預期應回傳錯誤
// schema := "CREATE TABLE invalid schema"
// return db.EnsureTable(schema)
// },
// wantErr: true,
// },
// {
// name: "GetSession 返回有效 Session",
// action: func() error {
// if db.GetSession().Session == nil {
// return fmt.Errorf("session is nil")
// }
// return nil
// },
// wantErr: false,
// },
// {
// name: "Close close Session",
// action: func() error {
// db.Close()
// // 無法直接驗證內部是否已關閉,但可避免再次使用 session 產生 panic
// return nil
// },
// wantErr: false,
// },
// }
// // 依序執行各子案例
// for _, tc := range tests {
// t.Run(tc.name, func(t *testing.T) {
// err := tc.action()
// if (err != nil) != tc.wantErr {
// t.Errorf("%s havs error = %v, wantErr %v", tc.name, err, tc.wantErr)
// }
// })
// }
// }
// Mark: new multiple container lead to unit test too slow
// func TestCassandraDB_getReleaseVersion(t *testing.T) {
// t.Parallel()
// type fields struct {
// Version string
// }
// tests := []struct {
// name string
// fields fields
// want string
// wantError bool
// }{
// {
// name: "3",
// fields: fields{Version: "3.11"},
// want: "3.11.19",
// wantError: false,
// },
// {
// name: "5",
// fields: fields{Version: "5.0.4"},
// want: "5.0.4",
// wantError: false,
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// container, err := initCassandraContainer(tt.fields.Version)
// defer func() {
// _ = container.Container.Terminate(container.Ctx)
// fmt.Println("[TEST] Container terminated")
// }()
// if !tt.wantError {
// assert.NoError(t, err)
// // 建立 CassandraDB 連線
// hosts := []string{container.Host}
// db, err := NewCassandraDB(
// hosts,
// WithPort(container.Port),
// WithConsistency(gocql.One),
// WithNumConns(2),
// )
// assert.NoError(t, err)
// version, err := db.getReleaseVersion()
// assert.NoError(t, err)
// assert.Equal(t, version, tt.want)
// }
// })
// }
// }
func TestCassandraDB_getReleaseVersion(t *testing.T) {
t.Parallel()
type fields struct {
Version string
}
tests := []struct {
name string
fields fields
want string
wantError bool
}{
// {
// name: "3",
// fields: fields{Version: "3.11"},
// want: "3.11.19",
// wantError: false,
// },
{
name: "5.0.4",
fields: fields{Version: "5.0.4"},
want: "5.0.4",
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !tt.wantError {
version, err := cassandraDBTest.getReleaseVersion()
assert.NoError(t, err)
assert.Equal(t, version, tt.want)
}
})
}
}

View File

@ -1,203 +0,0 @@
package cassandra
import (
"context"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/scylladb/gocqlx/v3/table"
)
var qh = &queryHelper{}
// Insert 依據 document 自動產生 INSERT 語句並執行
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) Insert(ctx context.Context, document any, keyspace string) error {
keyspace = getKeyspace(db, keyspace)
metadata, err := GenerateTableMetadata(document, keyspace)
if err != nil {
return err
}
t := table.New(metadata)
q := qh.withContextAndTimestamp(ctx, db.GetSession().Query(t.Insert()).BindStruct(document))
return q.ExecRelease()
}
// Get 根據 struct 的 Primary Key 查詢單筆資料Get ByPK
// - filter 為目標資料 struct其欄位需對應表格的 Primary Key 欄位Partition Key + Clustering Key
// - Cassandra 中 Primary Key 是由 Partition Key 與 Clustering Key 組成的整體,作為唯一識別一筆資料的 key
// - Cassandra 並不保證 Partition Key 或 Clustering Key 單獨具有唯一性,只有整個 Primary Key 才是唯一
// - Partition Key 的作用是將資料分布到不同節點NodeClustering Key 則是節點內排序資料用
// - 如果僅提供 Partition Key會查到分區內的多筆資料但由於 .Get() 預設加 LIMIT 1僅會取得其中一筆排序第一
// - 若想查詢特定欄位(如 name但該欄位不是 Primary Key 組成部分,則無法使用 .Get() 查詢,也無法用該欄位直接篩選資料(會報錯)
// - 解法是1. 改變 table 結構使欲查欄位成為 PK或 2. 建立額外 table 以該欄位為 Partition Key或 3. 使用 ALLOW FILTERING不建議
// Get 根據 struct 的 Primary Key 查詢單筆資料Get ByPK
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) Get(ctx context.Context, dest any, keyspace string) error {
keyspace = getKeyspace(db, keyspace)
metadata, err := GenerateTableMetadata(dest, keyspace)
if err != nil {
return err
}
t := table.New(metadata)
q := qh.withContextAndTimestamp(ctx, db.GetSession().Query(t.Get()).BindStruct(dest))
err = q.GetRelease(dest)
if err == gocql.ErrNotFound {
return ErrNotFound.WithTable(metadata.Name)
} else if err != nil {
return ErrInvalidInput.WithTable(metadata.Name).WithError(err)
}
return nil
}
// Delete 依據 document 的主鍵產生 DELETE 語句並執行
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) Delete(ctx context.Context, filter any, keyspace string) error {
keyspace = getKeyspace(db, keyspace)
metadata, err := GenerateTableMetadata(filter, keyspace)
if err != nil {
return err
}
t := table.New(metadata)
stmt, names := t.Delete()
q := qh.withContextAndTimestamp(ctx, db.GetSession().Query(stmt, names).BindStruct(filter))
return q.ExecRelease()
}
// Update 根據 document 欄位產生 UPDATE 語句並執行
// - 只會更新非零值或非 nil 的欄位(零值欄位會被排除)
// - 主鍵欄位一定會保留,作為 WHERE 條件使用
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) Update(ctx context.Context, document any, keyspace string) error {
return db.UpdateSelective(ctx, document, keyspace, false)
}
// UpdateSelective 根據 document 欄位產生 UPDATE 語句並執行
// - includeZero: false 時只更新非零值欄位(等同於 Updatetrue 時更新所有欄位(包括零值)
// - 主鍵欄位一定會保留,作為 WHERE 條件使用
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) UpdateSelective(ctx context.Context, document any, keyspace string, includeZero bool) error {
keyspace = getKeyspace(db, keyspace)
metadata, err := GenerateTableMetadata(document, keyspace)
if err != nil {
return err
}
v := reflect.ValueOf(document)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
typ := v.Type()
// 收集更新欄位與其值(根據 includeZero 決定是否包含零值,保留主鍵)
setCols := make([]string, 0)
setVals := make([]any, 0)
whereCols := make([]string, 0)
whereVals := make([]any, 0)
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
val := v.Field(i)
if !val.IsValid() {
continue
}
if contains(metadata.PartKey, tag) || contains(metadata.SortKey, tag) {
whereCols = append(whereCols, tag)
whereVals = append(whereVals, val.Interface())
continue
}
if !includeZero && isZero(val) {
continue
}
setCols = append(setCols, tag)
setVals = append(setVals, val.Interface())
}
if len(setCols) == 0 {
return ErrNoFieldsToUpdate.WithTable(metadata.Name)
}
// Build UPDATE statement
builder := qb.Update(metadata.Name).Set(setCols...)
for _, col := range whereCols {
builder = builder.Where(qb.Eq(col))
}
stmt, names := builder.ToCql()
setVals = append(setVals, whereVals...)
q := qh.withContextAndTimestamp(ctx, db.GetSession().Query(stmt, names).Bind(setVals...))
return q.ExecRelease()
}
// UpdateAll 更新所有欄位(包括零值)
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) UpdateAll(ctx context.Context, document any, keyspace string) error {
return db.UpdateSelective(ctx, document, keyspace, true)
}
// TODO: Cassandra 不支援 OFFSET 方式的分頁(例如查詢第 N 頁)
// 原因Cassandra 是分散式資料庫,設計上不允許像傳統 SQL 那樣用 OFFSET 跳頁,會導致效能極差
// ✅ 正確方式為使用 PagingState 做游標式Cursor-based分頁一頁一頁往後翻
// ✅ 如果需要快取第 N 頁位置,應在應用層儲存每一頁的 PagingState 以供跳轉
// ❌ Cassandra 不適合直接實作全站排行榜或全表分頁查詢,除非搭配 ElasticSearch 或針對 Partition Key 分頁設計
// 若未來有特定分區(如 user_id條件可考慮實作分區內的分頁邏輯以提高效能
// GetAll 取得指定 struct 類型在 Cassandra 中的所有資料
// - filter用來推斷 table 結構的範例物件(可為指標)
// - result要寫入的 slice 指標,如 *[]MyStruct
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) GetAll(ctx context.Context, filter any, result any, keyspace string) error {
keyspace = getKeyspace(db, keyspace)
metadata, err := GenerateTableMetadata(filter, keyspace)
if err != nil {
return err
}
t := table.New(metadata)
stmt, names := qb.Select(t.Name()).Columns(metadata.Columns...).ToCql()
q := qh.withContextAndTimestamp(ctx, db.GetSession().Query(stmt, names))
return q.SelectRelease(result)
}
// QueryBuilder executes a query with optional conditions on Cassandra table
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) QueryBuilder(
ctx context.Context,
tableStruct any,
result any,
keyspace string,
opts ...QueryOption,
) error {
keyspace = getKeyspace(db, keyspace)
metadata, err := GenerateTableMetadata(tableStruct, keyspace)
if err != nil {
return err
}
tbl := table.New(metadata)
builder := qb.Select(tbl.Name()).Columns(metadata.Columns...)
bindMap := qb.M{}
for _, opt := range opts {
opt(builder, bindMap)
}
stmt, names := builder.ToCql()
query := qh.withContextAndTimestamp(ctx, db.GetSession().Query(stmt, names).BindMap(bindMap))
return query.SelectRelease(result)
}

View File

@ -1,363 +0,0 @@
package cassandra
import (
"context"
"fmt"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/stretchr/testify/assert"
)
func TestInsert(t *testing.T) {
t.Parallel()
ks := generateRandomKeySpace(t)
ctx := context.Background()
now := time.Now()
// 測試案例(可擴充)
tests := []struct {
name string
input MonkeyEntity
}{
{
name: "insert George",
input: MonkeyEntity{
ID: gocql.TimeUUID(),
Name: "George",
UpdateAt: now,
CreateAt: now,
},
},
{
name: "insert Bob",
input: MonkeyEntity{
ID: gocql.TimeUUID(),
Name: "Bob",
UpdateAt: now,
CreateAt: now,
},
},
{
name: "insert Alice",
input: MonkeyEntity{
ID: gocql.TimeUUID(),
Name: "Alice",
UpdateAt: now,
CreateAt: now,
},
},
}
// 執行測試
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := cassandraDBTest.Insert(ctx, &tc.input, ks)
assert.NoError(t, err)
// 驗證寫入
var name string
q := cassandraDBTest.GetSession().Query(fmt.Sprintf("SELECT name FROM %s.monkey_entity WHERE id = ?", ks), []string{"name"})
err = q.Bind(tc.input.ID).GetRelease(&name)
assert.NoError(t, err)
assert.Equal(t, tc.input.Name, name)
})
}
}
func TestGet(t *testing.T) {
t.Parallel()
ctx := context.Background()
ks := generateRandomKeySpace(t)
now := time.Now()
monkey := MonkeyEntity{
ID: gocql.TimeUUID(),
Name: "George",
UpdateAt: now,
CreateAt: now,
}
// 插入一筆資料
err := cassandraDBTest.Insert(ctx, &monkey, ks)
assert.NoError(t, err)
tests := []struct {
name string
filter MonkeyEntity
expect string
}{
{
name: "Get existing monkey",
filter: MonkeyEntity{ID: monkey.ID, Name: monkey.Name},
expect: "George",
},
{
name: "Get non-existent monkey",
filter: MonkeyEntity{ID: gocql.TimeUUID(), Name: "GG"},
expect: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := tc.filter // 預設填入主鍵
err := cassandraDBTest.Get(ctx, &result, ks)
if tc.expect == "" {
assert.Error(t, err, "expected error for missing record")
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expect, result.Name)
}
})
}
}
func TestDelete(t *testing.T) {
t.Parallel()
ks := generateRandomKeySpace(t)
ctx := context.Background()
now := time.Now()
monkey := MonkeyEntity{
ID: gocql.TimeUUID(),
Name: "DeleteMe",
UpdateAt: now,
CreateAt: now,
}
// 插入資料
err := cassandraDBTest.Insert(ctx, &monkey, ks)
assert.NoError(t, err)
// 先確認有插入成功
verify := MonkeyEntity{ID: monkey.ID, Name: monkey.Name}
err = cassandraDBTest.Get(ctx, &verify, ks)
assert.NoError(t, err)
assert.Equal(t, "DeleteMe", verify.Name)
// 執行刪除
err = cassandraDBTest.Delete(ctx, &monkey, ks)
assert.NoError(t, err)
// 再查,應該查不到
result := MonkeyEntity{ID: monkey.ID, Name: monkey.Name}
err = cassandraDBTest.Get(ctx, &result, ks)
assert.Error(t, err, "expected error because record should be deleted")
}
func TestUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ks := generateRandomKeySpace(t)
now := time.Now()
id := gocql.TimeUUID()
// Step 1: 插入初始資料
monkey := MonkeyEntity{
ID: id,
Name: "OldName",
UpdateAt: now,
CreateAt: now,
}
err := cassandraDBTest.Insert(ctx, &monkey, ks)
assert.NoError(t, err)
// Step 2: 更新 UpdateAt 欄位(模擬只更新一欄)
updatedTime := now.Add(10 * time.Minute)
updateDoc := MonkeyEntity{
ID: id,
Name: "OldName", // 主鍵
UpdateAt: updatedTime,
// CreateAt 是零值,不會被更新
}
err = cassandraDBTest.Update(ctx, &updateDoc, ks)
assert.NoError(t, err)
// Step 3: 查詢回來驗證更新
result := MonkeyEntity{
ID: id,
Name: "OldName",
}
err = cassandraDBTest.Get(ctx, &result, ks)
assert.NoError(t, err)
assert.WithinDuration(t, updatedTime, result.UpdateAt, time.Second)
assert.WithinDuration(t, now, result.CreateAt, time.Second) // 未被更新
}
func insertSampleConsistency(t *testing.T, db *CassandraDB, ctx context.Context, keyspace string) *Consistency {
err := db.EnsureTable(`
CREATE TABLE IF NOT EXISTS my_keyspace.consistency (
id UUID,
consistency_name TEXT,
last_task_id TEXT,
target TEXT,
status TEXT,
consistency_type TEXT,
consistency_map TEXT,
create_at BIGINT,
update_at BIGINT,
PRIMARY KEY ((id))
);`)
assert.NoError(t, err)
c := &Consistency{
ID: gocql.TimeUUID(),
ConsistencyName: "query-test",
LastTaskID: "task-1",
Target: "test.csv",
Status: "Running",
ConsistencyType: "simple",
ConsistencyMap: `{"example": "value"}`,
CreateAT: time.Now().UnixNano(),
UpdateAT: time.Now().UnixNano(),
}
err = db.Insert(ctx, c, keyspace)
assert.NoError(t, err)
return c
}
func TestQueryBuilder_WithWhere(t *testing.T) {
t.Parallel()
ctx := context.Background()
saved := insertSampleConsistency(t, cassandraDBTest, ctx, "my_keyspace")
t.Run("query by id", func(t *testing.T) {
var results []*Consistency
e := &Consistency{}
field := GetCqlTag(e, &e.ID)
err := cassandraDBTest.QueryBuilder(
ctx,
&Consistency{},
&results,
"my_keyspace",
WithWhere(
[]qb.Cmp{qb.Eq(field)},
map[string]any{field: saved.ID.String()},
),
)
assert.NoError(t, err)
assert.NotEmpty(t, results)
found := false
for _, r := range results {
if r.ID == saved.ID {
found = true
break
}
}
assert.True(t, found, "should find inserted consistency")
})
t.Run("query with unmatched id", func(t *testing.T) {
var results []*Consistency
e := &Consistency{}
field := GetCqlTag(e, &e.ID)
err := cassandraDBTest.QueryBuilder(
ctx,
&Consistency{},
&results,
"my_keyspace",
WithWhere(
[]qb.Cmp{qb.Eq(field)},
map[string]any{field: "NonExist"},
),
)
assert.Error(t, err)
assert.Empty(t, results)
})
t.Run("query by in", func(t *testing.T) {
var results []*Consistency
e := &Consistency{}
field := GetCqlTag(e, &e.ID)
err := cassandraDBTest.QueryBuilder(
ctx,
&Consistency{},
&results,
"my_keyspace",
WithWhere(
[]qb.Cmp{qb.In(field)},
map[string]any{field: []gocql.UUID{saved.ID}},
),
)
assert.NoError(t, err)
assert.NotEmpty(t, results)
found := false
for _, r := range results {
if r.ID == saved.ID {
found = true
break
}
}
assert.True(t, found, "should find inserted consistency")
})
t.Run("query by one is not in", func(t *testing.T) {
var results []*Consistency
e := &Consistency{}
field := GetCqlTag(e, &e.ID)
err := cassandraDBTest.QueryBuilder(
ctx,
&Consistency{},
&results,
"my_keyspace",
WithWhere(
[]qb.Cmp{qb.In(field)},
map[string]any{field: []gocql.UUID{saved.ID, gocql.TimeUUID()}},
),
)
assert.NoError(t, err)
assert.NotEmpty(t, results)
found := false
for _, r := range results {
if r.ID == saved.ID {
found = true
break
}
}
assert.True(t, found, "should find inserted consistency")
})
t.Run("query get all", func(t *testing.T) {
var results []*Consistency
e := &Consistency{}
err := cassandraDBTest.QueryBuilder(
ctx,
e,
&results,
"my_keyspace",
)
assert.NoError(t, err)
assert.NotEmpty(t, results)
found := false
for _, r := range results {
if r.ID == saved.ID {
found = true
break
}
}
assert.True(t, found, "should find inserted consistency")
})
}

162
pkg/library/cassandra/db.go Normal file
View File

@ -0,0 +1,162 @@
package cassandra
import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3"
)
// DB 是 Cassandra 的核心資料庫連接
type DB struct {
session gocqlx.Session
defaultKeyspace string
version string
saiSupported bool
// 內部快取
metadataCache sync.Map // 重用現有的 metadata 快取邏輯
}
// New 創建新的 DB 實例
func New(opts ...Option) (*DB, error) {
cfg := defaultConfig()
for _, opt := range opts {
opt(cfg)
}
if len(cfg.Hosts) == 0 {
return nil, fmt.Errorf("at least one host is required")
}
// 建立連線設定
cluster := gocql.NewCluster(cfg.Hosts...)
cluster.Port = cfg.Port
cluster.Consistency = cfg.Consistency
cluster.Timeout = time.Duration(cfg.ConnectTimeoutSec) * time.Second
cluster.NumConns = cfg.NumConns
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
NumRetries: cfg.MaxRetries,
Min: cfg.RetryMinInterval,
Max: cfg.RetryMaxInterval,
}
cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{
MaxRetries: cfg.MaxRetries,
InitialInterval: cfg.ReconnectInitialInterval,
MaxInterval: cfg.ReconnectMaxInterval,
}
// 若有提供 Keyspace 則指定
if cfg.Keyspace != "" {
cluster.Keyspace = cfg.Keyspace
}
// 若啟用驗證則設定帳號密碼
if cfg.UseAuth {
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: cfg.Username,
Password: cfg.Password,
}
}
// 建立 Session
session, err := gocqlx.WrapSession(cluster.CreateSession())
if err != nil {
return nil, fmt.Errorf("failed to connect to Cassandra cluster (hosts: %v, port: %d): %w", cfg.Hosts, cfg.Port, err)
}
db := &DB{
session: session,
defaultKeyspace: cfg.Keyspace,
}
// 初始化版本資訊
version, err := db.getVersion(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get DB version: %w", err)
}
db.version = version
db.saiSupported = isSAISupported(version)
return db, nil
}
// Close 關閉資料庫連線
func (db *DB) Close() {
db.session.Close()
}
// GetSession 返回底層的 gocqlx Session用於進階操作
func (db *DB) GetSession() gocqlx.Session {
return db.session
}
// GetDefaultKeyspace 返回預設的 keyspace
func (db *DB) GetDefaultKeyspace() string {
return db.defaultKeyspace
}
// Version 返回資料庫版本
func (db *DB) Version() string {
return db.version
}
// SaiSupported 返回是否支援 SAI
func (db *DB) SaiSupported() bool {
return db.saiSupported
}
// getVersion 獲取資料庫版本
func (db *DB) getVersion(ctx context.Context) (string, error) {
var version string
stmt := "SELECT release_version FROM system.local"
err := db.session.Query(stmt, []string{"release_version"}).
WithContext(ctx).
Consistency(gocql.One).
Scan(&version)
return version, err
}
// isSAISupported 檢查版本是否支援 SAI
func isSAISupported(version string) bool {
// 只要 major >=5 就支援
// 4.0.9+ 才有 SAI但不穩強烈建議 5.0+
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
major, _ := strconv.Atoi(parts[0])
minor, _ := strconv.Atoi(parts[1])
if major >= 5 {
return true
}
if major == 4 {
if minor > 0 { // 4.1.x、4.2.x 直接支援
return true
}
if minor == 0 {
patch := 0
if len(parts) >= 3 {
patch, _ = strconv.Atoi(parts[2])
}
if patch >= 9 {
return true
}
}
}
return false
}
// withContextAndTimestamp 為查詢添加 context 和時間戳
func (db *DB) withContextAndTimestamp(ctx context.Context, q *gocqlx.Queryx) *gocqlx.Queryx {
return q.WithContext(ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
}

View File

@ -5,63 +5,32 @@ import (
"fmt" "fmt"
) )
// 定義統一的錯誤類型 // ErrorCode 定義錯誤代碼
var ( type ErrorCode string
// ErrNotFound 表示記錄未找到
ErrNotFound = &Error{
Code: "NOT_FOUND",
Message: "record not found",
}
// ErrAcquireLockFailed 表示獲取鎖失敗 const (
ErrAcquireLockFailed = &Error{ // ErrCodeNotFound 表示記錄未找到
Code: "LOCK_ACQUIRE_FAILED", ErrCodeNotFound ErrorCode = "NOT_FOUND"
Message: "acquire lock failed", // ErrCodeConflict 表示衝突(如唯一鍵衝突)
} ErrCodeConflict ErrorCode = "CONFLICT"
// ErrCodeInvalidInput 表示輸入參數無效
// ErrInvalidInput 表示輸入參數無效 ErrCodeInvalidInput ErrorCode = "INVALID_INPUT"
ErrInvalidInput = &Error{ // ErrCodeMissingPartition 表示缺少 Partition Key
Code: "INVALID_INPUT", ErrCodeMissingPartition ErrorCode = "MISSING_PARTITION_KEY"
Message: "invalid input parameter", // ErrCodeNoFieldsToUpdate 表示沒有欄位需要更新
} ErrCodeNoFieldsToUpdate ErrorCode = "NO_FIELDS_TO_UPDATE"
// ErrCodeMissingTableName 表示缺少 TableName 方法
// ErrNoPartitionKey 表示缺少 Partition Key ErrCodeMissingTableName ErrorCode = "MISSING_TABLE_NAME"
ErrNoPartitionKey = &Error{ // ErrCodeMissingWhereCondition 表示缺少 WHERE 條件
Code: "NO_PARTITION_KEY", ErrCodeMissingWhereCondition ErrorCode = "MISSING_WHERE_CONDITION"
Message: "no partition key defined in struct",
}
// ErrMissingTableName 表示缺少 TableName 方法
ErrMissingTableName = &Error{
Code: "MISSING_TABLE_NAME",
Message: "struct must implement TableName() method",
}
// ErrNoFieldsToUpdate 表示沒有欄位需要更新
ErrNoFieldsToUpdate = &Error{
Code: "NO_FIELDS_TO_UPDATE",
Message: "no fields to update",
}
// ErrMissingWhereCondition 表示缺少 WHERE 條件
ErrMissingWhereCondition = &Error{
Code: "MISSING_WHERE_CONDITION",
Message: "operation requires at least one WHERE condition for safety",
}
// ErrMissingPartitionKey 表示 WHERE 條件中缺少 Partition Key
ErrMissingPartitionKey = &Error{
Code: "MISSING_PARTITION_KEY",
Message: "operation requires all partition keys in WHERE clause",
}
) )
// Error 是統一的錯誤類型 // Error 是統一的錯誤類型
type Error struct { type Error struct {
Code string // 錯誤代碼 Code ErrorCode
Message string // 錯誤訊息 Message string
Table string // 相關的表名(可選) Table string
Err error // 底層錯誤(可選) Err error
} }
// Error 實現 error 介面 // Error 實現 error 介面
@ -104,19 +73,72 @@ func (e *Error) WithError(err error) *Error {
} }
// NewError 創建新的錯誤 // NewError 創建新的錯誤
func NewError(code, message string) *Error { func NewError(code ErrorCode, message string) *Error {
return &Error{ return &Error{
Code: code, Code: code,
Message: message, Message: message,
} }
} }
// IsNotFound 檢查錯誤是否為 NotFound // 預定義錯誤
func IsNotFound(err error) bool { var (
return errors.Is(err, ErrNotFound) // ErrNotFound 表示記錄未找到
ErrNotFound = &Error{
Code: ErrCodeNotFound,
Message: "record not found",
} }
// IsLockFailed 檢查錯誤是否為獲取鎖失敗 // ErrInvalidInput 表示輸入參數無效
func IsLockFailed(err error) bool { ErrInvalidInput = &Error{
return errors.Is(err, ErrAcquireLockFailed) Code: ErrCodeInvalidInput,
Message: "invalid input parameter",
}
// ErrNoPartitionKey 表示缺少 Partition Key
ErrNoPartitionKey = &Error{
Code: ErrCodeMissingPartition,
Message: "no partition key defined in struct",
}
// ErrMissingTableName 表示缺少 TableName 方法
ErrMissingTableName = &Error{
Code: ErrCodeMissingTableName,
Message: "struct must implement TableName() method",
}
// ErrNoFieldsToUpdate 表示沒有欄位需要更新
ErrNoFieldsToUpdate = &Error{
Code: ErrCodeNoFieldsToUpdate,
Message: "no fields to update",
}
// ErrMissingWhereCondition 表示缺少 WHERE 條件
ErrMissingWhereCondition = &Error{
Code: ErrCodeMissingWhereCondition,
Message: "operation requires at least one WHERE condition for safety",
}
// ErrMissingPartitionKey 表示 WHERE 條件中缺少 Partition Key
ErrMissingPartitionKey = &Error{
Code: ErrCodeMissingPartition,
Message: "operation requires all partition keys in WHERE clause",
}
)
// IsNotFound 檢查錯誤是否為 NotFound
func IsNotFound(err error) bool {
var e *Error
if errors.As(err, &e) {
return e.Code == ErrCodeNotFound
}
return false
}
// IsConflict 檢查錯誤是否為 Conflict
func IsConflict(err error) bool {
var e *Error
if errors.As(err, &e) {
return e.Code == ErrCodeConflict
}
return false
} }

View File

@ -1,285 +0,0 @@
package cassandra
import (
"context"
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/scylladb/gocqlx/v3/table"
)
/*
todo 目前尚未實作的部分但因為目前使用上並沒有嚴格一致性故目前簡易的版本可先行
1. 讀寫一致性問題
Cassandra 本身為最終一致性如果在 Commit 期間網路有短暫中斷可能造成部分操作成功部分失敗的半提交狀態
Commit 之後再次掃描 Steps看是否所有 IsExec 都為 true若有 false則觸發額外的重試或警示機制
2. 反射收集欄位的可靠度
Update 方法透過反射與 isZero 來排除不更新欄位但若結構體中出現自訂零值如自訂型態有預設值可能誤過濾掉真正要更新的欄位
可能在資料模型層先明確標示要更新的欄位列表或提供外部參數指明更新欄位以減少反射過濾錯誤
3. 交易邊界與隔離度
此實作並未提供交易隔離Isolation外部程式仍可能在交易尚未 Commit 時讀到中間狀態
若對讀取一致性有嚴格要求可考慮使用 Cassandra Lightweight TransactionsLWT搭配 IF NOT EXISTS / IF 條件確保寫入前的前置檢查
4. 錯誤重試與警示
Commit 中某個步驟失敗直接返回錯誤但沒有集中收集失敗資訊
建議整合一個監控與重試機制將失敗細節step index錯誤訊息記錄到外部持久化系統以便運維人員介入或自動重試
5. 崩潰恢復
如果程式在 Commit 過程中程式本身當掉記憶體中的 Steps 會丟失無法回滾
可以把 OperationLog 持久化到可靠的日誌表Cassandra 或外部 DBCommit 之前就先寫入並在啟動時掃描未完成的交易回滾或重試
*/
type Action int64
const (
ActionUnknown Action = iota
ActionInsert
ActionDelete
ActionUpdate
)
// OperationLog 記錄操作日誌,用於補償回滾
type OperationLog struct {
ID gocql.UUID // 操作ID用來標識該操作
Action Action // 操作類型(增、刪、改)
IsExec bool
Exec []*gocqlx.Queryx // 這一個步驟要執行的東西
OldData any // 變更前的數據,僅對修改和刪除有效
NewData any // 變更後的數據,僅對新增和修改有效
}
// CompensatingTransaction 補償式交易介面
// 這是一個基於補償操作Compensating Action的交易模式適用於最終一致性場景
// 與傳統 ACID 交易不同,它不提供隔離性保證,但可以確保「要嘛全成功,要嘛全失敗」
// 注意:這不是真正的原子性交易,而是透過記錄操作日誌並在失敗時執行補償操作來實現
type CompensatingTransaction interface {
Insert(ctx context.Context, document any) error
Delete(ctx context.Context, filter any) error
Update(ctx context.Context, document any) error
Rollback() error
Commit() error
}
// transaction 定義補償操作的結構
type transaction struct {
ctx context.Context
keyspace string
db *CassandraDB
Steps []OperationLog // 用來記錄所有操作步驟的日誌
}
// NewCompensatingTransaction 創建一個新的補償式交易
// keyspace 如果為空,則使用初始化時設定的預設 keyspace
func NewCompensatingTransaction(ctx context.Context, keyspace string, db *CassandraDB) CompensatingTransaction {
keyspace = getKeyspace(db, keyspace)
return &transaction{
ctx: ctx,
keyspace: keyspace,
db: db,
Steps: []OperationLog{},
}
}
// NewEZTransaction 創建一個新的補償式交易(向後相容的別名)
// Deprecated: 使用 NewCompensatingTransaction 代替
func NewEZTransaction(ctx context.Context, keyspace string, db *CassandraDB) CompensatingTransaction {
return NewCompensatingTransaction(ctx, keyspace, db)
}
func (tx *transaction) Insert(ctx context.Context, document any) error {
metadata, err := GenerateTableMetadata(document, tx.keyspace)
if err != nil {
return err
}
t := table.New(metadata)
q := qh.withContextAndTimestamp(ctx, tx.db.GetSession().Query(t.Insert()).BindStruct(document))
logEntry := OperationLog{
ID: gocql.TimeUUID(),
Action: ActionInsert,
Exec: []*gocqlx.Queryx{q},
NewData: document,
}
tx.Steps = append(tx.Steps, logEntry)
return nil
}
func (tx *transaction) Delete(ctx context.Context, filter any) error {
metadata, err := GenerateTableMetadata(filter, tx.keyspace)
if err != nil {
return err
}
t := table.New(metadata)
doc := filter
get := tx.db.GetSession().Query(t.Get()).BindStruct(doc).WithContext(ctx)
q := qh.withContextAndTimestamp(ctx, tx.db.GetSession().Query(t.Delete()).BindStruct(filter))
logEntry := OperationLog{
ID: gocql.TimeUUID(),
Action: ActionDelete,
Exec: []*gocqlx.Queryx{get, q}, // 有順序,要先拿取保留舊資料,
OldData: doc, // 保留結構體才有機會回復
}
tx.Steps = append(tx.Steps, logEntry)
return nil
}
func (tx *transaction) Update(ctx context.Context, document any) error {
metadata, err := GenerateTableMetadata(document, tx.keyspace)
if err != nil {
return err
}
t := table.New(metadata)
v := reflect.ValueOf(document)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
typ := v.Type()
// 收集更新欄位與其值(排除零值,保留主鍵)
setCols := make([]string, 0)
setVals := make([]any, 0)
whereCols := make([]string, 0)
whereVals := make([]any, 0)
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
val := v.Field(i)
if !val.IsValid() {
continue
}
if contains(metadata.PartKey, tag) || contains(metadata.SortKey, tag) {
whereCols = append(whereCols, tag)
whereVals = append(whereVals, val.Interface())
continue
}
if isZero(val) {
continue
}
setCols = append(setCols, tag)
setVals = append(setVals, val.Interface())
}
if len(setCols) == 0 {
return ErrNoFieldsToUpdate.WithTable(metadata.Name)
}
// Build UPDATE statement
builder := qb.Update(metadata.Name).Set(setCols...)
for _, col := range whereCols {
builder = builder.Where(qb.Eq(col))
}
stmt, names := builder.ToCql()
setVals = append(setVals, whereVals...)
q := qh.withContextAndTimestamp(ctx, tx.db.GetSession().Query(stmt, names).Bind(setVals...))
doc := document
get := tx.db.GetSession().Query(t.Get()).BindStruct(doc).WithContext(ctx)
logEntry := OperationLog{
ID: gocql.TimeUUID(),
Action: ActionUpdate,
Exec: []*gocqlx.Queryx{get, q}, // 有順序,要先拿取保留舊資料,才可以 update
OldData: doc, // 保留結構體才有機會回復
NewData: document,
}
tx.Steps = append(tx.Steps, logEntry)
return nil
}
func (tx *transaction) Rollback() error {
for _, item := range tx.Steps {
// 沒有做過的就不用回復了
if !item.IsExec {
continue
}
switch item.Action {
case ActionInsert:
err := tx.db.Delete(tx.ctx, item.NewData, tx.keyspace)
if err != nil {
// Rollback 失敗時繼續處理其他步驟,但最終會返回錯誤
// 注意:這裡不記錄日誌,因為 library 包不應該直接記錄日誌
// 調用者應該根據返回的錯誤進行日誌記錄
continue
}
case ActionUpdate:
err := tx.db.Update(tx.ctx, item.OldData, tx.keyspace)
if err != nil {
// Rollback 失敗時繼續處理其他步驟,但最終會返回錯誤
continue
}
case ActionDelete:
err := tx.db.Insert(tx.ctx, item.OldData, tx.keyspace)
if err != nil {
// Rollback 失敗時繼續處理其他步驟,但最終會返回錯誤
continue
}
}
}
return nil
}
func (tx *transaction) Commit() error {
for i, step := range tx.Steps {
switch step.Action {
case ActionInsert:
// 單純插入,不用回滾額外做事,插入的資料已經放在 New Data 裡面了
if err := step.Exec[0].ExecRelease(); err != nil {
return fmt.Errorf("failed to insert: %w", err)
}
// 標示為以執行,如果有錯誤要回復,指座椅執行的就好
tx.Steps[i].IsExec = true
case ActionUpdate:
// 要先 get 之後再 Update
// 單純插入,不用回滾額外做事,插入的資料已經放在 New Data 裡面了
if err := step.Exec[0].GetRelease(step.OldData); err != nil {
return fmt.Errorf("failed to get: %w", err)
}
if err := step.Exec[1].ExecRelease(); err != nil {
return fmt.Errorf("failed to update: %w", err)
}
// 標示為以執行,如果有錯誤要回復,指座椅執行的就好
tx.Steps[i].IsExec = true
case ActionDelete:
// 要先 get 之後再 Update
// 單純插入,不用回滾額外做事,插入的資料已經放在 New Data 裡面了
if err := step.Exec[0].GetRelease(step.OldData); err != nil {
return fmt.Errorf("failed to get: %w", err)
}
if err := step.Exec[1].ExecRelease(); err != nil {
return fmt.Errorf("failed to delete: %w", err)
}
// 標示為以執行,如果有錯誤要回復,指座椅執行的就好
tx.Steps[i].IsExec = true
default:
return fmt.Errorf("unknown action: %v", step.Action)
}
}
return nil
}

View File

@ -1,231 +0,0 @@
package cassandra
import (
"context"
"testing"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
)
type TE struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Name string `db:"name"`
}
func (m *TE) TableName() string {
return "test_entity"
}
func TestNewEZTransactionInsert(t *testing.T) {
ctx := context.Background()
err := cassandraDBTest.EnsureTable(`
CREATE TABLE IF NOT EXISTS my_keyspace.test_entity (
id UUID PRIMARY KEY,
name TEXT
);`)
assert.NoError(t, err)
// 定義 table-driven 測試案例
tests := []struct {
name string
doc TE
}{
{
name: "insert_record_alice",
doc: TE{
ID: gocql.TimeUUID(),
Name: "Alice",
},
},
{
name: "insert_record_bob",
doc: TE{
ID: gocql.TimeUUID(),
Name: "Bob",
},
},
{
name: "insert_record_empty_name",
doc: TE{
ID: gocql.TimeUUID(),
Name: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 每個子案例都使用新的 transaction
tx := NewEZTransaction(ctx, "my_keyspace", cassandraDBTest)
// 1. 呼叫 Insert
err := tx.Insert(ctx, &tt.doc)
assert.NoError(t, err, "Insert() 應該不會錯誤")
// 2. 呼叫 Commit真正寫入 Cassandra
err = tx.Commit()
assert.NoError(t, err, "Commit() 應該不會錯誤")
// 3. 從 Cassandra 查回資料,驗證
var got TE
got.ID = tt.doc.ID
err = cassandraDBTest.Get(ctx, &got, "my_keyspace")
assert.NoError(t, err)
// 驗證欄位值符合
assert.Equal(t, tt.doc.ID, got.ID, "ID 應一致")
assert.Equal(t, tt.doc.Name, got.Name, "Name 應一致")
})
}
}
func TestNewEZTransactionDelete(t *testing.T) {
ctx := context.Background()
err := cassandraDBTest.EnsureTable(`
CREATE TABLE IF NOT EXISTS my_keyspace.test_entity (
id UUID PRIMARY KEY,
name TEXT
);`)
assert.NoError(t, err)
// 定義 table-driven 測試案例
tests := []struct {
name string
doc TE
}{
{
name: "ok",
doc: TE{
ID: gocql.TimeUUID(),
Name: "Alice",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 每個子案例都使用新的 transaction
tx := NewEZTransaction(ctx, "my_keyspace", cassandraDBTest)
// 1. 呼叫 Delete
err := tx.Insert(ctx, &tt.doc)
assert.NoError(t, err, "Insert() 應該不會錯誤")
// 2. 呼叫 Delete
err = tx.Delete(ctx, &tt.doc)
assert.NoError(t, err, "Delete() 應該不會錯誤")
// 3. 呼叫 Commit真正寫入 Cassandra
err = tx.Commit()
assert.NoError(t, err, "Commit() 應該不會錯誤")
//
// 4. 從 Cassandra 查回資料,驗證
var got TE
got.ID = tt.doc.ID
err = cassandraDBTest.Get(ctx, &got, "my_keyspace")
assert.Equal(t, err, gocql.ErrNotFound)
})
}
}
func TestNewEZTransactionUpdate(t *testing.T) {
ctx := context.Background()
assert.NoError(t, cassandraDBTest.EnsureTable(`
CREATE TABLE IF NOT EXISTS my_keyspace.test_entity (
id UUID PRIMARY KEY,
name TEXT
);
`))
// 2. 插入初始資料
id := gocql.TimeUUID()
before := TE{ID: id, Name: "Before"}
assert.NoError(t, cassandraDBTest.Insert(ctx, &before, "my_keyspace"))
// 定義多組更新案例
tests := []struct {
name string
newName string
wantErr bool
}{
{name: "update_to_Alice", newName: "Alice"},
{name: "update_to_empty", newName: "", wantErr: true},
{name: "update_to_Bob", newName: "Bob"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 為每個案例都重置為 Before
// 重新 insert 一次(覆蓋舊值)
assert.NoError(t, cassandraDBTest.Insert(ctx, &before, "my_keyspace"))
// 3. 建立 transaction 並呼叫 Update
tx := NewEZTransaction(ctx, "my_keyspace", cassandraDBTest)
updateDoc := TE{ID: id, Name: tt.newName}
err := tx.Update(ctx, &updateDoc)
if tt.wantErr {
assert.Error(t, err, "Update() 應該會出錯")
return
}
assert.NoError(t, err, "Update() 不應出錯")
// 4. Commit 實際寫入
err = tx.Commit()
assert.NoError(t, err, "Commit() 不應出錯")
// 5. 查詢並驗證
var got TE
got.ID = id
err = cassandraDBTest.Get(ctx, &got, "my_keyspace")
assert.NoError(t, err, "db.Get() 應成功")
assert.Equal(t, id, got.ID, "ID 應一致")
assert.Equal(t, tt.newName, got.Name, "Name 應被更新為最新值")
})
}
}
func Test_Rollback(t *testing.T) {
ctx := context.Background()
assert.NoError(t, cassandraDBTest.EnsureTable(`
CREATE TABLE IF NOT EXISTS my_keyspace.test_entity (
id UUID PRIMARY KEY,
name TEXT
);
`))
// 3. 用 Transaction 插入一筆資料,並 Commit
id := gocql.TimeUUID()
doc := TE{ID: id, Name: "Alice"}
tx := NewEZTransaction(ctx, "my_keyspace", cassandraDBTest)
err := tx.Insert(ctx, &doc)
assert.NoError(t, err)
err = tx.Commit()
assert.NoError(t, err)
// 4. Query 確認資料已存在
var got TE
got.ID = id
err = cassandraDBTest.Get(ctx, &got, "my_keyspace")
assert.NoError(t, err)
assert.Equal(t, got.Name, doc.Name)
// 5. 呼叫 Rollback應自動刪除剛剛那筆
err = tx.Rollback()
assert.NoError(t, err)
var afterGot TE
afterGot.ID = id
err = cassandraDBTest.Get(ctx, &afterGot, "my_keyspace")
assert.Error(t, err)
// Output:
// after commit: Alice
// after rollback: not found
}

View File

@ -1,74 +0,0 @@
module gitlab.supermicro.com/infra/infra-core/storage/cassandra
go 1.24.2
require (
github.com/gocql/gocql v1.7.0
github.com/scylladb/gocqlx/v3 v3.0.1
github.com/stretchr/testify v1.10.0
github.com/testcontainers/testcontainers-go v0.37.0
)
require (
dario.cat/mergo v1.0.1 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.0.1+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.2 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/sys/user v0.1.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/scylladb/go-reflectx v1.0.1 // indirect
github.com/shirou/gopsutil/v4 v4.25.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.35.0 // indirect
go.opentelemetry.io/otel/sdk v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.opentelemetry.io/proto/otlp v1.3.1 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/net v0.35.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/time v0.10.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
google.golang.org/grpc v1.65.0 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,207 +0,0 @@
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.0.1+incompatible h1:FCHjSRdXhNRFjlHMTv4jUNlIBbTeRjrWfeFuJp7jpo0=
github.com/docker/docker v28.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I=
github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5lXtc=
github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo=
github.com/moby/sys/user v0.1.0 h1:WmZ93f5Ux6het5iituh9x2zAG7NFY9Aqi49jjE1PaQg=
github.com/moby/sys/user v0.1.0/go.mod h1:fKJhFOnsCN6xZ5gSfbM6zaHGgDJMrqt9/reuj4T7MmU=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/scylladb/go-reflectx v1.0.1 h1:b917wZM7189pZdlND9PbIJ6NQxfDPfBvUaQ7cjj1iZQ=
github.com/scylladb/go-reflectx v1.0.1/go.mod h1:rWnOfDIRWBGN0miMLIcoPt/Dhi2doCMZqwMCJ3KupFc=
github.com/scylladb/gocqlx/v3 v3.0.1 h1:JBvOUBz62LQ2lbIgJqQbwVMiDftbtrJSi63KVxvRYOQ=
github.com/scylladb/gocqlx/v3 v3.0.1/go.mod h1:EjbSZM0VR2a57ZUxCRQ3v3CSoWIkH1WTMwxeDbFQorY=
github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs=
github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/testcontainers/testcontainers-go v0.37.0 h1:L2Qc0vkTw2EHWQ08djon0D2uw7Z/PtHS/QzZZ5Ra/hg=
github.com/testcontainers/testcontainers-go v0.37.0/go.mod h1:QPzbxZhQ6Bclip9igjLFj6z0hs01bU8lrl2dHQmgFGM=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0 h1:t6wl9SPayj+c7lEIFgm4ooDBZVb01IhLB4InpomhRw8=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0/go.mod h1:iSDOcsnSA5INXzZtwaBPrKp/lWu/V14Dd+llD0oI2EA=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.24.0 h1:Xw8U6u2f8DK2XAkGRFV7BBLENgnTGX9i4rQRxJf+/vs=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.24.0/go.mod h1:6KW1Fm6R/s6Z3PGXwSJN2K4eT6wQB3vXX6CVnYX9NmM=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/sdk v1.24.0 h1:YMPPDNymmQN3ZgczicBY3B6sf9n62Dlj9pWD3ucgoDw=
go.opentelemetry.io/otel/sdk v1.24.0/go.mod h1:KVrIYw6tEubO9E96HQpcmpTKDVn9gdv35HoYiQWGDFg=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d h1:kHjw/5UfflP/L5EbledDrcG4C2597RtymmGRZvHiCuY=
google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d/go.mod h1:mw8MG/Qz5wfgYr6VqVCiZcHe/GJEfI+oGGDCohaVgB0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:BwIjyKYGsK9dMCBOorzRri8MQwmi7mT9rGHsCEinZkA=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=

View File

@ -11,13 +11,11 @@ import (
) )
const ( const (
defaultTTLSec = 30 defaultLockTTLSec = 30
defaultRetry = 3 defaultLockRetry = 3
baseDelay = 100 * time.Millisecond lockBaseDelay = 100 * time.Millisecond
) )
// 使用 error.go 中定義的統一錯誤
// LockOption 用來設定 TryLock 的 TTL 行為 // LockOption 用來設定 TryLock 的 TTL 行為
type LockOption func(*lockOptions) type LockOption func(*lockOptions)
@ -25,6 +23,7 @@ type lockOptions struct {
ttlSeconds int // TTL單位秒<=0 代表不 expire ttlSeconds int // TTL單位秒<=0 代表不 expire
} }
// WithLockTTL 設定鎖的 TTL
func WithLockTTL(d time.Duration) LockOption { func WithLockTTL(d time.Duration) LockOption {
return func(o *lockOptions) { return func(o *lockOptions) {
o.ttlSeconds = int(d.Seconds()) o.ttlSeconds = int(d.Seconds())
@ -40,30 +39,17 @@ func WithNoLockExpire() LockOption {
// TryLock 嘗試在表上插入一筆唯一鍵IF NOT EXISTS作為鎖 // TryLock 嘗試在表上插入一筆唯一鍵IF NOT EXISTS作為鎖
// 預設 30 秒 TTL可透過 option 調整或取消 TTL // 預設 30 秒 TTL可透過 option 調整或取消 TTL
// keyspace 如果為空,則使用初始化時設定的預設 keyspace func (r *repository[T]) TryLock(ctx context.Context, doc T, opts ...LockOption) error {
func (db *CassandraDB) TryLock( // 組合 option
ctx context.Context, options := &lockOptions{ttlSeconds: defaultLockTTLSec}
document any,
keyspace string,
opts ...LockOption,
) error {
keyspace = getKeyspace(db, keyspace)
// 1. 解析 metadata
metadata, err := GenerateTableMetadata(document, keyspace)
if err != nil {
return err
}
// 2. 組合 option
options := &lockOptions{ttlSeconds: defaultTTLSec}
for _, opt := range opts { for _, opt := range opts {
opt(options) opt(options)
} }
// 3. 建 TTL 子句 // 建 TTL 子句
builder := qb.Insert(metadata.Name). builder := qb.Insert(r.table).
Unique(). // IF NOT EXISTS Unique(). // IF NOT EXISTS
Columns(metadata.Columns...) Columns(r.metadata.Columns...)
if options.ttlSeconds > 0 { if options.ttlSeconds > 0 {
ttl := time.Duration(options.ttlSeconds) * time.Second ttl := time.Duration(options.ttlSeconds) * time.Second
@ -71,51 +57,36 @@ func (db *CassandraDB) TryLock(
} }
stmt, names := builder.ToCql() stmt, names := builder.ToCql()
// 4. 執行 CAS // 執行 CAS
q := db.GetSession().Query(stmt, names).BindStruct(document). q := r.db.session.Query(stmt, names).BindStruct(doc).
WithContext(ctx). WithContext(ctx).
WithTimestamp(time.Now().UnixNano() / 1e3). WithTimestamp(time.Now().UnixNano() / 1e3).
SerialConsistency(gocql.Serial) SerialConsistency(gocql.Serial)
applied, err := q.ExecCASRelease() applied, err := q.ExecCASRelease()
if err != nil { if err != nil {
return err return ErrInvalidInput.WithTable(r.table).WithError(err)
} }
if !applied { if !applied {
return ErrAcquireLockFailed.WithTable(metadata.Name) return NewError(ErrCodeConflict, "acquire lock failed").WithTable(r.table)
} }
return nil return nil
} }
// UnLock 釋放鎖,其實就是 Delete // UnLock 釋放鎖,其實就是 Delete
// keyspace 如果為空,則使用初始化時設定的預設 keyspace func (r *repository[T]) UnLock(ctx context.Context, doc T) error {
func (db *CassandraDB) UnLock(ctx context.Context, filter any, keyspace string) error {
keyspace = getKeyspace(db, keyspace)
if filter == nil {
return errors.New("unlock: filter cannot be nil")
}
metadata, err := GenerateTableMetadata(filter, keyspace)
if err != nil {
return fmt.Errorf("unlock: failed to generate metadata: %w", err)
}
if len(metadata.Columns) == 0 {
return fmt.Errorf("unlock: missing primary key in struct (table: %s)", metadata.Name)
}
var lastErr error var lastErr error
for i := 0; i < defaultRetry; i++ { for i := 0; i < defaultLockRetry; i++ {
builder := qb.Delete(metadata.Name).Existing() builder := qb.Delete(r.table).Existing()
// 動態添加 WHERE 條件 // 動態添加 WHERE 條件(使用 Partition Key
for _, key := range metadata.PartKey { for _, key := range r.metadata.PartKey {
builder = builder.Where(qb.Eq(key)) builder = builder.Where(qb.Eq(key))
} }
stmt, names := builder.ToCql() stmt, names := builder.ToCql()
q := db.GetSession().Query(stmt, names).BindStruct(filter). q := r.db.session.Query(stmt, names).BindStruct(doc).
WithContext(ctx). WithContext(ctx).
WithTimestamp(time.Now().UnixNano() / 1e3). WithTimestamp(time.Now().UnixNano() / 1e3).
SerialConsistency(gocql.Serial) SerialConsistency(gocql.Serial)
@ -126,13 +97,24 @@ func (db *CassandraDB) UnLock(ctx context.Context, filter any, keyspace string)
} }
if err != nil { if err != nil {
lastErr = fmt.Errorf("unlock: execution failed (table: %s, attempt: %d/%d): %w", metadata.Name, i+1, defaultRetry, err) lastErr = fmt.Errorf("unlock error: %w", err)
} else if !applied { } else if !applied {
lastErr = fmt.Errorf("unlock: operation not applied - row not found or not visible yet (table: %s)", metadata.Name) lastErr = fmt.Errorf("unlock not applied: row not found or not visible yet")
} }
time.Sleep(baseDelay * time.Duration(1<<i)) // 100ms → 200ms → 400ms time.Sleep(lockBaseDelay * time.Duration(1<<i)) // 100ms → 200ms → 400ms
} }
return fmt.Errorf("unlock: failed after %d retries (table: %s): %w", defaultRetry, metadata.Name, lastErr) return ErrInvalidInput.WithTable(r.table).WithError(
fmt.Errorf("unlock failed after %d retries: %w", defaultLockRetry, lastErr),
)
}
// IsLockFailed 檢查錯誤是否為獲取鎖失敗
func IsLockFailed(err error) bool {
var e *Error
if errors.As(err, &e) {
return e.Code == ErrCodeConflict && e.Message == "acquire lock failed"
}
return false
} }

View File

@ -1,64 +0,0 @@
package cassandra
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type LockTest struct {
ID string `db:"id" partition_key:"true"`
Holder string `db:"holder"`
}
func (l *LockTest) TableName() string { return "lock_test" }
func TestTryLockAndUnLock(t *testing.T) {
ctx := context.Background()
assert.NoError(t, cassandraDBTest.EnsureTable(`
CREATE TABLE IF NOT EXISTS my_keyspace.lock_test (
id TEXT PRIMARY KEY,
holder TEXT
);
`))
lockID := "lock-123"
holder := "daniel"
lockDoc := &LockTest{ID: lockID, Holder: holder}
t.Run("acquire lock - success", func(t *testing.T) {
err := cassandraDBTest.TryLock(ctx, lockDoc, "my_keyspace")
assert.NoError(t, err, "TryLock 應該成功")
})
t.Run("acquire lock again - fail", func(t *testing.T) {
err := cassandraDBTest.TryLock(ctx, lockDoc, "my_keyspace")
assert.Error(t, err, "重複上鎖應該失敗")
})
t.Run("unlock", func(t *testing.T) {
err := cassandraDBTest.UnLock(ctx, lockDoc, "my_keyspace")
assert.NoError(t, err, "UnLock 應該成功")
})
t.Run("lock with TTL", func(t *testing.T) {
lockWithTTL := &LockTest{ID: "lock-ttl", Holder: "jack"}
err := cassandraDBTest.TryLock(ctx, lockWithTTL, "my_keyspace", WithLockTTL(2*time.Second))
assert.NoError(t, err)
// 兩秒後嘗試再次上鎖應該成功TTL 過期)
time.Sleep(3 * time.Second)
err = cassandraDBTest.TryLock(ctx, lockWithTTL, "my_keyspace")
assert.NoError(t, err)
_ = cassandraDBTest.UnLock(ctx, lockWithTTL, "my_keyspace")
})
t.Run("unlock not exist", func(t *testing.T) {
nonExist := &LockTest{ID: "not-exist", Holder: "nobody"}
err := cassandraDBTest.UnLock(ctx, nonExist, "my_keyspace")
assert.Error(t, err, "unlock 不存在的鎖應該失敗")
})
}

View File

@ -3,26 +3,63 @@ package cassandra
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sync"
"unicode"
"github.com/scylladb/gocqlx/v3/table" "github.com/scylladb/gocqlx/v3/table"
) )
// GenerateTableMetadata 根據傳入的 struct 產生 table.Metadata var (
func GenerateTableMetadata(document any, keyspace string) (table.Metadata, error) { // metadataCache 快取已生成的 Metadata避免重複反射解析
// 取得型別資訊,若是指標則取 Elem // key: tableName + ":" + structType (不包含 keyspace因為同一個 struct 在不同 keyspace 結構相同)
t := reflect.TypeOf(document) metadataCache sync.Map
)
type cachedMetadata struct {
columns []string
partKeys []string
sortKeys []string
err error
}
// generateMetadata 根據傳入的 struct 產生 table.Metadata
// 使用快取機制避免重複反射解析,提升效能
func generateMetadata[T Table](doc T, keyspace string) (table.Metadata, error) {
// 取得型別資訊
t := reflect.TypeOf(doc)
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Ptr {
t = t.Elem() t = t.Elem()
} }
// 取得表名稱:若 model 有實作 TableName() 則使用該方法,否則轉換型別名稱為 snake_case // 取得表名稱
var tableName string tableName := doc.TableName()
if tm, ok := document.(interface{ TableName() string }); ok { if tableName == "" {
tableName = fmt.Sprintf("%s.%s", keyspace, tm.TableName())
} else {
return table.Metadata{}, ErrMissingTableName return table.Metadata{}, ErrMissingTableName
} }
// 構建快取 key: tableName:structType (不包含 keyspace)
cacheKey := fmt.Sprintf("%s:%s", tableName, t.String())
// 檢查快取
if cached, ok := metadataCache.Load(cacheKey); ok {
cachedMeta := cached.(cachedMetadata)
if cachedMeta.err != nil {
return table.Metadata{}, cachedMeta.err
}
// 從快取構建 metadata動態加上 keyspace
meta := table.Metadata{
Name: fmt.Sprintf("%s.%s", keyspace, tableName),
Columns: make([]string, len(cachedMeta.columns)),
PartKey: make([]string, len(cachedMeta.partKeys)),
SortKey: make([]string, len(cachedMeta.sortKeys)),
}
copy(meta.Columns, cachedMeta.columns)
copy(meta.PartKey, cachedMeta.partKeys)
copy(meta.SortKey, cachedMeta.sortKeys)
return meta, nil
}
// 快取未命中,生成 metadata
columns := make([]string, 0, t.NumField()) columns := make([]string, 0, t.NumField())
partKeys := make([]string, 0, t.NumField()) partKeys := make([]string, 0, t.NumField())
sortKeys := make([]string, 0, t.NumField()) sortKeys := make([]string, 0, t.NumField())
@ -44,22 +81,36 @@ func GenerateTableMetadata(document any, keyspace string) (table.Metadata, error
colName = toSnakeCase(field.Name) colName = toSnakeCase(field.Name)
} }
columns = append(columns, colName) columns = append(columns, colName)
// 若有 partition:"true" 標記,加入 PartKey // 若有 partition_key:"true" 標記,加入 PartKey
if field.Tag.Get("partition_key") == "true" { if field.Tag.Get("partition_key") == "true" {
partKeys = append(partKeys, colName) partKeys = append(partKeys, colName)
} }
// 若有 sort:"true" 標記,加入 SortKey // 若有 clustering_key:"true" 標記,加入 SortKey
if field.Tag.Get("clustering_key") == "true" { if field.Tag.Get("clustering_key") == "true" {
sortKeys = append(sortKeys, colName) sortKeys = append(sortKeys, colName)
} }
} }
if len(partKeys) == 0 { if len(partKeys) == 0 {
return table.Metadata{}, ErrNoPartitionKey err := ErrNoPartitionKey
// 快取錯誤結果
metadataCache.Store(cacheKey, cachedMetadata{err: err})
return table.Metadata{}, err
} }
// 組合 Metadata // 快取成功結果(只存結構資訊,不包含 keyspace
cachedMeta := cachedMetadata{
columns: make([]string, len(columns)),
partKeys: make([]string, len(partKeys)),
sortKeys: make([]string, len(sortKeys)),
}
copy(cachedMeta.columns, columns)
copy(cachedMeta.partKeys, partKeys)
copy(cachedMeta.sortKeys, sortKeys)
metadataCache.Store(cacheKey, cachedMeta)
// 組合並返回 Metadata包含 keyspace
meta := table.Metadata{ meta := table.Metadata{
Name: tableName, Name: fmt.Sprintf("%s.%s", keyspace, tableName),
Columns: columns, Columns: columns,
PartKey: partKeys, PartKey: partKeys,
SortKey: sortKeys, SortKey: sortKeys,
@ -67,3 +118,19 @@ func GenerateTableMetadata(document any, keyspace string) (table.Metadata, error
return meta, nil return meta, nil
} }
// toSnakeCase 將 CamelCase 字串轉換為 snake_case
func toSnakeCase(s string) string {
var result []rune
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
result = append(result, '_')
}
result = append(result, unicode.ToLower(r))
} else {
result = append(result, r)
}
}
return string(result)
}

View File

@ -1,74 +0,0 @@
package cassandra
import (
"testing"
"github.com/scylladb/gocqlx/v3/table"
"github.com/stretchr/testify/assert"
)
// TestGenerateTableMetadata 測試 GenerateTableMetadata 函式
func TestGenerateTableMetadata(t *testing.T) {
testCases := []struct {
name string
document any
expected table.Metadata
expectError bool
}{
{
name: "MonkeyEntity with TableName",
document: &MonkeyEntity{},
expected: table.Metadata{
Name: "test.monkey_entity",
Columns: []string{"id", "name", "update_at", "create_at"},
PartKey: []string{"id"},
SortKey: []string{"name"},
},
expectError: false,
},
{
name: "Animal without TableName, type name converted to snake_case",
document: &Animal{},
expected: table.Metadata{
Name: "test.animal",
Columns: []string{"id", "type"},
PartKey: []string{"id"},
SortKey: []string{},
},
expectError: false,
},
{
name: "InvalidEntity without partition key",
document: &InvalidEntity{},
expected: table.Metadata{},
expectError: true,
},
{
name: "CatEntity with TableName",
document: &CatEntity{},
expected: table.Metadata{
Name: "test.cat_entity",
Columns: []string{"id", "name", "update_at", "create_at"},
PartKey: []string{"id", "name"},
SortKey: []string{"create_at"},
},
expectError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
meta, err := GenerateTableMetadata(tc.document, "test")
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
// 比較 Metadata 的各個欄位
assert.Equal(t, tc.expected.Name, meta.Name, "table name mismatch")
assert.Equal(t, tc.expected.Columns, meta.Columns, "columns mismatch")
assert.Equal(t, tc.expected.PartKey, meta.PartKey, "partition keys mismatch")
assert.Equal(t, tc.expected.SortKey, meta.SortKey, "sort keys mismatch")
}
})
}
}

View File

@ -3,142 +3,160 @@ package cassandra
import ( import (
"time" "time"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/gocql/gocql" "github.com/gocql/gocql"
) )
// Option 是設定選項的函數型別 // config 是初始化 DB 所需的內部設定(私有)
type Option func(*cassandraConf) type config struct {
Hosts []string // Cassandra 主機列表
Port int // 連線埠
Keyspace string // 預設使用的 Keyspace
Username string // 認證用戶名
Password string // 認證密碼
Consistency gocql.Consistency // 一致性級別
ConnectTimeoutSec int // 連線逾時秒數
NumConns int // 每個節點連線數
MaxRetries int // 重試次數
UseAuth bool // 是否使用帳號密碼驗證
RetryMinInterval time.Duration // 重試間隔最小值
RetryMaxInterval time.Duration // 重試間隔最大值
ReconnectInitialInterval time.Duration // 重連初始間隔
ReconnectMaxInterval time.Duration // 重連最大間隔
CQLVersion string // 執行連線的CQL 版本號
}
// defaultConfig 返回預設配置
func defaultConfig() *config {
return &config{
Port: defaultPort,
Consistency: defaultConsistency,
ConnectTimeoutSec: defaultTimeoutSec,
NumConns: defaultNumConns,
MaxRetries: defaultMaxRetries,
RetryMinInterval: defaultRetryMinInterval,
RetryMaxInterval: defaultRetryMaxInterval,
ReconnectInitialInterval: defaultReconnectInitialInterval,
ReconnectMaxInterval: defaultReconnectMaxInterval,
CQLVersion: defaultCqlVersion,
}
}
// Option 是設定選項的函數型別
type Option func(*config)
// WithHosts 設定 Cassandra 主機列表
func WithHosts(hosts ...string) Option {
return func(c *config) {
c.Hosts = hosts
}
}
// WithPort 設定連線埠
func WithPort(port int) Option { func WithPort(port int) Option {
return func(c *cassandraConf) { return func(c *config) {
c.Port = port c.Port = port
} }
} }
// WithKeyspace 設定預設 keyspace
func WithKeyspace(keyspace string) Option { func WithKeyspace(keyspace string) Option {
return func(c *cassandraConf) { return func(c *config) {
c.Keyspace = keyspace c.Keyspace = keyspace
} }
} }
// WithAuth 設定認證資訊
func WithAuth(username, password string) Option { func WithAuth(username, password string) Option {
return func(c *cassandraConf) { return func(c *config) {
c.Username = username c.Username = username
c.Password = password c.Password = password
c.UseAuth = true c.UseAuth = true
} }
} }
// WithConsistency is used to set the consistency level, default is Quorum // WithConsistency 設定一致性級別
func WithConsistency(consistency gocql.Consistency) Option { func WithConsistency(consistency gocql.Consistency) Option {
return func(c *cassandraConf) { return func(c *config) {
c.Consistency = consistency c.Consistency = consistency
} }
} }
// WithConnectTimeoutSec is used to set the connect timeout, default is 10 seconds // WithConnectTimeoutSec 設定連線逾時秒數
func WithConnectTimeoutSec(timeout int) Option { func WithConnectTimeoutSec(timeout int) Option {
return func(c *cassandraConf) { return func(c *config) {
if timeout <= 0 { if timeout <= 0 {
timeout = defaultTimeoutSec timeout = defaultTimeoutSec
} }
c.ConnectTimeoutSec = timeout c.ConnectTimeoutSec = timeout
} }
} }
// WithNumConns is used to set the number of connections to each node, default is 10 // WithNumConns 設定每個節點的連線數
func WithNumConns(numConns int) Option { func WithNumConns(numConns int) Option {
return func(c *cassandraConf) { return func(c *config) {
if numConns <= 0 { if numConns <= 0 {
numConns = defaultNumConns numConns = defaultNumConns
} }
c.NumConns = numConns c.NumConns = numConns
} }
} }
// WithMaxRetries is used to set the maximum retries, default is 3 // WithMaxRetries 設定最大重試次數
func WithMaxRetries(maxRetries int) Option { func WithMaxRetries(maxRetries int) Option {
return func(c *cassandraConf) { return func(c *config) {
if maxRetries <= 0 { if maxRetries <= 0 {
maxRetries = defaultMaxRetries maxRetries = defaultMaxRetries
} }
c.MaxRetries = maxRetries c.MaxRetries = maxRetries
} }
} }
// WithRetryMinInterval is used to set the minimum retry interval, default is 1 second // WithRetryMinInterval 設定最小重試間隔
func WithRetryMinInterval(duration time.Duration) Option { func WithRetryMinInterval(duration time.Duration) Option {
return func(c *cassandraConf) { return func(c *config) {
if duration <= 0 { if duration <= 0 {
duration = defaultRetryMinInterval duration = defaultRetryMinInterval
} }
c.RetryMinInterval = duration c.RetryMinInterval = duration
} }
} }
// WithRetryMaxInterval is used to set the maximum retry interval, default is 30 seconds // WithRetryMaxInterval 設定最大重試間隔
func WithRetryMaxInterval(duration time.Duration) Option { func WithRetryMaxInterval(duration time.Duration) Option {
return func(c *cassandraConf) { return func(c *config) {
if duration <= 0 { if duration <= 0 {
duration = defaultRetryMaxInterval duration = defaultRetryMaxInterval
} }
c.RetryMaxInterval = duration c.RetryMaxInterval = duration
} }
} }
// WithReconnectInitialInterval is used to set the initial reconnect interval, default is 1 second // WithReconnectInitialInterval 設定初始重連間隔
func WithReconnectInitialInterval(duration time.Duration) Option { func WithReconnectInitialInterval(duration time.Duration) Option {
return func(c *cassandraConf) { return func(c *config) {
if duration <= 0 { if duration <= 0 {
duration = defaultReconnectInitialInterval duration = defaultReconnectInitialInterval
} }
c.ReconnectInitialInterval = duration c.ReconnectInitialInterval = duration
} }
} }
// WithReconnectMaxInterval is used to set the maximum reconnect interval, default is 60 seconds // WithReconnectMaxInterval 設定最大重連間隔
func WithReconnectMaxInterval(duration time.Duration) Option { func WithReconnectMaxInterval(duration time.Duration) Option {
return func(c *cassandraConf) { return func(c *config) {
if duration <= 0 { if duration <= 0 {
duration = defaultReconnectMaxInterval duration = defaultReconnectMaxInterval
} }
c.ReconnectMaxInterval = duration c.ReconnectMaxInterval = duration
} }
} }
// WithCQLVersion is used to set the CQL version, default is 3.0.0 // WithCQLVersion 設定 CQL 版本
func WithCQLVersion(version string) Option { func WithCQLVersion(version string) Option {
return func(c *cassandraConf) { return func(c *config) {
if version == "" { if version == "" {
version = defaultCqlVersion version = defaultCqlVersion
} }
c.CQLVersion = version c.CQLVersion = version
} }
} }
// ===============================================================
// QueryOption defines a function that modifies a query builder
type QueryOption func(*qb.SelectBuilder, qb.M)
// WithWhere adds WHERE clauses to the query
func WithWhere(where []qb.Cmp, args map[string]any) QueryOption {
return func(b *qb.SelectBuilder, bind qb.M) {
if len(where) > 0 {
b.Where(where...)
for k, v := range args {
bind[k] = v
}
}
}
}

View File

@ -1,158 +0,0 @@
package cassandra
import (
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/assert"
)
func TestOptions(t *testing.T) {
tests := []struct {
name string
option Option
check func(conf *cassandraConf)
}{
{
name: "WithPort",
option: WithPort(1234),
check: func(conf *cassandraConf) {
assert.Equal(t, 1234, conf.Port, "Port 設定錯誤")
},
},
{
name: "WithKeyspace",
option: WithKeyspace("my_keyspace"),
check: func(conf *cassandraConf) {
assert.Equal(t, "my_keyspace", conf.Keyspace, "Keyspace 設定錯誤")
},
},
{
name: "WithAuth",
option: WithAuth("user", "pass"),
check: func(conf *cassandraConf) {
assert.Equal(t, "user", conf.Username, "Username 設定錯誤")
assert.Equal(t, "pass", conf.Password, "Password 設定錯誤")
assert.True(t, conf.UseAuth, "UseAuth 應該為 true")
},
},
{
name: "WithConsistency",
option: WithConsistency(gocql.Quorum),
check: func(conf *cassandraConf) {
assert.Equal(t, gocql.Quorum, conf.Consistency, "Consistency 設定錯誤")
},
},
{
name: "WithConnectTimeoutSec",
option: WithConnectTimeoutSec(45),
check: func(conf *cassandraConf) {
assert.Equal(t, 45, conf.ConnectTimeoutSec, "ConnectTimeoutSec 設定錯誤")
},
},
{
name: "WithConnectTimeoutSec_default",
option: WithConnectTimeoutSec(0),
check: func(conf *cassandraConf) {
assert.Equal(t, defaultTimeoutSec, conf.ConnectTimeoutSec, "ConnectTimeoutSec 設定錯誤")
},
},
{
name: "WithNumConns",
option: WithNumConns(10),
check: func(conf *cassandraConf) {
assert.Equal(t, 10, conf.NumConns, "NumConns 設定錯誤")
},
},
{
name: "WithNumConns_default",
option: WithNumConns(0),
check: func(conf *cassandraConf) {
assert.Equal(t, defaultNumConns, conf.NumConns, "NumConns 設定錯誤")
},
},
{
name: "WithMaxRetries",
option: WithMaxRetries(5),
check: func(conf *cassandraConf) {
assert.Equal(t, 5, conf.MaxRetries, "MaxRetries 設定錯誤")
},
},
{
name: "WithMaxRetries_default",
option: WithMaxRetries(0),
check: func(conf *cassandraConf) {
assert.Equal(t, defaultMaxRetries, conf.MaxRetries, "MaxRetries 設定錯誤")
},
},
{
name: "WithRetryMinInterval",
option: WithRetryMinInterval(2 * time.Second),
check: func(conf *cassandraConf) {
assert.Equal(t, 2*time.Second, conf.RetryMinInterval, "RetryMinInterval 設定錯誤")
},
},
{
name: "WithRetryMinInterval_default",
option: WithRetryMinInterval(0),
check: func(conf *cassandraConf) {
assert.Equal(t, defaultRetryMinInterval, conf.RetryMinInterval, "RetryMinInterval 設定錯誤")
},
},
{
name: "WithRetryMaxInterval",
option: WithRetryMaxInterval(10 * time.Second),
check: func(conf *cassandraConf) {
assert.Equal(t, 10*time.Second, conf.RetryMaxInterval, "RetryMaxInterval 設定錯誤")
},
},
{
name: "WithRetryMaxInterval_default",
option: WithRetryMaxInterval(0),
check: func(conf *cassandraConf) {
assert.Equal(t, defaultRetryMaxInterval, conf.RetryMaxInterval, "RetryMaxInterval 設定錯誤")
},
},
{
name: "WithReconnectInitialInterval",
option: WithReconnectInitialInterval(1 * time.Second),
check: func(conf *cassandraConf) {
assert.Equal(t, 1*time.Second, conf.ReconnectInitialInterval, "ReconnectInitialInterval 設定錯誤")
},
},
{
name: "WithReconnectInitialInterval_default",
option: WithReconnectInitialInterval(0),
check: func(conf *cassandraConf) {
assert.Equal(t, defaultReconnectInitialInterval, conf.ReconnectInitialInterval, "ReconnectInitialInterval 設定錯誤")
},
},
{
name: "WithReconnectMaxInterval",
option: WithReconnectMaxInterval(10 * time.Second),
check: func(conf *cassandraConf) {
assert.Equal(t, 10*time.Second, conf.ReconnectMaxInterval, "ReconnectMaxInterval 設定錯誤")
},
},
{
name: "WithReconnectMaxInterval_default",
option: WithReconnectMaxInterval(0),
check: func(conf *cassandraConf) {
assert.Equal(t, defaultReconnectMaxInterval, conf.ReconnectMaxInterval, "ReconnectMaxInterval 設定錯誤")
},
},
}
for _, tc := range tests {
tc := tc // 避免 closure 捕捉迴圈變數
t.Run(tc.name, func(t *testing.T) {
// 為每個測試案例產生一個新的 cassandraConf 實例
conf := &cassandraConf{}
// 套用 Option
tc.option(conf)
// 執行檢查
tc.check(conf)
})
}
}

View File

@ -0,0 +1,226 @@
package cassandra
import (
"context"
"fmt"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3/qb"
)
// Condition 定義查詢條件介面
type Condition interface {
Build() (qb.Cmp, map[string]any)
}
// Eq 等於條件
func Eq(column string, value any) Condition {
return &eqCondition{column: column, value: value}
}
type eqCondition struct {
column string
value any
}
func (c *eqCondition) Build() (qb.Cmp, map[string]any) {
return qb.Eq(c.column), map[string]any{c.column: c.value}
}
// In IN 條件
func In(column string, values []any) Condition {
return &inCondition{column: column, values: values}
}
type inCondition struct {
column string
values []any
}
func (c *inCondition) Build() (qb.Cmp, map[string]any) {
return qb.In(c.column), map[string]any{c.column: c.values}
}
// Gt 大於條件
func Gt(column string, value any) Condition {
return &gtCondition{column: column, value: value}
}
type gtCondition struct {
column string
value any
}
func (c *gtCondition) Build() (qb.Cmp, map[string]any) {
return qb.Gt(c.column), map[string]any{c.column: c.value}
}
// Lt 小於條件
func Lt(column string, value any) Condition {
return &ltCondition{column: column, value: value}
}
type ltCondition struct {
column string
value any
}
func (c *ltCondition) Build() (qb.Cmp, map[string]any) {
return qb.Lt(c.column), map[string]any{c.column: c.value}
}
// QueryBuilder 定義查詢構建器介面
type QueryBuilder[T Table] interface {
Where(condition Condition) QueryBuilder[T]
OrderBy(column string, order Order) QueryBuilder[T]
Limit(n int) QueryBuilder[T]
Select(columns ...string) QueryBuilder[T]
Scan(ctx context.Context, dest *[]T) error
One(ctx context.Context) (T, error)
Count(ctx context.Context) (int64, error)
}
// queryBuilder 是 QueryBuilder 的具體實作
type queryBuilder[T Table] struct {
repo *repository[T]
conditions []Condition
orders []orderBy
limit int
columns []string
}
type orderBy struct {
column string
order Order
}
// newQueryBuilder 創建新的查詢構建器
func newQueryBuilder[T Table](repo *repository[T]) QueryBuilder[T] {
return &queryBuilder[T]{
repo: repo,
}
}
// Where 添加 WHERE 條件
func (q *queryBuilder[T]) Where(condition Condition) QueryBuilder[T] {
q.conditions = append(q.conditions, condition)
return q
}
// OrderBy 添加排序
func (q *queryBuilder[T]) OrderBy(column string, order Order) QueryBuilder[T] {
q.orders = append(q.orders, orderBy{column: column, order: order})
return q
}
// Limit 設置限制
func (q *queryBuilder[T]) Limit(n int) QueryBuilder[T] {
q.limit = n
return q
}
// Select 指定要查詢的欄位
func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] {
q.columns = append(q.columns, columns...)
return q
}
// Scan 執行查詢並將結果掃描到 dest
func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
if dest == nil {
return ErrInvalidInput.WithTable(q.repo.table).WithError(
fmt.Errorf("destination cannot be nil"),
)
}
builder := qb.Select(q.repo.table)
// 添加欄位
if len(q.columns) > 0 {
builder = builder.Columns(q.columns...)
} else {
builder = builder.Columns(q.repo.metadata.Columns...)
}
// 添加條件
bindMap := make(map[string]any)
var cmps []qb.Cmp
for _, cond := range q.conditions {
cmp, binds := cond.Build()
cmps = append(cmps, cmp)
for k, v := range binds {
bindMap[k] = v
}
}
if len(cmps) > 0 {
builder = builder.Where(cmps...)
}
// 添加排序
for _, o := range q.orders {
order := qb.ASC
if o.order == DESC {
order = qb.DESC
}
builder = builder.OrderBy(o.column, order)
}
// 添加限制
if q.limit > 0 {
builder = builder.Limit(uint(q.limit))
}
stmt, names := builder.ToCql()
query := q.repo.db.withContextAndTimestamp(ctx,
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
return query.SelectRelease(dest)
}
// One 執行查詢並返回單筆結果
func (q *queryBuilder[T]) One(ctx context.Context) (T, error) {
var zero T
q.limit = 1
var results []T
if err := q.Scan(ctx, &results); err != nil {
return zero, err
}
if len(results) == 0 {
return zero, ErrNotFound.WithTable(q.repo.table)
}
return results[0], nil
}
// Count 計算符合條件的記錄數
func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) {
builder := qb.Select(q.repo.table).Columns("COUNT(*)")
// 添加條件
bindMap := make(map[string]any)
var cmps []qb.Cmp
for _, cond := range q.conditions {
cmp, binds := cond.Build()
cmps = append(cmps, cmp)
for k, v := range binds {
bindMap[k] = v
}
}
if len(cmps) > 0 {
builder = builder.Where(cmps...)
}
stmt, names := builder.ToCql()
query := q.repo.db.withContextAndTimestamp(ctx,
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
var count int64
err := query.GetRelease(&count)
if err == gocql.ErrNotFound {
return 0, nil // COUNT 查詢不會返回 ErrNotFound但為了安全起見
}
return count, err
}

View File

@ -1,30 +0,0 @@
package cassandra
import (
"context"
"time"
"github.com/scylladb/gocqlx/v3"
)
// queryHelper 封裝查詢相關的輔助方法
type queryHelper struct{}
// withTimestamp 為查詢添加時間戳
func (h *queryHelper) withTimestamp(q *gocqlx.Queryx) *gocqlx.Queryx {
return q.WithTimestamp(time.Now().UnixNano() / 1e3)
}
// withContextAndTimestamp 為查詢添加 context 和時間戳
func (h *queryHelper) withContextAndTimestamp(ctx context.Context, q *gocqlx.Queryx) *gocqlx.Queryx {
return q.WithContext(ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
}
// getKeyspace 獲取 keyspace如果為空則使用預設值
func getKeyspace(db *CassandraDB, keyspace string) string {
if keyspace == "" {
return db.defaultKeyspace
}
return keyspace
}

View File

@ -1,438 +0,0 @@
# Cassandra Database Client for Go with Advanced CRUD Operations and Transaction Support
一套功能完備的 Go 語言 Apache Cassandra 客戶端,支援進階 CRUD 操作、Batch 交易、分散式鎖機制、SAI (Storage-Attached Indexing) 索引與 Fluent API 鏈式查詢介面,讓你用最簡潔的程式碼玩轉 Cassandra
## 特色
* Go struct 自動生成 Table Metadata
* 批次操作與原子性交易支援(含 rollback
* 內建分散式鎖 (基於唯一索引)
* 支援 SAI 二級索引
* 類 GORM 流暢式Fluent API查詢體驗
* 單筆/多筆操作自動處理
* 完善的連線管理與組態選項
## 專案結構
```
.
├── batch.go # Batch 批次操作/交易
├── client.go # Cassandra 連線管理主體
├── crud.go # 基本 CRUD 操作
├── ez_transaction.go # 支援 rollback 的交易系統
├── lock.go # 分散式鎖實作
├── metadata.go # 由 struct 產生 Table metadata
├── option.go # 組態與查詢選項
├── table.go # Table 操作、查詢組合
├── utils.go # 工具函式
└── tests/ # 全面測試
```
## 安裝方式
```bash
go get gitlab.supermicro.com/infra/infra-core/storage/cassandra
```
## 快速開始
### 1. 初始化 Client
```go
import "gitlab.supermicro.com/infra/infra-core/storage/cassandra"
// 基本初始化(使用預設 keyspace
client, err := cassandra.NewCassandraDB(
[]string{"localhost"},
cassandra.WithPort(9042),
cassandra.WithKeyspace("my_keyspace"),
cassandra.WithAuth("username", "password"), // 可選
)
if err != nil {
log.Fatal(err)
}
defer client.Close()
// 使用預設 keyspace 時,後續操作可以省略 keyspace 參數
// 如果傳入空字串 "",會自動使用初始化時設定的預設 keyspace
```
### 2. 定義資料模型
```go
type User struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Name string `db:"name" clustering_key:"true" sai:"true"`
Email string `db:"email"`
CreatedAt time.Time `db:"created_at"`
}
func (u *User) TableName() string {
return "users"
}
```
### 3. 基本 CRUD 操作
```go
// 新增keyspace 為空時使用預設 keyspace
user := &User{
ID: gocql.TimeUUID(),
Name: "John Doe",
Email: "john@example.com",
CreatedAt: time.Now(),
}
err = client.Insert(ctx, user, "") // 使用預設 keyspace
// 或明確指定 keyspace
err = client.Insert(ctx, user, "my_keyspace")
// 查詢
result := &User{ID: user.ID}
err = client.Get(ctx, result, "")
if cassandra.IsNotFound(err) {
// 處理記錄不存在的情況
log.Println("User not found")
}
// 更新(只更新非零值欄位)
result.Email = "newemail@example.com"
err = client.Update(ctx, result, "")
// 更新所有欄位(包括零值)
result.Email = ""
err = client.UpdateAll(ctx, result, "")
// 選擇性更新(可控制是否包含零值)
err = client.UpdateSelective(ctx, result, "", false) // false = 排除零值
// 刪除
err = client.Delete(ctx, result, "")
```
### 4. 進階Batch 與補償式交易操作
```go
// Batch 操作(原子性批次操作)
// Batch 是 Cassandra 原生的批次操作,保證原子性
batch := client.NewBatch(ctx, "") // 使用預設 keyspace
batch.Insert(user1)
batch.Insert(user2)
batch.Update(user3)
err := batch.Commit()
// 補償式交易Compensating Transaction
// 注意:這不是真正的 ACID 交易,而是基於補償操作的模式
// 適用於最終一致性場景,可以確保「要嘛全成功,要嘛全失敗」
tx := cassandra.NewCompensatingTransaction(ctx, "", client)
// 或使用向後相容的別名
// tx := cassandra.NewEZTransaction(ctx, "", client)
tx.Insert(user1)
tx.Update(user2)
if err := tx.Commit(); err != nil {
// 如果 Commit 失敗,執行 Rollback 進行補償操作
if rollbackErr := tx.Rollback(); rollbackErr != nil {
log.Printf("Rollback failed: %v", rollbackErr)
}
return err
}
```
**Batch vs CompensatingTransaction 的區別:**
- **Batch**: Cassandra 原生的原子性批次操作,所有操作要嘛全部成功,要嘛全部失敗。但無法跨表操作,且不支援條件操作。
- **CompensatingTransaction**: 基於補償操作的交易模式,可以跨表操作,支援複雜的業務邏輯。透過記錄操作日誌,在失敗時執行補償操作來實現「要嘛全成功,要嘛全失敗」的語義。
### 5. 錯誤處理
```go
import "gitlab.supermicro.com/infra/infra-core/storage/cassandra"
// 統一的錯誤處理
result := &User{ID: userID}
err := client.Get(ctx, result, "")
if err != nil {
// 檢查特定錯誤類型
if cassandra.IsNotFound(err) {
// 處理記錄不存在
log.Println("User not found")
} else if cassandra.IsLockFailed(err) {
// 處理獲取鎖失敗
log.Println("Failed to acquire lock")
} else {
// 處理其他錯誤
log.Printf("Error: %v", err)
}
}
// 錯誤類型包含詳細資訊
var cassandraErr *cassandra.Error
if errors.As(err, &cassandraErr) {
log.Printf("Error Code: %s", cassandraErr.Code)
log.Printf("Error Message: %s", cassandraErr.Message)
log.Printf("Table: %s", cassandraErr.Table)
if cassandraErr.Err != nil {
log.Printf("Underlying Error: %v", cassandraErr.Err)
}
}
```
### 6. IN 操作
```go
// 使用 QueryBuilder 進行 IN 查詢
where := []qb.Cmp{qb.In("id")}
args := map[string]any{"id": uuids}
var result []User
err := client.QueryBuilder(
ctx,
&User{},
&result,
"", // 使用預設 keyspace
cassandra.WithWhere(where, args),
)
```
---
## Fluent API 鏈式查詢 (GORM 風格)
支援類 GORM 直覺式鏈式呼叫查詢方式,快速進行 CRUD、條件過濾、排序、分頁、單筆查詢、更新、刪除等操作
```go
type TestUser struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Name string `db:"name" sai:"true"`
Age int64 `db:"age"`
}
func (TestUser) TableName() string { return "test_user" }
// 新增單筆
user := TestUser{ID: gocql.TimeUUID(), Name: "Alice", Age: 20}
err := db.Model(ctx, TestUser{}, keyspace).InsertOne(user)
// 批量新增
users := []TestUser{{...}, {...}}
err := db.Model(ctx, TestUser{}, keyspace).InsertMany(users)
// 查詢所有
var got []TestUser
err := db.Model(ctx, TestUser{}, keyspace).GetAll(&got)
// 查詢某些欄位
var got []TestUser
err := db.Model(ctx, TestUser{}, ""). // 使用預設 keyspace
Select("name").GetAll(&got)
// 條件查詢 + 排序 + 分頁
var result []TestUser
err := db.Model(ctx, TestUser{}, "").
Where(qb.Eq("name"), map[string]any{"name": "Alice"}).
OrderBy("age", qb.DESC).
Limit(10).
Scan(&result)
// IN 操作
var result []TestUser
err := db.Model(ctx, TestUser{}, "").
Where(qb.In("name"), map[string]any{"name": []string{"Alice", "Bob"}}).
Scan(&result)
// 單筆查詢
var user TestUser
err := db.Model(ctx, TestUser{}, "").
Where(qb.Eq("id"), map[string]any{"id": userID}).
Take(&user)
// 更新欄位(必須提供 partition_key 或 sai indexed 欄位在 WHERE 中)
err := db.Model(ctx, TestUser{}, "").
Where(qb.Eq("id"), map[string]any{"id": userID}).
Set("age", 30).
Update()
// 刪除(必須提供所有 partition keys
err := db.Model(ctx, TestUser{}, "").
Where(qb.Eq("id"), map[string]any{"id": userID}).
Delete()
// 計數
count, err := db.Model(ctx, TestUser{}, "").
Where(qb.Eq("name"), map[string]any{"name": "Alice"}).
Count()
```
### 常用查詢語法總結
| 操作 | 用法範例 |
| ---- | --------------------------------------------- |
| 條件查詢 | .Where(qb.Eq("欄位"), map\[string]any{"欄位": 值}) |
| 指定欄位 | .Select("id", "name") |
| 排序 | .OrderBy("age", qb.DESC) |
| 分頁 | .Limit(10) |
| 查單筆 | .Take(\&result) |
| 更新欄位 | .Set("age", 25).Update() |
| 刪除 | .Delete() |
| 計數 | .Count() |
---
## 完整 API 參考
### 初始化選項
```go
// 連線選項
cassandra.WithPort(port int)
cassandra.WithKeyspace(keyspace string)
cassandra.WithAuth(username, password string)
cassandra.WithConsistency(consistency gocql.Consistency)
cassandra.WithConnectTimeoutSec(timeout int)
cassandra.WithNumConns(numConns int)
cassandra.WithMaxRetries(maxRetries int)
cassandra.WithRetryMinInterval(duration time.Duration)
cassandra.WithRetryMaxInterval(duration time.Duration)
cassandra.WithReconnectInitialInterval(duration time.Duration)
cassandra.WithReconnectMaxInterval(duration time.Duration)
cassandra.WithCQLVersion(version string)
```
### 基本 CRUD 方法
```go
// 插入
func (db *CassandraDB) Insert(ctx context.Context, document any, keyspace string) error
// 查詢(根據 Primary Key
func (db *CassandraDB) Get(ctx context.Context, dest any, keyspace string) error
// 更新(只更新非零值欄位)
func (db *CassandraDB) Update(ctx context.Context, document any, keyspace string) error
// 選擇性更新(可控制是否包含零值)
func (db *CassandraDB) UpdateSelective(ctx context.Context, document any, keyspace string, includeZero bool) error
// 更新所有欄位(包括零值)
func (db *CassandraDB) UpdateAll(ctx context.Context, document any, keyspace string) error
// 刪除
func (db *CassandraDB) Delete(ctx context.Context, filter any, keyspace string) error
// 查詢所有
func (db *CassandraDB) GetAll(ctx context.Context, filter any, result any, keyspace string) error
// 查詢構建器
func (db *CassandraDB) QueryBuilder(ctx context.Context, tableStruct any, result any, keyspace string, opts ...QueryOption) error
```
### Fluent API 方法
```go
// 創建查詢構建器
func (db *CassandraDB) Model(ctx context.Context, document any, keyspace string) *Query
// Query 方法
func (q *Query) Where(cmp qb.Cmp, args map[string]any) *Query
func (q *Query) Select(cols ...string) *Query
func (q *Query) OrderBy(column string, order qb.Order) *Query
func (q *Query) Limit(limit uint) *Query
func (q *Query) Set(col string, val any) *Query
func (q *Query) Scan(dest any) error
func (q *Query) Take(dest any) error
func (q *Query) GetAll(dest any) error
func (q *Query) Count() (int64, error)
func (q *Query) InsertOne(data any) error
func (q *Query) InsertMany(documents any) error
func (q *Query) Update() error
func (q *Query) Delete() error
```
### Batch 操作
```go
// 創建 Batch
func (db *CassandraDB) NewBatch(ctx context.Context, keyspace string) *Batch
// Batch 方法
func (tx *Batch) Insert(doc any) error
func (tx *Batch) Delete(doc any) error
func (tx *Batch) Update(doc any) error
func (tx *Batch) Commit() error
```
### 補償式交易
```go
// 創建補償式交易
func NewCompensatingTransaction(ctx context.Context, keyspace string, db *CassandraDB) CompensatingTransaction
// 向後相容的別名(已棄用)
func NewEZTransaction(ctx context.Context, keyspace string, db *CassandraDB) CompensatingTransaction
// Transaction 方法
func (tx CompensatingTransaction) Insert(ctx context.Context, document any) error
func (tx CompensatingTransaction) Delete(ctx context.Context, filter any) error
func (tx CompensatingTransaction) Update(ctx context.Context, document any) error
func (tx CompensatingTransaction) Commit() error
func (tx CompensatingTransaction) Rollback() error
```
### 分散式鎖
```go
// 嘗試獲取鎖
func (db *CassandraDB) TryLock(ctx context.Context, document any, keyspace string, opts ...LockOption) error
// 釋放鎖
func (db *CassandraDB) UnLock(ctx context.Context, filter any, keyspace string) error
// 鎖選項
func WithLockTTL(d time.Duration) LockOption
func WithNoLockExpire() LockOption
```
### 錯誤處理
```go
// 錯誤類型
type Error struct {
Code string
Message string
Table string
Err error
}
// 預定義錯誤
var ErrNotFound
var ErrAcquireLockFailed
var ErrInvalidInput
var ErrNoPartitionKey
var ErrMissingTableName
var ErrNoFieldsToUpdate
var ErrMissingWhereCondition
var ErrMissingPartitionKey
// 錯誤檢查函數
func IsNotFound(err error) bool
func IsLockFailed(err error) bool
```
## 注意事項
1. **Keyspace 處理**: 如果方法參數中的 `keyspace` 為空字串 `""`,會自動使用初始化時設定的預設 keyspace。
2. **WHERE 條件限制**: Cassandra 的 WHERE 條件只能使用:
- Partition Key 欄位
- 有 SAI 索引的欄位
- Clustering Key 欄位(在 Partition Key 之後)
3. **Update 方法**:
- `Update()`: 只更新非零值欄位
- `UpdateAll()`: 更新所有欄位(包括零值)
- `UpdateSelective()`: 可控制是否包含零值
4. **補償式交易**: 這不是真正的 ACID 交易,而是基於補償操作的模式,適用於最終一致性場景。
5. **錯誤處理**: 建議使用 `IsNotFound()``IsLockFailed()` 等輔助函數來檢查特定錯誤類型。
---

View File

@ -0,0 +1,257 @@
package cassandra
import (
"context"
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/scylladb/gocqlx/v3/table"
)
// Repository 定義資料存取介面(小介面,符合 M3
type Repository[T Table] interface {
// 基本 CRUD
Insert(ctx context.Context, doc T) error
Get(ctx context.Context, pk any) (T, error)
Update(ctx context.Context, doc T) error
Delete(ctx context.Context, pk any) error
// 批次操作
InsertMany(ctx context.Context, docs []T) error
// 查詢構建器
Query() QueryBuilder[T]
// 分散式鎖
TryLock(ctx context.Context, doc T, opts ...LockOption) error
UnLock(ctx context.Context, doc T) error
}
// repository 是 Repository 的具體實作
type repository[T Table] struct {
db *DB
keyspace string
table string
metadata table.Metadata
}
// NewRepository 獲取指定類型的 Repository
// keyspace 如果為空,使用預設 keyspace
func NewRepository[T Table](db *DB, keyspace string) (Repository[T], error) {
if keyspace == "" {
keyspace = db.defaultKeyspace
}
var zero T
metadata, err := generateMetadata(zero, keyspace)
if err != nil {
return nil, fmt.Errorf("failed to generate metadata: %w", err)
}
return &repository[T]{
db: db,
keyspace: keyspace,
table: metadata.Name,
metadata: metadata,
}, nil
}
// Insert 插入單筆資料
func (r *repository[T]) Insert(ctx context.Context, doc T) error {
t := table.New(r.metadata)
q := r.db.withContextAndTimestamp(ctx,
r.db.session.Query(t.Insert()).BindStruct(doc))
return q.ExecRelease()
}
// Get 根據主鍵查詢單筆資料
// 注意pk 必須是完整的 Primary Key包含所有 Partition Key 和 Clustering Key
// 如果主鍵是多欄位,需要傳入包含所有主鍵欄位的 struct
// pk 可以是string, int, int64, gocql.UUID, []byte 或包含主鍵欄位的 struct
func (r *repository[T]) Get(ctx context.Context, pk any) (T, error) {
var zero T
t := table.New(r.metadata)
// 使用 table.Get() 方法,它會自動根據 metadata 構建主鍵查詢
// 如果 pk 是 struct使用 BindStruct否則使用 Bind
var q *gocqlx.Queryx
if reflect.TypeOf(pk).Kind() == reflect.Struct {
q = r.db.withContextAndTimestamp(ctx,
r.db.session.Query(t.Get()).BindStruct(pk))
} else {
// 單一主鍵欄位的情況
// 注意:這只適用於單一 Partition Key 且無 Clustering Key 的情況
if len(r.metadata.PartKey) != 1 || len(r.metadata.SortKey) > 0 {
return zero, ErrInvalidInput.WithTable(r.table).WithError(
fmt.Errorf("single value primary key only supported for single partition key without clustering key"),
)
}
q = r.db.withContextAndTimestamp(ctx,
r.db.session.Query(t.Get()).Bind(pk))
}
var result T
err := q.GetRelease(&result)
if err == gocql.ErrNotFound {
return zero, ErrNotFound.WithTable(r.table)
}
if err != nil {
return zero, ErrInvalidInput.WithTable(r.table).WithError(err)
}
return result, nil
}
// Update 更新資料(只更新非零值欄位)
func (r *repository[T]) Update(ctx context.Context, doc T) error {
return r.updateSelective(ctx, doc, false)
}
// UpdateAll 更新所有欄位(包括零值)
func (r *repository[T]) UpdateAll(ctx context.Context, doc T) error {
return r.updateSelective(ctx, doc, true)
}
// updateSelective 選擇性更新
func (r *repository[T]) updateSelective(ctx context.Context, doc T, includeZero bool) error {
// 重用現有的 BuildUpdateFields 邏輯
// 由於在不同套件,我們需要重新實作或導入
fields, err := r.buildUpdateFields(doc, includeZero)
if err != nil {
return err
}
stmt, names := r.buildUpdateStatement(fields.setCols, fields.whereCols)
setVals := append(fields.setVals, fields.whereVals...)
q := r.db.withContextAndTimestamp(ctx,
r.db.session.Query(stmt, names).Bind(setVals...))
return q.ExecRelease()
}
// Delete 刪除資料
// pk 可以是string, int, int64, gocql.UUID, []byte 或包含主鍵欄位的 struct
func (r *repository[T]) Delete(ctx context.Context, pk any) error {
t := table.New(r.metadata)
stmt, names := t.Delete()
q := r.db.withContextAndTimestamp(ctx,
r.db.session.Query(stmt, names).Bind(pk))
return q.ExecRelease()
}
// InsertMany 批次插入資料
func (r *repository[T]) InsertMany(ctx context.Context, docs []T) error {
if len(docs) == 0 {
return nil
}
// 使用 Batch 操作
batch := r.db.session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
t := table.New(r.metadata)
stmt, names := t.Insert()
for _, doc := range docs {
if err := batch.BindStruct(r.db.session.Query(stmt, names), doc); err != nil {
return fmt.Errorf("failed to bind document: %w", err)
}
}
return r.db.session.ExecuteBatch(batch)
}
// Query 返回查詢構建器
func (r *repository[T]) Query() QueryBuilder[T] {
return newQueryBuilder(r)
}
// updateFields 包含更新操作所需的欄位資訊
type updateFields struct {
setCols []string
setVals []any
whereCols []string
whereVals []any
}
// buildUpdateFields 從 document 中提取更新所需的欄位資訊
func (r *repository[T]) buildUpdateFields(doc T, includeZero bool) (*updateFields, error) {
v := reflect.ValueOf(doc)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
typ := v.Type()
setCols := make([]string, 0)
setVals := make([]any, 0)
whereCols := make([]string, 0)
whereVals := make([]any, 0)
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
val := v.Field(i)
if !val.IsValid() {
continue
}
// 主鍵欄位放入 WHERE 條件
if contains(r.metadata.PartKey, tag) || contains(r.metadata.SortKey, tag) {
whereCols = append(whereCols, tag)
whereVals = append(whereVals, val.Interface())
continue
}
// 根據 includeZero 決定是否包含零值欄位
if !includeZero && isZero(val) {
continue
}
setCols = append(setCols, tag)
setVals = append(setVals, val.Interface())
}
if len(setCols) == 0 {
return nil, ErrNoFieldsToUpdate.WithTable(r.table)
}
return &updateFields{
setCols: setCols,
setVals: setVals,
whereCols: whereCols,
whereVals: whereVals,
}, nil
}
// buildUpdateStatement 構建 UPDATE CQL 語句
func (r *repository[T]) buildUpdateStatement(setCols, whereCols []string) (string, []string) {
builder := qb.Update(r.table).Set(setCols...)
for _, col := range whereCols {
builder = builder.Where(qb.Eq(col))
}
return builder.ToCql()
}
// contains 判斷字串是否存在於 slice 中
func contains(list []string, target string) bool {
for _, item := range list {
if item == target {
return true
}
}
return false
}
// isZero 判斷欄位是否為零值或 nil
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice:
return v.IsNil()
default:
return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
}
}

View File

@ -1,462 +0,0 @@
package cassandra
import (
"context"
"errors"
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/scylladb/gocqlx/v3/table"
)
func (db *CassandraDB) AutoCreateSAIIndexes(doc any, keyspace string) error {
metadata, err := GenerateTableMetadata(doc, keyspace)
if err != nil {
return err
}
t := reflect.TypeOf(doc)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Tag.Get("sai") == "true" {
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS ON %s (%s) USING 'sai';", metadata.Name, col)
if err := db.GetSession().ExecStmt(stmt); err != nil {
return fmt.Errorf("failed to create SAI index on table %s, column %s: %w", metadata.Name, col, err)
}
}
}
return nil
}
type Query struct {
db *CassandraDB
ctx context.Context
table string
keyspace string
columns []string
cmps []qb.Cmp
bindMap map[string]any
orders []orderBy
limit uint
document any
sets []setField // 欲更新欄位及其值
errs []error
}
type orderBy struct {
Column string
Order qb.Order
}
type setField struct {
Col string
Val any
}
// Model 創建一個新的查詢構建器
// document: 用於推斷表結構的範例物件(必須實現 TableName() 方法)
// keyspace: 如果為空,則使用初始化時設定的預設 keyspace
func (db *CassandraDB) Model(ctx context.Context, document any, keyspace string) *Query {
keyspace = getKeyspace(db, keyspace)
metadata, err := GenerateTableMetadata(document, keyspace)
if err != nil {
// 如果 metadata 生成失敗,創建一個帶錯誤的 Query
return &Query{
db: db,
ctx: ctx,
keyspace: keyspace,
document: document,
errs: []error{err},
}
}
return &Query{
db: db,
ctx: ctx,
table: metadata.Name,
keyspace: keyspace,
columns: make([]string, 0),
cmps: make([]qb.Cmp, 0),
bindMap: make(map[string]any),
orders: make([]orderBy, 0),
limit: 0,
document: document, // document 用於生成 metadata 和驗證 SAI 欄位
errs: make([]error, 0),
}
}
// Where 添加 WHERE 條件
// 只允許 partition key 或有 sai index 的欄位進行 where 查詢
// cmp: 查詢條件(如 qb.Eq("id")
// args: 參數映射(如 map[string]any{"id": uuid}
func (q *Query) Where(cmp qb.Cmp, args map[string]any) *Query {
// 如果之前有錯誤,直接返回
if len(q.errs) > 0 {
return q
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
q.errs = append(q.errs, err)
return q
}
for k := range args {
// 允許 partition_key 或 sai 欄位
isPartition := contains(metadata.PartKey, k)
isSAI := IsSAIField(q.document, k)
if !isPartition && !isSAI {
q.errs = append(q.errs, NewError(
"INVALID_WHERE_FIELD",
fmt.Sprintf("where condition on field %s requires partition_key or sai index", k),
).WithTable(q.table))
}
}
q.cmps = append(q.cmps, cmp)
for k, v := range args {
q.bindMap[k] = v
}
return q
}
func (q *Query) Select(cols ...string) *Query {
q.columns = append(q.columns, cols...)
return q
}
func (q *Query) OrderBy(column string, order qb.Order) *Query {
q.orders = append(q.orders, orderBy{Column: column, Order: order})
return q
}
func (q *Query) Limit(limit uint) *Query {
q.limit = limit
return q
}
func (q *Query) Set(col string, val any) *Query {
q.sets = append(q.sets, setField{Col: col, Val: val})
return q
}
// Scan 執行查詢並將結果掃描到 dest
// dest 必須是指標類型:*Struct 用於單筆查詢,*[]Struct 用於多筆查詢
func (q *Query) Scan(dest any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
builder := qb.Select(q.table)
if len(q.columns) > 0 {
builder = builder.Columns(q.columns...)
} else {
// 如果沒有指定欄位,使用所有欄位
builder = builder.Columns(metadata.Columns...)
}
if len(q.cmps) > 0 {
builder = builder.Where(q.cmps...)
}
if len(q.orders) > 0 {
for _, o := range q.orders {
builder = builder.OrderBy(o.Column, o.Order)
}
}
if q.limit > 0 {
builder = builder.Limit(q.limit)
}
stmt, names := builder.ToCql()
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
// 型態判斷自動選用單筆/多筆查詢
destType := reflect.TypeOf(dest)
if destType.Kind() != reflect.Ptr {
return ErrInvalidInput.WithTable(q.table).WithError(fmt.Errorf("destination must be a pointer, got %T", dest))
}
elemType := destType.Elem()
switch elemType.Kind() {
case reflect.Slice:
return query.SelectRelease(dest)
case reflect.Struct:
err := query.GetRelease(dest)
if err == gocql.ErrNotFound {
return ErrNotFound.WithTable(q.table)
}
return err
default:
return ErrInvalidInput.WithTable(q.table).WithError(fmt.Errorf("destination must be pointer to struct or slice, got %T", dest))
}
}
func (q *Query) Take(dest any) error {
q.limit = 1
return q.Scan(dest)
}
// Delete 執行刪除操作
// 要求:必須提供所有 partition keys 在 WHERE 條件中
func (q *Query) Delete() error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
// 檢查是否提供所有 partition keys
missingKeys := make([]string, 0)
for _, pk := range metadata.PartKey {
if _, ok := q.bindMap[pk]; !ok {
missingKeys = append(missingKeys, pk)
}
}
if len(missingKeys) > 0 {
return ErrMissingPartitionKey.WithTable(q.table).WithError(
fmt.Errorf("missing partition keys: %v", missingKeys),
)
}
if len(q.cmps) == 0 {
return ErrMissingWhereCondition.WithTable(q.table)
}
// 組 Delete 語句
builder := qb.Delete(q.table)
builder = builder.Where(q.cmps...)
stmt, names := builder.ToCql()
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
return query.ExecRelease()
}
// Update 執行更新操作
// 要求:必須提供至少一個 partition_key 或 sai indexed 欄位在 WHERE 條件中,且至少有一個 Set 欄位
func (q *Query) Update() error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
if q.document == nil {
return ErrInvalidInput.WithTable(q.table).WithError(
fmt.Errorf("update requires document model to check partition keys"),
)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
// 先收集所有可被當作主查詢條件的欄位
allowed := make(map[string]struct{})
// 收集 partition_key
for _, pk := range metadata.PartKey {
allowed[pk] = struct{}{}
}
// 收集所有 sai 欄位
for _, f := range reflect.VisibleFields(reflect.TypeOf(q.document)) {
if f.Tag.Get("sai") == "true" {
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
allowed[col] = struct{}{}
}
}
// 檢查 bindMap 有沒有 hit 到
hasCondition := false
for k := range q.bindMap {
if _, ok := allowed[k]; ok {
hasCondition = true
break
}
}
if !hasCondition {
return ErrMissingPartitionKey.WithTable(q.table).WithError(
fmt.Errorf("requires at least one partition_key or sai indexed field in WHERE clause"),
)
}
// 至少要有一個 set 欄位
if len(q.sets) == 0 {
return ErrNoFieldsToUpdate.WithTable(q.table)
}
// 至少一個 where
if len(q.cmps) == 0 {
return ErrMissingWhereCondition.WithTable(q.table)
}
// 組合 set 欄位
setCols := make([]string, 0, len(q.sets))
setVals := make([]any, 0, len(q.sets))
for _, s := range q.sets {
setCols = append(setCols, s.Col)
setVals = append(setVals, s.Val)
}
// 組合 CQL
builder := qb.Update(q.table).Set(setCols...)
builder = builder.Where(q.cmps...)
stmt, names := builder.ToCql()
// setVals 要先,剩下的 where bind 順序依照 names
bindVals := append([]any{}, setVals...)
for _, name := range names[len(setCols):] {
if v, ok := q.bindMap[name]; ok {
bindVals = append(bindVals, v)
}
}
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
if len(bindVals) > 0 {
query = query.Bind(bindVals...)
}
return query.ExecRelease()
}
// InsertOne 插入單筆資料
func (q *Query) InsertOne(data any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
tbl := table.New(metadata)
qry := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(tbl.Insert()))
switch reflect.TypeOf(data).Kind() {
case reflect.Map:
qry = qry.BindMap(data.(map[string]any))
default:
qry = qry.BindStruct(data)
}
return qry.ExecRelease()
}
func (q *Query) InsertMany(documents any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
v := reflect.ValueOf(documents)
if v.Kind() != reflect.Slice {
return fmt.Errorf("insert many: input must be a slice, got %T", documents)
}
if v.Len() == 0 {
return nil
}
for i := 0; i < v.Len(); i++ {
item := v.Index(i).Interface()
if err := q.InsertOne(item); err != nil {
return fmt.Errorf("insert many: failed at index %d (table: %s): %w", i, q.table, err)
}
}
return nil
}
// GetAll 查詢所有資料(不帶條件)
func (q *Query) GetAll(dest any) error {
if len(q.errs) > 0 {
return errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return err
}
t := table.New(metadata)
stmt, names := qb.Select(t.Name()).Columns(metadata.Columns...).ToCql()
exec := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
return exec.SelectRelease(dest)
}
// Count 計算符合條件的記錄數
func (q *Query) Count() (int64, error) {
if len(q.errs) > 0 {
return 0, errors.Join(q.errs...)
}
metadata, err := GenerateTableMetadata(q.document, q.keyspace)
if err != nil {
return 0, err
}
t := table.New(metadata)
builder := qb.Select(t.Name()).Columns("COUNT(*)")
if len(q.cmps) > 0 {
builder = builder.Where(q.cmps...)
}
stmt, names := builder.ToCql()
query := qh.withContextAndTimestamp(q.ctx, q.db.GetSession().Query(stmt, names))
if q.bindMap == nil {
q.bindMap = qb.M{}
}
query = query.BindMap(q.bindMap)
var count int64
if err := query.GetRelease(&count); err != nil {
if err == gocql.ErrNotFound {
return 0, nil // COUNT 查詢不會返回 ErrNotFound但為了安全起見
}
return 0, err
}
return count, nil
}
func IsSAIField(model any, fieldName string) bool {
t := reflect.TypeOf(model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get("sai")
col := f.Tag.Get("db")
if col == "" {
col = toSnakeCase(f.Name)
}
if (col == fieldName || f.Name == fieldName) && tag == "true" {
return true
}
}
return false
}

View File

@ -1,324 +0,0 @@
package cassandra
import (
"context"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx/v3/qb"
"github.com/stretchr/testify/assert"
)
func TestQueryBuilder(t *testing.T) {
ctx := context.Background()
db := &CassandraDB{} // 可以用 mock DB
type args struct {
cmp qb.Cmp
whereArg map[string]any
selects []string
orderCol string
order qb.Order
limit uint
setCol string
setVal any
}
tests := []struct {
name string
args args
wantPanic bool
wantColumns []string
wantOrderCol string
wantOrder qb.Order
wantLimit uint
wantSetCol string
wantSetVal any
}{
{
name: "where by partition key",
args: args{
cmp: qb.Eq("id"),
whereArg: map[string]any{"id": "abc"},
selects: []string{"id", "name"},
orderCol: "id",
order: qb.ASC,
limit: 1,
setCol: "name",
setVal: "Daniel",
},
wantPanic: false,
wantColumns: []string{"id", "name"},
wantOrderCol: "id",
wantOrder: qb.ASC,
wantLimit: 1,
wantSetCol: "name",
wantSetVal: "Daniel",
},
{
name: "where by sai index",
args: args{
cmp: qb.Eq("name"),
whereArg: map[string]any{"name": "daniel"},
selects: []string{"id", "name"},
orderCol: "name",
order: qb.DESC,
limit: 2,
setCol: "name",
setVal: "Jacky",
},
wantPanic: false,
wantColumns: []string{"id", "name"},
wantOrderCol: "name",
wantOrder: qb.DESC,
wantLimit: 2,
wantSetCol: "name",
wantSetVal: "Jacky",
},
{
name: "where by non-partition-non-sai",
args: args{
cmp: qb.Eq("age"),
whereArg: map[string]any{"age": 18},
selects: []string{"id", "name"},
orderCol: "age",
order: qb.ASC,
limit: 3,
setCol: "age",
setVal: 20,
},
wantPanic: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
q := db.Model(ctx, &MonkeyEntity{}, "my_keyspace").
Where(tc.args.cmp, tc.args.whereArg).
Select(tc.args.selects...).
OrderBy(tc.args.orderCol, tc.args.order).
Limit(tc.args.limit).
Set(tc.args.setCol, tc.args.setVal)
if tc.wantPanic {
assert.Error(t, q.Update())
} else {
assert.Equal(t, tc.wantColumns, q.columns)
if len(q.orders) > 0 {
assert.Equal(t, tc.wantOrderCol, q.orders[0].Column)
assert.Equal(t, tc.wantOrder, q.orders[0].Order)
}
assert.Equal(t, tc.wantLimit, q.limit)
if len(q.sets) > 0 {
assert.Equal(t, tc.wantSetCol, q.sets[0].Col)
assert.Equal(t, tc.wantSetVal, q.sets[0].Val)
}
}
})
}
}
func TestQuery_Select(t *testing.T) {
tests := []struct {
name string
selectCalls [][]string
wantColumns []string
}{
{
name: "select one col",
selectCalls: [][]string{{"id"}},
wantColumns: []string{"id"},
},
{
name: "select multi col in one call",
selectCalls: [][]string{{"id", "name"}},
wantColumns: []string{"id", "name"},
},
{
name: "multiple select calls append columns",
selectCalls: [][]string{{"id"}, {"name"}, {"age"}},
wantColumns: []string{"id", "name", "age"},
},
{
name: "multiple select calls with overlap",
selectCalls: [][]string{{"id"}, {"id", "name"}, {"name", "age"}},
wantColumns: []string{"id", "id", "name", "name", "age"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
q := &Query{columns: make([]string, 0)}
for _, call := range tc.selectCalls {
q = q.Select(call...)
}
assert.Equal(t, tc.wantColumns, q.columns)
})
}
}
func TestQuery_Count(t *testing.T) {
// 準備測試用資料
ctx := context.Background()
ks := generateRandomKeySpace(t)
cassandraDBTest.AutoCreateSAIIndexes(&MonkeyEntity{}, ks)
now := time.Now().UTC()
// 批量插入資料
docs := []MonkeyEntity{
{ID: gocql.TimeUUID(), Name: "Alice", CreateAt: now, UpdateAt: now},
{ID: gocql.TimeUUID(), Name: "Bob", CreateAt: now, UpdateAt: now},
{ID: gocql.TimeUUID(), Name: "Alice", CreateAt: now, UpdateAt: now},
}
for _, doc := range docs {
assert.NoError(t, cassandraDBTest.Insert(ctx, &doc, ks))
}
tests := []struct {
name string
filterName string
wantCount int64
}{
{"CountAll", "", 3},
{"CountAlice", "Alice", 2},
{"CountBob", "Bob", 1},
{"CountNobody", "Charlie", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := cassandraDBTest.Model(ctx, &MonkeyEntity{}, ks)
if tt.filterName != "" {
q = q.Where(qb.Eq("name"), qb.M{"name": tt.filterName})
}
count, err := q.Count()
assert.NoError(t, err)
assert.Equal(t, tt.wantCount, count)
})
}
}
type TestUser struct {
ID gocql.UUID `db:"id" partition_key:"true"`
Name string `db:"name" sai:"true"`
Age int64 `db:"age"`
}
func (TestUser) TableName() string { return "test_user" }
func TestQueryBasicFlow(t *testing.T) {
// 啟動 Cassandra container
ctx := context.Background()
keyspace := "my_keyspace"
err := cassandraDBTest.EnsureTable(`
CREATE TABLE IF NOT EXISTS my_keyspace.test_user (
id UUID,
name TEXT,
age BIGINT,
PRIMARY KEY (id)
);`)
assert.NoError(t, err)
err = cassandraDBTest.AutoCreateSAIIndexes(&TestUser{}, keyspace)
assert.NoError(t, err)
// 測試資料
u1 := TestUser{ID: gocql.TimeUUID(), Name: "Alice", Age: 20}
u2 := TestUser{ID: gocql.TimeUUID(), Name: "Bob", Age: 22}
u3 := TestUser{ID: gocql.TimeUUID(), Name: "Carol", Age: 23}
// InsertOne/InsertMany
t.Run("InsertOne", func(t *testing.T) {
q := cassandraDBTest.Model(ctx, TestUser{}, keyspace)
assert.NoError(t, q.InsertOne(u1))
})
t.Run("InsertMany", func(t *testing.T) {
q := cassandraDBTest.Model(ctx, TestUser{}, keyspace)
assert.NoError(t, q.InsertMany([]TestUser{u2, u3}))
})
// GetAll
t.Run("GetAll", func(t *testing.T) {
var got []TestUser
q := cassandraDBTest.Model(ctx, TestUser{}, keyspace)
assert.NoError(t, q.GetAll(&got))
assert.GreaterOrEqual(t, len(got), 3)
})
// Count
t.Run("Count All", func(t *testing.T) {
q := cassandraDBTest.Model(ctx, TestUser{}, keyspace)
count, err := q.Count()
assert.NoError(t, err)
assert.GreaterOrEqual(t, count, int64(3))
})
// Delete
t.Run("Delete Carol", func(t *testing.T) {
q2 := cassandraDBTest.Model(ctx, TestUser{}, keyspace)
q2.Where(qb.Eq("id"), map[string]any{"id": u3.ID})
assert.NoError(t, q2.Delete())
// 驗證已刪除
var user TestUser
err := cassandraDBTest.Model(ctx, TestUser{}, keyspace).
Where(qb.Eq("id"), map[string]any{"id": u3.ID}).Scan(&user)
assert.Error(t, err)
q3 := cassandraDBTest.Model(ctx, TestUser{}, keyspace)
count, err := q3.Count()
assert.NoError(t, err)
assert.GreaterOrEqual(t, count, int64(2))
})
// Scan
t.Run("Scan Find Alice", func(t *testing.T) {
var user []TestUser
err := cassandraDBTest.Model(ctx, TestUser{}, keyspace).
Where(qb.Eq("name"), map[string]any{"name": "Alice"}).Scan(&user)
assert.NoError(t, err)
assert.Equal(t, u1.Name, user[0].Name)
})
//
// Take (僅取一筆)
t.Run("Take Get Bob", func(t *testing.T) {
var user TestUser
q2 := cassandraDBTest.Model(ctx, TestUser{}, keyspace).
Where(qb.Eq("name"), map[string]any{"name": "Bob"})
assert.NoError(t, q2.Take(&user))
assert.Equal(t, u2.Name, user.Name)
})
// Update
t.Run("Update Age of Alice", func(t *testing.T) {
q := cassandraDBTest.Model(ctx, TestUser{}, keyspace)
assert.NoError(t, q.InsertMany([]TestUser{u1, u2, u3}))
err = cassandraDBTest.Model(ctx,
TestUser{}, keyspace).
Where(qb.Eq("id"), map[string]any{"id": u1.ID}).
Set("age", 30).
Update()
assert.NoError(t, err)
// 驗證
var user TestUser
assert.NoError(t, cassandraDBTest.Model(ctx, TestUser{}, keyspace).
Where(qb.Eq("id"), map[string]any{"id": u1.ID}).Take(&user))
assert.Equal(t, int64(30), user.Age)
})
// In 這個 case 不通過,原因是 sai key 也不一定可以確認 cassandra 分區
t.Run("In", func(t *testing.T) {
q := cassandraDBTest.Model(ctx, TestUser{}, keyspace)
assert.NoError(t, q.InsertMany([]TestUser{u1, u2, u3}))
var user []TestUser
err = cassandraDBTest.Model(ctx,
TestUser{}, keyspace).
Where(qb.In("name"), map[string]any{"name": []string{u1.Name, u2.Name}}).
Scan(&user)
assert.Error(t, err)
})
}

View File

@ -0,0 +1,32 @@
package cassandra
import (
"github.com/gocql/gocql"
)
// Table 定義資料表模型必須實作的介面
type Table interface {
TableName() string
}
// PrimaryKey 定義主鍵類型(使用類型約束)
// 注意Go 1.18+ 才支持類型約束,如果需要兼容舊版本,可以使用 interface{}
type PrimaryKey interface {
~string | ~int | ~int64 | gocql.UUID | []byte
}
// Order 定義排序順序
type Order int
const (
ASC Order = 0
DESC Order = 1
)
// 將 Order 轉換為 toGocqlX 的 Order
func (o Order) toGocqlX() string {
if o == DESC {
return "DESC"
}
return "ASC"
}

View File

@ -1,65 +0,0 @@
package cassandra
import (
"reflect"
"unicode"
)
// GetCqlTag 取得指定欄位的 cql tag
// model 必須為 struct 指標fieldPtr 為該 struct 欄位的指標
func GetCqlTag(model interface{}, fieldPtr interface{}) string {
v := reflect.ValueOf(model)
// 確保 model 為 struct 指標
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return ""
}
s := v.Elem()
// 遍歷所有欄位,找出地址與傳入 fieldPtr 相符的欄位
for i := 0; i < s.NumField(); i++ {
field := s.Type().Field(i)
fieldVal := s.Field(i)
// 如果能取地址且地址與 fieldPtr 相等,則取得 tag
if fieldVal.CanAddr() && fieldVal.Addr().Interface() == fieldPtr {
return field.Tag.Get("db")
}
}
return ""
}
// toSnakeCase 將 CamelCase 字串轉換為 snake_case
func toSnakeCase(s string) string {
var result []rune
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
result = append(result, '_')
}
result = append(result, unicode.ToLower(r))
} else {
result = append(result, r)
}
}
return string(result)
}
// 判斷欄位是否為零值或 nil
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice:
return v.IsNil()
default:
return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
}
}
// 判斷字串是否存在於 slice 中
func contains(list []string, target string) bool {
for _, item := range list {
if item == target {
return true
}
}
return false
}

View File

@ -1,166 +0,0 @@
package cassandra
import (
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/gocql/gocql"
)
func TestGetCqlTag(t *testing.T) {
monkey := &MonkeyEntity{
// 為了測試用,欄位內容可以不給值
ID: gocql.TimeUUID(),
Name: "TestMonkey",
UpdateAt: time.Now(),
CreateAt: time.Now(),
}
tests := []struct {
name string
model interface{}
fieldPtr interface{}
expected string
expectPanic bool
}{
{
name: "取得 Name 的 cql tag",
model: monkey,
fieldPtr: &monkey.Name,
expected: "name",
},
{
name: "取得 ID 的 cql tag",
model: monkey,
fieldPtr: &monkey.ID,
expected: "id",
},
{
name: "取得 UpdateAt 的 cql tag",
model: monkey,
fieldPtr: &monkey.UpdateAt,
expected: "update_at",
},
{
name: "取得 CreateAt 的 cql tag",
model: monkey,
fieldPtr: &monkey.CreateAt,
expected: "create_at",
},
{
name: "找不到對應欄位,回傳空字串",
model: monkey,
fieldPtr: new(int), // 傳入與 MonkeyEntity 無關的欄位指標
expected: "",
},
{
name: "非指向 struct 的 model應該 panic",
model: MonkeyEntity{}, // 非指針
fieldPtr: &monkey.Name,
expected: "",
},
}
for _, tt := range tests {
tt := tt // 捕捉迴圈變數
t.Run(tt.name, func(t *testing.T) {
// 如果預期會 panic則用 recover 進行驗證
if tt.expectPanic {
defer func() {
if r := recover(); r == nil {
t.Errorf("預期測試案例 %q 發生 panic但實際並未 panic", tt.name)
}
}()
_ = GetCqlTag(tt.model, tt.fieldPtr)
} else {
result := GetCqlTag(tt.model, tt.fieldPtr)
if result != tt.expected {
t.Errorf("測試案例 %q: 預期 %q, 但得到 %q", tt.name, tt.expected, result)
}
}
})
}
}
// -------------------- 測試函式 --------------------
// TestToSnakeCase 測試 toSnakeCase 函式
func TestToSnakeCase(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"CamelCase", "camel_case"},
{"snake_case", "snake_case"},
{"HttpServer", "http_server"},
{"A", "a"},
{"Already_Snake", "already__snake"}, // 依照實作,"Already_Snake" 轉換後會產生 double underscore
}
for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
result := toSnakeCase(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
func TestIsZero(t *testing.T) {
type testCase struct {
name string
input any
expected bool
}
tests := []testCase{
{"zero int", 0, true},
{"non-zero int", 42, false},
{"zero string", "", true},
{"non-zero string", "hello", false},
{"zero bool", false, true},
{"non-zero bool", true, false},
{"nil slice", []string(nil), true},
{"empty slice", []string{}, false},
{"nil pointer", (*int)(nil), true},
{"non-nil pointer", new(int), false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
v := reflect.ValueOf(tc.input)
actual := isZero(v)
if actual != tc.expected {
t.Errorf("isZero(%v) = %v; want %v", tc.input, actual, tc.expected)
}
})
}
}
func TestContains(t *testing.T) {
type testCase struct {
name string
list []string
target string
expected bool
}
tests := []testCase{
{"contains first", []string{"a", "b", "c"}, "a", true},
{"contains middle", []string{"a", "b", "c"}, "b", true},
{"contains last", []string{"a", "b", "c"}, "c", true},
{"not contains", []string{"a", "b", "c"}, "d", false},
{"empty list", []string{}, "a", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
actual := contains(tc.list, tc.target)
if actual != tc.expected {
t.Errorf("contains(%v, %q) = %v; want %v", tc.list, tc.target, actual, tc.expected)
}
})
}
}