package cassandra import ( "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestError_Error(t *testing.T) { tests := []struct { name string err *Error want string contains []string // 如果 want 為空,則檢查是否包含這些字串 }{ { name: "error with code and message only", err: &Error{ Code: ErrCodeNotFound, Message: "record not found", }, want: "cassandra[NOT_FOUND]: record not found", }, { name: "error with code, message and table", err: &Error{ Code: ErrCodeNotFound, Message: "record not found", Table: "users", }, want: "cassandra[NOT_FOUND] (table: users): record not found", }, { name: "error with code, message and underlying error", err: &Error{ Code: ErrCodeInvalidInput, Message: "invalid input parameter", Err: errors.New("validation failed"), }, contains: []string{ "cassandra[INVALID_INPUT]", "invalid input parameter", "validation failed", }, }, { name: "error with all fields", err: &Error{ Code: ErrCodeConflict, Message: "acquire lock failed", Table: "locks", Err: errors.New("lock already exists"), }, contains: []string{ "cassandra[CONFLICT]", "(table: locks)", "acquire lock failed", "lock already exists", }, }, { name: "error with empty message", err: &Error{ Code: ErrCodeNotFound, }, want: "cassandra[NOT_FOUND]: ", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.err.Error() if tt.want != "" { assert.Equal(t, tt.want, result) } else { for _, substr := range tt.contains { assert.Contains(t, result, substr) } } }) } } func TestError_Unwrap(t *testing.T) { tests := []struct { name string err *Error wantErr error }{ { name: "error with underlying error", err: &Error{ Code: ErrCodeInvalidInput, Message: "invalid input", Err: errors.New("underlying error"), }, wantErr: errors.New("underlying error"), }, { name: "error without underlying error", err: &Error{ Code: ErrCodeNotFound, Message: "not found", }, wantErr: nil, }, { name: "error with nil underlying error", err: &Error{ Code: ErrCodeNotFound, Message: "not found", Err: nil, }, wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.err.Unwrap() if tt.wantErr == nil { assert.Nil(t, result) } else { assert.Equal(t, tt.wantErr.Error(), result.Error()) } }) } } func TestError_WithTable(t *testing.T) { tests := []struct { name string err *Error table string wantCode ErrorCode wantMsg string wantTbl string }{ { name: "add table to error without table", err: &Error{ Code: ErrCodeNotFound, Message: "record not found", }, table: "users", wantCode: ErrCodeNotFound, wantMsg: "record not found", wantTbl: "users", }, { name: "replace existing table", err: &Error{ Code: ErrCodeNotFound, Message: "record not found", Table: "old_table", }, table: "new_table", wantCode: ErrCodeNotFound, wantMsg: "record not found", wantTbl: "new_table", }, { name: "add table to error with underlying error", err: &Error{ Code: ErrCodeInvalidInput, Message: "invalid input", Err: errors.New("validation failed"), }, table: "products", wantCode: ErrCodeInvalidInput, wantMsg: "invalid input", wantTbl: "products", }, { name: "add empty table", err: &Error{ Code: ErrCodeNotFound, Message: "not found", }, table: "", wantCode: ErrCodeNotFound, wantMsg: "not found", wantTbl: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.err.WithTable(tt.table) assert.NotNil(t, result) assert.Equal(t, tt.wantCode, result.Code) assert.Equal(t, tt.wantMsg, result.Message) assert.Equal(t, tt.wantTbl, result.Table) // 確保是新的實例,不是修改原來的 assert.NotSame(t, tt.err, result) }) } } func TestError_WithError(t *testing.T) { tests := []struct { name string err *Error underlying error wantCode ErrorCode wantMsg string wantErr error }{ { name: "add underlying error to error without error", err: &Error{ Code: ErrCodeInvalidInput, Message: "invalid input", }, underlying: errors.New("validation failed"), wantCode: ErrCodeInvalidInput, wantMsg: "invalid input", wantErr: errors.New("validation failed"), }, { name: "replace existing underlying error", err: &Error{ Code: ErrCodeInvalidInput, Message: "invalid input", Err: errors.New("old error"), }, underlying: errors.New("new error"), wantCode: ErrCodeInvalidInput, wantMsg: "invalid input", wantErr: errors.New("new error"), }, { name: "add nil underlying error", err: &Error{ Code: ErrCodeNotFound, Message: "not found", }, underlying: nil, wantCode: ErrCodeNotFound, wantMsg: "not found", wantErr: nil, }, { name: "add error to error with table", err: &Error{ Code: ErrCodeConflict, Message: "conflict", Table: "locks", }, underlying: errors.New("lock exists"), wantCode: ErrCodeConflict, wantMsg: "conflict", wantErr: errors.New("lock exists"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.err.WithError(tt.underlying) assert.NotNil(t, result) assert.Equal(t, tt.wantCode, result.Code) assert.Equal(t, tt.wantMsg, result.Message) // 確保是新的實例 assert.NotSame(t, tt.err, result) // 檢查 underlying error if tt.wantErr == nil { assert.Nil(t, result.Err) } else { require.NotNil(t, result.Err) assert.Equal(t, tt.wantErr.Error(), result.Err.Error()) } }) } } func TestNewError(t *testing.T) { tests := []struct { name string code ErrorCode message string want *Error }{ { name: "create NOT_FOUND error", code: ErrCodeNotFound, message: "record not found", want: &Error{ Code: ErrCodeNotFound, Message: "record not found", }, }, { name: "create CONFLICT error", code: ErrCodeConflict, message: "lock acquisition failed", want: &Error{ Code: ErrCodeConflict, Message: "lock acquisition failed", }, }, { name: "create INVALID_INPUT error", code: ErrCodeInvalidInput, message: "invalid parameter", want: &Error{ Code: ErrCodeInvalidInput, Message: "invalid parameter", }, }, { name: "create error with empty message", code: ErrCodeNotFound, message: "", want: &Error{ Code: ErrCodeNotFound, Message: "", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := NewError(tt.code, tt.message) assert.NotNil(t, result) assert.Equal(t, tt.want.Code, result.Code) assert.Equal(t, tt.want.Message, result.Message) assert.Empty(t, result.Table) assert.Nil(t, result.Err) }) } } func TestIsNotFound(t *testing.T) { tests := []struct { name string err error want bool }{ { name: "Error with NOT_FOUND code", err: &Error{ Code: ErrCodeNotFound, Message: "record not found", }, want: true, }, { name: "Error with CONFLICT code", err: &Error{ Code: ErrCodeConflict, Message: "conflict", }, want: false, }, { name: "Error with INVALID_INPUT code", err: &Error{ Code: ErrCodeInvalidInput, Message: "invalid input", }, want: false, }, { name: "wrapped Error with NOT_FOUND code", err: &Error{ Code: ErrCodeNotFound, Message: "record not found", Err: errors.New("underlying error"), }, want: true, }, { name: "standard error", err: errors.New("standard error"), want: false, }, { name: "nil error", err: nil, want: false, }, { name: "predefined ErrNotFound", err: ErrNotFound, want: true, }, { name: "predefined ErrNotFound with table", err: ErrNotFound.WithTable("users"), want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := IsNotFound(tt.err) assert.Equal(t, tt.want, result) }) } } func TestIsConflict(t *testing.T) { tests := []struct { name string err error want bool }{ { name: "Error with CONFLICT code", err: &Error{ Code: ErrCodeConflict, Message: "conflict", }, want: true, }, { name: "Error with NOT_FOUND code", err: &Error{ Code: ErrCodeNotFound, Message: "record not found", }, want: false, }, { name: "Error with INVALID_INPUT code", err: &Error{ Code: ErrCodeInvalidInput, Message: "invalid input", }, want: false, }, { name: "wrapped Error with CONFLICT code", err: &Error{ Code: ErrCodeConflict, Message: "conflict", Err: errors.New("underlying error"), }, want: true, }, { name: "standard error", err: errors.New("standard error"), want: false, }, { name: "nil error", err: nil, want: false, }, { name: "NewError with CONFLICT code", err: NewError(ErrCodeConflict, "lock failed"), want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := IsConflict(tt.err) assert.Equal(t, tt.want, result) }) } } func TestPredefinedErrors(t *testing.T) { tests := []struct { name string err *Error wantCode ErrorCode wantMsg string }{ { name: "ErrNotFound", err: ErrNotFound, wantCode: ErrCodeNotFound, wantMsg: "record not found", }, { name: "ErrInvalidInput", err: ErrInvalidInput, wantCode: ErrCodeInvalidInput, wantMsg: "invalid input parameter", }, { name: "ErrNoPartitionKey", err: ErrNoPartitionKey, wantCode: ErrCodeMissingPartition, wantMsg: "no partition key defined in struct", }, { name: "ErrMissingTableName", err: ErrMissingTableName, wantCode: ErrCodeMissingTableName, wantMsg: "struct must implement TableName() method", }, { name: "ErrNoFieldsToUpdate", err: ErrNoFieldsToUpdate, wantCode: ErrCodeNoFieldsToUpdate, wantMsg: "no fields to update", }, { name: "ErrMissingWhereCondition", err: ErrMissingWhereCondition, wantCode: ErrCodeMissingWhereCondition, wantMsg: "operation requires at least one WHERE condition for safety", }, { name: "ErrMissingPartitionKey", err: ErrMissingPartitionKey, wantCode: ErrCodeMissingPartition, wantMsg: "operation requires all partition keys in WHERE clause", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.NotNil(t, tt.err) assert.Equal(t, tt.wantCode, tt.err.Code) assert.Equal(t, tt.wantMsg, tt.err.Message) assert.Empty(t, tt.err.Table) assert.Nil(t, tt.err.Err) }) } } func TestError_Chaining(t *testing.T) { t.Run("chain WithTable and WithError", func(t *testing.T) { err := NewError(ErrCodeNotFound, "record not found"). WithTable("users"). WithError(errors.New("database error")) assert.Equal(t, ErrCodeNotFound, err.Code) assert.Equal(t, "record not found", err.Message) assert.Equal(t, "users", err.Table) assert.NotNil(t, err.Err) assert.Equal(t, "database error", err.Err.Error()) assert.True(t, IsNotFound(err)) }) t.Run("chain multiple WithTable calls", func(t *testing.T) { err1 := ErrNotFound.WithTable("table1") err2 := err1.WithTable("table2") assert.Equal(t, "table1", err1.Table) assert.Equal(t, "table2", err2.Table) assert.NotSame(t, err1, err2) }) t.Run("chain multiple WithError calls", func(t *testing.T) { err1 := ErrInvalidInput.WithError(errors.New("error1")) err2 := err1.WithError(errors.New("error2")) assert.Equal(t, "error1", err1.Err.Error()) assert.Equal(t, "error2", err2.Err.Error()) assert.NotSame(t, err1, err2) }) } func TestError_ErrorsAs(t *testing.T) { t.Run("errors.As works with Error", func(t *testing.T) { err := ErrNotFound.WithTable("users") var target *Error ok := errors.As(err, &target) assert.True(t, ok) assert.NotNil(t, target) assert.Equal(t, ErrCodeNotFound, target.Code) assert.Equal(t, "users", target.Table) }) t.Run("errors.As works with wrapped Error", func(t *testing.T) { underlying := errors.New("underlying error") err := ErrInvalidInput.WithError(underlying) var target *Error ok := errors.As(err, &target) assert.True(t, ok) assert.NotNil(t, target) assert.Equal(t, ErrCodeInvalidInput, target.Code) assert.Equal(t, underlying, target.Err) }) t.Run("errors.Is works with Error", func(t *testing.T) { err := ErrNotFound assert.True(t, errors.Is(err, ErrNotFound)) assert.False(t, errors.Is(err, ErrInvalidInput)) }) }