package cassandra import ( "errors" "testing" "time" "github.com/gocql/gocql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestIsSAISupported(t *testing.T) { tests := []struct { name string version string expected bool }{ { name: "version 5.0.0 should support SAI", version: "5.0.0", expected: true, }, { name: "version 5.1.0 should support SAI", version: "5.1.0", expected: true, }, { name: "version 6.0.0 should support SAI", version: "6.0.0", expected: true, }, { name: "version 4.1.0 should support SAI", version: "4.1.0", expected: true, }, { name: "version 4.2.0 should support SAI", version: "4.2.0", expected: true, }, { name: "version 4.0.9 should support SAI", version: "4.0.9", expected: true, }, { name: "version 4.0.10 should support SAI", version: "4.0.10", expected: true, }, { name: "version 4.0.8 should not support SAI", version: "4.0.8", expected: false, }, { name: "version 4.0.0 should not support SAI", version: "4.0.0", expected: false, }, { name: "version 3.11.0 should not support SAI", version: "3.11.0", expected: false, }, { name: "invalid version format should not support SAI", version: "invalid", expected: false, }, { name: "empty version should not support SAI", version: "", expected: false, }, { name: "version with only major should not support SAI", version: "5", expected: false, }, { name: "version 4.0.9 with extra parts should support SAI", version: "4.0.9.1", expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isSAISupported(tt.version) assert.Equal(t, tt.expected, result, "version %s should have SAI support = %v", tt.version, tt.expected) }) } } func TestNew_Validation(t *testing.T) { tests := []struct { name string opts []Option wantErr bool errMsg string }{ { name: "no hosts should return error", opts: []Option{}, wantErr: true, errMsg: "at least one host is required", }, { name: "empty hosts should return error", opts: []Option{WithHosts()}, wantErr: true, errMsg: "at least one host is required", }, { name: "valid hosts should not return error on validation", opts: []Option{ WithHosts("localhost"), }, wantErr: false, }, { name: "multiple hosts should not return error on validation", opts: []Option{ WithHosts("localhost", "127.0.0.1"), }, wantErr: false, }, { name: "with keyspace should not return error on validation", opts: []Option{ WithHosts("localhost"), WithKeyspace("test_keyspace"), }, wantErr: false, }, { name: "with port should not return error on validation", opts: []Option{ WithHosts("localhost"), WithPort(9042), }, wantErr: false, }, { name: "with auth should not return error on validation", opts: []Option{ WithHosts("localhost"), WithAuth("user", "pass"), }, wantErr: false, }, { name: "with all options should not return error on validation", opts: []Option{ WithHosts("localhost"), WithKeyspace("test_keyspace"), WithPort(9042), WithAuth("user", "pass"), WithConsistency(gocql.Quorum), WithConnectTimeoutSec(10), WithNumConns(10), WithMaxRetries(3), }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, err := New(tt.opts...) if tt.wantErr { require.Error(t, err) if tt.errMsg != "" { assert.Contains(t, err.Error(), tt.errMsg) } assert.Nil(t, db) } else { // 注意:這裡可能會因為無法連接到真實的 Cassandra 而失敗 // 但至少驗證了配置驗證邏輯 if err != nil { // 如果錯誤不是驗證錯誤,而是連接錯誤,這是可以接受的 assert.NotContains(t, err.Error(), "at least one host is required") } } }) } } func TestDB_GetDefaultKeyspace(t *testing.T) { tests := []struct { name string keyspace string expectedResult string }{ { name: "empty keyspace should return empty string", keyspace: "", expectedResult: "", }, { name: "non-empty keyspace should return keyspace", keyspace: "test_keyspace", expectedResult: "test_keyspace", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 注意:這需要一個有效的 DB 實例 // 在實際測試中,可能需要 mock 或使用 testcontainers // 這裡只是展示測試結構 _ = tt }) } } func TestDB_Version(t *testing.T) { tests := []struct { name string version string expected string }{ { name: "version 5.0.0", version: "5.0.0", expected: "5.0.0", }, { name: "version 4.0.9", version: "4.0.9", expected: "4.0.9", }, { name: "version 3.11.0", version: "3.11.0", expected: "3.11.0", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 注意:這需要一個有效的 DB 實例 // 在實際測試中,可能需要 mock 或使用 testcontainers _ = tt }) } } func TestDB_SaiSupported(t *testing.T) { tests := []struct { name string version string expected bool }{ { name: "version 5.0.0 should support SAI", version: "5.0.0", expected: true, }, { name: "version 4.0.9 should support SAI", version: "4.0.9", expected: true, }, { name: "version 4.0.8 should not support SAI", version: "4.0.8", expected: false, }, { name: "version 3.11.0 should not support SAI", version: "3.11.0", expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 注意:這需要一個有效的 DB 實例 // 在實際測試中,可能需要 mock 或使用 testcontainers // 這裡只是展示測試結構 _ = tt }) } } func TestDB_GetSession(t *testing.T) { t.Run("GetSession should return non-nil session", func(t *testing.T) { // 注意:這需要一個有效的 DB 實例 // 在實際測試中,可能需要 mock 或使用 testcontainers }) } func TestDB_Close(t *testing.T) { t.Run("Close should not panic", func(t *testing.T) { // 注意:這需要一個有效的 DB 實例 // 在實際測試中,可能需要 mock 或使用 testcontainers }) } func TestDB_getVersion(t *testing.T) { tests := []struct { name string version string queryErr error wantErr bool expectedVer string }{ { name: "successful version query", version: "5.0.0", queryErr: nil, wantErr: false, expectedVer: "5.0.0", }, { name: "query error should return error", version: "", queryErr: errors.New("connection failed"), wantErr: true, expectedVer: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 注意:這需要 mock session // 在實際測試中,需要使用 mock 或 testcontainers _ = tt }) } } func TestDB_withContextAndTimestamp(t *testing.T) { t.Run("withContextAndTimestamp should add context and timestamp", func(t *testing.T) { // 注意:這需要 mock query // 在實際測試中,需要使用 mock }) } func TestDefaultConfig(t *testing.T) { t.Run("defaultConfig should return valid config", func(t *testing.T) { cfg := defaultConfig() require.NotNil(t, cfg) assert.Equal(t, defaultPort, cfg.Port) assert.Equal(t, defaultConsistency, cfg.Consistency) assert.Equal(t, defaultTimeoutSec, cfg.ConnectTimeoutSec) assert.Equal(t, defaultNumConns, cfg.NumConns) assert.Equal(t, defaultMaxRetries, cfg.MaxRetries) assert.Equal(t, defaultRetryMinInterval, cfg.RetryMinInterval) assert.Equal(t, defaultRetryMaxInterval, cfg.RetryMaxInterval) assert.Equal(t, defaultReconnectInitialInterval, cfg.ReconnectInitialInterval) assert.Equal(t, defaultReconnectMaxInterval, cfg.ReconnectMaxInterval) assert.Equal(t, defaultCqlVersion, cfg.CQLVersion) }) } func TestOptionFunctions(t *testing.T) { tests := []struct { name string opt Option validateConfig func(*testing.T, *config) }{ { name: "WithHosts should set hosts", opt: WithHosts("host1", "host2"), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, []string{"host1", "host2"}, c.Hosts) }, }, { name: "WithPort should set port", opt: WithPort(9999), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, 9999, c.Port) }, }, { name: "WithKeyspace should set keyspace", opt: WithKeyspace("test_keyspace"), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, "test_keyspace", c.Keyspace) }, }, { name: "WithAuth should set auth and enable UseAuth", opt: WithAuth("user", "pass"), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, "user", c.Username) assert.Equal(t, "pass", c.Password) assert.True(t, c.UseAuth) }, }, { name: "WithConsistency should set consistency", opt: WithConsistency(gocql.One), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, gocql.One, c.Consistency) }, }, { name: "WithConnectTimeoutSec should set timeout", opt: WithConnectTimeoutSec(20), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, 20, c.ConnectTimeoutSec) }, }, { name: "WithConnectTimeoutSec with zero should use default", opt: WithConnectTimeoutSec(0), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, defaultTimeoutSec, c.ConnectTimeoutSec) }, }, { name: "WithNumConns should set numConns", opt: WithNumConns(20), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, 20, c.NumConns) }, }, { name: "WithNumConns with zero should use default", opt: WithNumConns(0), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, defaultNumConns, c.NumConns) }, }, { name: "WithMaxRetries should set maxRetries", opt: WithMaxRetries(5), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, 5, c.MaxRetries) }, }, { name: "WithMaxRetries with zero should use default", opt: WithMaxRetries(0), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, defaultMaxRetries, c.MaxRetries) }, }, { name: "WithRetryMinInterval should set retryMinInterval", opt: WithRetryMinInterval(2 * time.Second), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, 2*time.Second, c.RetryMinInterval) }, }, { name: "WithRetryMinInterval with zero should use default", opt: WithRetryMinInterval(0), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, defaultRetryMinInterval, c.RetryMinInterval) }, }, { name: "WithRetryMaxInterval should set retryMaxInterval", opt: WithRetryMaxInterval(60 * time.Second), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, 60*time.Second, c.RetryMaxInterval) }, }, { name: "WithRetryMaxInterval with zero should use default", opt: WithRetryMaxInterval(0), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, defaultRetryMaxInterval, c.RetryMaxInterval) }, }, { name: "WithReconnectInitialInterval should set reconnectInitialInterval", opt: WithReconnectInitialInterval(2 * time.Second), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, 2*time.Second, c.ReconnectInitialInterval) }, }, { name: "WithReconnectInitialInterval with zero should use default", opt: WithReconnectInitialInterval(0), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, defaultReconnectInitialInterval, c.ReconnectInitialInterval) }, }, { name: "WithReconnectMaxInterval should set reconnectMaxInterval", opt: WithReconnectMaxInterval(120 * time.Second), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, 120*time.Second, c.ReconnectMaxInterval) }, }, { name: "WithReconnectMaxInterval with zero should use default", opt: WithReconnectMaxInterval(0), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, defaultReconnectMaxInterval, c.ReconnectMaxInterval) }, }, { name: "WithCQLVersion should set CQLVersion", opt: WithCQLVersion("3.1.0"), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, "3.1.0", c.CQLVersion) }, }, { name: "WithCQLVersion with empty should use default", opt: WithCQLVersion(""), validateConfig: func(t *testing.T, c *config) { assert.Equal(t, defaultCqlVersion, c.CQLVersion) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := defaultConfig() tt.opt(cfg) tt.validateConfig(t, cfg) }) } } func TestMultipleOptions(t *testing.T) { t.Run("multiple options should be applied correctly", func(t *testing.T) { cfg := defaultConfig() WithHosts("host1", "host2")(cfg) WithPort(9999)(cfg) WithKeyspace("test")(cfg) WithAuth("user", "pass")(cfg) assert.Equal(t, []string{"host1", "host2"}, cfg.Hosts) assert.Equal(t, 9999, cfg.Port) assert.Equal(t, "test", cfg.Keyspace) assert.Equal(t, "user", cfg.Username) assert.Equal(t, "pass", cfg.Password) assert.True(t, cfg.UseAuth) }) }