163 lines
3.6 KiB
Go
163 lines
3.6 KiB
Go
package cassandra
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gocql/gocql"
|
||
"github.com/scylladb/gocqlx/v3"
|
||
)
|
||
|
||
// DB 是 Cassandra 的核心資料庫連接
|
||
type DB struct {
|
||
session gocqlx.Session
|
||
defaultKeyspace string
|
||
version string
|
||
saiSupported bool
|
||
|
||
// 內部快取
|
||
metadataCache sync.Map // 重用現有的 metadata 快取邏輯
|
||
}
|
||
|
||
// New 創建新的 DB 實例
|
||
func New(opts ...Option) (*DB, error) {
|
||
cfg := defaultConfig()
|
||
for _, opt := range opts {
|
||
opt(cfg)
|
||
}
|
||
|
||
if len(cfg.Hosts) == 0 {
|
||
return nil, fmt.Errorf("at least one host is required")
|
||
}
|
||
|
||
// 建立連線設定
|
||
cluster := gocql.NewCluster(cfg.Hosts...)
|
||
cluster.Port = cfg.Port
|
||
cluster.Consistency = cfg.Consistency
|
||
cluster.Timeout = time.Duration(cfg.ConnectTimeoutSec) * time.Second
|
||
cluster.NumConns = cfg.NumConns
|
||
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
|
||
NumRetries: cfg.MaxRetries,
|
||
Min: cfg.RetryMinInterval,
|
||
Max: cfg.RetryMaxInterval,
|
||
}
|
||
|
||
cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{
|
||
MaxRetries: cfg.MaxRetries,
|
||
InitialInterval: cfg.ReconnectInitialInterval,
|
||
MaxInterval: cfg.ReconnectMaxInterval,
|
||
}
|
||
|
||
// 若有提供 Keyspace 則指定
|
||
if cfg.Keyspace != "" {
|
||
cluster.Keyspace = cfg.Keyspace
|
||
}
|
||
|
||
// 若啟用驗證則設定帳號密碼
|
||
if cfg.UseAuth {
|
||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
||
Username: cfg.Username,
|
||
Password: cfg.Password,
|
||
}
|
||
}
|
||
|
||
// 建立 Session
|
||
session, err := gocqlx.WrapSession(cluster.CreateSession())
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to Cassandra cluster (hosts: %v, port: %d): %w", cfg.Hosts, cfg.Port, err)
|
||
}
|
||
|
||
db := &DB{
|
||
session: session,
|
||
defaultKeyspace: cfg.Keyspace,
|
||
}
|
||
|
||
// 初始化版本資訊
|
||
version, err := db.getVersion(context.Background())
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get DB version: %w", err)
|
||
}
|
||
db.version = version
|
||
db.saiSupported = isSAISupported(version)
|
||
|
||
return db, nil
|
||
}
|
||
|
||
// Close 關閉資料庫連線
|
||
func (db *DB) Close() {
|
||
db.session.Close()
|
||
}
|
||
|
||
// GetSession 返回底層的 gocqlx Session(用於進階操作)
|
||
func (db *DB) GetSession() gocqlx.Session {
|
||
return db.session
|
||
}
|
||
|
||
// GetDefaultKeyspace 返回預設的 keyspace
|
||
func (db *DB) GetDefaultKeyspace() string {
|
||
return db.defaultKeyspace
|
||
}
|
||
|
||
// Version 返回資料庫版本
|
||
func (db *DB) Version() string {
|
||
return db.version
|
||
}
|
||
|
||
// SaiSupported 返回是否支援 SAI
|
||
func (db *DB) SaiSupported() bool {
|
||
return db.saiSupported
|
||
}
|
||
|
||
// getVersion 獲取資料庫版本
|
||
func (db *DB) getVersion(ctx context.Context) (string, error) {
|
||
var version string
|
||
stmt := "SELECT release_version FROM system.local"
|
||
err := db.session.Query(stmt, []string{"release_version"}).
|
||
WithContext(ctx).
|
||
Consistency(gocql.One).
|
||
Scan(&version)
|
||
return version, err
|
||
}
|
||
|
||
// isSAISupported 檢查版本是否支援 SAI
|
||
func isSAISupported(version string) bool {
|
||
// 只要 major >=5 就支援
|
||
// 4.0.9+ 才有 SAI,但不穩,強烈建議 5.0+
|
||
parts := strings.Split(version, ".")
|
||
if len(parts) < 2 {
|
||
return false
|
||
}
|
||
major, _ := strconv.Atoi(parts[0])
|
||
minor, _ := strconv.Atoi(parts[1])
|
||
|
||
if major >= 5 {
|
||
return true
|
||
}
|
||
|
||
if major == 4 {
|
||
if minor > 0 { // 4.1.x、4.2.x 直接支援
|
||
return true
|
||
}
|
||
if minor == 0 {
|
||
patch := 0
|
||
if len(parts) >= 3 {
|
||
patch, _ = strconv.Atoi(parts[2])
|
||
}
|
||
if patch >= 9 {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// withContextAndTimestamp 為查詢添加 context 和時間戳
|
||
func (db *DB) withContextAndTimestamp(ctx context.Context, q *gocqlx.Queryx) *gocqlx.Queryx {
|
||
return q.WithContext(ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
|
||
}
|