92 lines
2.2 KiB
Go
92 lines
2.2 KiB
Go
package cassandra
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/testcontainers/testcontainers-go"
|
|
"github.com/testcontainers/testcontainers-go/wait"
|
|
)
|
|
|
|
// startCassandraContainer 啟動 Cassandra 測試容器
|
|
func startCassandraContainer(ctx context.Context) (string, string, func(), error) {
|
|
req := testcontainers.ContainerRequest{
|
|
Image: "cassandra:4.1",
|
|
ExposedPorts: []string{"9042/tcp"},
|
|
WaitingFor: wait.ForListeningPort("9042/tcp"),
|
|
Env: map[string]string{
|
|
"CASSANDRA_CLUSTER_NAME": "test-cluster",
|
|
},
|
|
}
|
|
|
|
cassandraC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
|
ContainerRequest: req,
|
|
Started: true,
|
|
})
|
|
if err != nil {
|
|
return "", "", nil, fmt.Errorf("failed to start Cassandra container: %w", err)
|
|
}
|
|
|
|
port, err := cassandraC.MappedPort(ctx, "9042")
|
|
if err != nil {
|
|
cassandraC.Terminate(ctx)
|
|
return "", "", nil, fmt.Errorf("failed to get mapped port: %w", err)
|
|
}
|
|
|
|
host, err := cassandraC.Host(ctx)
|
|
if err != nil {
|
|
cassandraC.Terminate(ctx)
|
|
return "", "", nil, fmt.Errorf("failed to get host: %w", err)
|
|
}
|
|
|
|
tearDown := func() {
|
|
_ = cassandraC.Terminate(ctx)
|
|
}
|
|
|
|
fmt.Printf("Cassandra test container started: %s:%s\n", host, port.Port())
|
|
|
|
return host, port.Port(), tearDown, nil
|
|
}
|
|
|
|
// setupTestDB 設置測試用的 DB 實例
|
|
func setupTestDB(t testing.TB) (*DB, func()) {
|
|
ctx := context.Background()
|
|
host, port, tearDown, err := startCassandraContainer(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Failed to start Cassandra container: %v", err)
|
|
}
|
|
|
|
portInt, err := strconv.Atoi(port)
|
|
if err != nil {
|
|
tearDown()
|
|
t.Fatalf("Failed to convert port to int: %v", err)
|
|
}
|
|
|
|
db, err := New(
|
|
WithHosts(host),
|
|
WithPort(portInt),
|
|
WithKeyspace("test_keyspace"),
|
|
)
|
|
if err != nil {
|
|
tearDown()
|
|
t.Fatalf("Failed to create DB: %v", err)
|
|
}
|
|
|
|
// 創建 keyspace
|
|
createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
|
|
if err := db.session.Query(createKeyspaceStmt, nil).Exec(); err != nil {
|
|
db.Close()
|
|
tearDown()
|
|
t.Fatalf("Failed to create keyspace: %v", err)
|
|
}
|
|
|
|
cleanup := func() {
|
|
db.Close()
|
|
tearDown()
|
|
}
|
|
|
|
return db, cleanup
|
|
}
|