backend/pkg/library/mongo/custom_mongo_decimal_test.go

275 lines
6.8 KiB
Go
Raw Normal View History

2025-10-01 16:30:27 +00:00
package mongo
import (
"reflect"
"testing"
"github.com/shopspring/decimal"
"go.mongodb.org/mongo-driver/v2/bson"
)
func TestMgoDecimal_InterfaceCompliance(t *testing.T) {
encoder := &MgoDecimal{}
decoder := &MgoDecimal{}
// Test that they implement the required interfaces
var _ bson.ValueEncoder = encoder
var _ bson.ValueDecoder = decoder
// Test that they can be used in TypeCodec
codec := TypeCodec{
ValueType: reflect.TypeOf(decimal.Decimal{}),
Encoder: encoder,
Decoder: decoder,
}
if codec.Encoder != encoder {
t.Error("Expected encoder to be set correctly")
}
if codec.Decoder != decoder {
t.Error("Expected decoder to be set correctly")
}
}
func TestMgoDecimal_EncodeValue_InvalidType(t *testing.T) {
encoder := &MgoDecimal{}
// Test with invalid type
value := reflect.ValueOf("not a decimal")
err := encoder.EncodeValue(bson.EncodeContext{}, nil, value)
if err == nil {
t.Error("Expected error for invalid type, got nil")
}
expectedErr := "value not a decimal to encode is not of type decimal.Decimal"
if err.Error() != expectedErr {
t.Errorf("Expected error '%s', got '%s'", expectedErr, err.Error())
}
}
// Test decimal conversion functions
func TestDecimalConversion(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"0", "0"},
{"123.45", "123.45"},
{"-123.45", "-123.45"},
{"0.000001", "0.000001"},
{"9999999999999999999.999999999999999", "9999999999999999999.999999999999999"},
{"-9999999999999999999.999999999999999", "-9999999999999999999.999999999999999"},
}
for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
// Test decimal to string conversion
dec, err := decimal.NewFromString(tc.input)
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", tc.input, err)
}
if dec.String() != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, dec.String())
}
// Test BSON decimal128 conversion
primDec, err := bson.ParseDecimal128(dec.String())
if err != nil {
t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err)
}
if primDec.String() != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, primDec.String())
}
})
}
}
// Test error cases
func TestDecimalConversionErrors(t *testing.T) {
invalidCases := []string{
"invalid",
"not a number",
"",
"123.45.67",
"abc123",
}
for _, invalid := range invalidCases {
t.Run(invalid, func(t *testing.T) {
_, err := decimal.NewFromString(invalid)
if err == nil {
t.Errorf("Expected error for invalid decimal string: %s", invalid)
}
_, err = bson.ParseDecimal128(invalid)
if err == nil {
t.Errorf("Expected error for invalid decimal128 string: %s", invalid)
}
})
}
}
// Test edge cases for decimal values
func TestDecimalEdgeCases(t *testing.T) {
testCases := []struct {
name string
value decimal.Decimal
expected string
}{
{"zero", decimal.Zero, "0"},
{"positive small", decimal.NewFromFloat(0.000001), "0.000001"},
{"negative small", decimal.NewFromFloat(-0.000001), "-0.000001"},
{"positive large", decimal.NewFromInt(999999999999999), "999999999999999"},
{"negative large", decimal.NewFromInt(-999999999999999), "-999999999999999"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test conversion to BSON Decimal128
primDec, err := bson.ParseDecimal128(tc.value.String())
if err != nil {
t.Fatalf("Failed to parse decimal128 from %s: %v", tc.value.String(), err)
}
// Test conversion back to decimal
dec, err := decimal.NewFromString(primDec.String())
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err)
}
if !dec.Equal(tc.value) {
t.Errorf("Round trip failed: original=%s, result=%s", tc.value.String(), dec.String())
}
})
}
}
// Test error handling in encoder
func TestMgoDecimal_EncoderErrors(t *testing.T) {
encoder := &MgoDecimal{}
testCases := []struct {
name string
value interface{}
}{
{"string", "not a decimal"},
{"int", 123},
{"float", 123.45},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
value := reflect.ValueOf(tc.value)
err := encoder.EncodeValue(bson.EncodeContext{}, nil, value)
if err == nil {
t.Errorf("Expected error for type %T, got nil", tc.value)
}
})
}
}
// Test decimal precision
func TestDecimalPrecision(t *testing.T) {
testCases := []string{
"0.1",
"0.01",
"0.001",
"0.0001",
"0.00001",
"0.000001",
"0.0000001",
"0.00000001",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
dec, err := decimal.NewFromString(tc)
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", tc, err)
}
// Test conversion to BSON Decimal128
primDec, err := bson.ParseDecimal128(dec.String())
if err != nil {
t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err)
}
// Test conversion back to decimal
result, err := decimal.NewFromString(primDec.String())
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err)
}
if !result.Equal(dec) {
t.Errorf("Precision lost: original=%s, result=%s", dec.String(), result.String())
}
})
}
}
// Test large numbers
func TestDecimalLargeNumbers(t *testing.T) {
testCases := []string{
"1000000000000000",
"10000000000000000",
"100000000000000000",
"1000000000000000000",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
dec, err := decimal.NewFromString(tc)
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", tc, err)
}
// Test conversion to BSON Decimal128
primDec, err := bson.ParseDecimal128(dec.String())
if err != nil {
t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err)
}
// Test conversion back to decimal
result, err := decimal.NewFromString(primDec.String())
if err != nil {
t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err)
}
if !result.Equal(dec) {
t.Errorf("Large number lost: original=%s, result=%s", dec.String(), result.String())
}
})
}
}
// Benchmark tests
func BenchmarkMgoDecimal_ParseDecimal128(b *testing.B) {
dec := decimal.NewFromFloat(123.45)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = bson.ParseDecimal128(dec.String())
}
}
func BenchmarkMgoDecimal_DecimalFromString(b *testing.B) {
primDec, _ := bson.ParseDecimal128("123.45")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = decimal.NewFromString(primDec.String())
}
}
func BenchmarkMgoDecimal_RoundTrip(b *testing.B) {
dec := decimal.NewFromFloat(123.45)
b.ResetTimer()
for i := 0; i < b.N; i++ {
primDec, _ := bson.ParseDecimal128(dec.String())
_, _ = decimal.NewFromString(primDec.String())
}
}