package mongo import ( "reflect" "testing" "github.com/shopspring/decimal" "go.mongodb.org/mongo-driver/v2/bson" ) func TestMgoDecimal_InterfaceCompliance(t *testing.T) { encoder := &MgoDecimal{} decoder := &MgoDecimal{} // Test that they implement the required interfaces var _ bson.ValueEncoder = encoder var _ bson.ValueDecoder = decoder // Test that they can be used in TypeCodec codec := TypeCodec{ ValueType: reflect.TypeOf(decimal.Decimal{}), Encoder: encoder, Decoder: decoder, } if codec.Encoder != encoder { t.Error("Expected encoder to be set correctly") } if codec.Decoder != decoder { t.Error("Expected decoder to be set correctly") } } func TestMgoDecimal_EncodeValue_InvalidType(t *testing.T) { encoder := &MgoDecimal{} // Test with invalid type value := reflect.ValueOf("not a decimal") err := encoder.EncodeValue(bson.EncodeContext{}, nil, value) if err == nil { t.Error("Expected error for invalid type, got nil") } expectedErr := "value not a decimal to encode is not of type decimal.Decimal" if err.Error() != expectedErr { t.Errorf("Expected error '%s', got '%s'", expectedErr, err.Error()) } } // Test decimal conversion functions func TestDecimalConversion(t *testing.T) { testCases := []struct { input string expected string }{ {"0", "0"}, {"123.45", "123.45"}, {"-123.45", "-123.45"}, {"0.000001", "0.000001"}, {"9999999999999999999.999999999999999", "9999999999999999999.999999999999999"}, {"-9999999999999999999.999999999999999", "-9999999999999999999.999999999999999"}, } for _, tc := range testCases { t.Run(tc.input, func(t *testing.T) { // Test decimal to string conversion dec, err := decimal.NewFromString(tc.input) if err != nil { t.Fatalf("Failed to create decimal from %s: %v", tc.input, err) } if dec.String() != tc.expected { t.Errorf("Expected %s, got %s", tc.expected, dec.String()) } // Test BSON decimal128 conversion primDec, err := bson.ParseDecimal128(dec.String()) if err != nil { t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err) } if primDec.String() != tc.expected { t.Errorf("Expected %s, got %s", tc.expected, primDec.String()) } }) } } // Test error cases func TestDecimalConversionErrors(t *testing.T) { invalidCases := []string{ "invalid", "not a number", "", "123.45.67", "abc123", } for _, invalid := range invalidCases { t.Run(invalid, func(t *testing.T) { _, err := decimal.NewFromString(invalid) if err == nil { t.Errorf("Expected error for invalid decimal string: %s", invalid) } _, err = bson.ParseDecimal128(invalid) if err == nil { t.Errorf("Expected error for invalid decimal128 string: %s", invalid) } }) } } // Test edge cases for decimal values func TestDecimalEdgeCases(t *testing.T) { testCases := []struct { name string value decimal.Decimal expected string }{ {"zero", decimal.Zero, "0"}, {"positive small", decimal.NewFromFloat(0.000001), "0.000001"}, {"negative small", decimal.NewFromFloat(-0.000001), "-0.000001"}, {"positive large", decimal.NewFromInt(999999999999999), "999999999999999"}, {"negative large", decimal.NewFromInt(-999999999999999), "-999999999999999"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Test conversion to BSON Decimal128 primDec, err := bson.ParseDecimal128(tc.value.String()) if err != nil { t.Fatalf("Failed to parse decimal128 from %s: %v", tc.value.String(), err) } // Test conversion back to decimal dec, err := decimal.NewFromString(primDec.String()) if err != nil { t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err) } if !dec.Equal(tc.value) { t.Errorf("Round trip failed: original=%s, result=%s", tc.value.String(), dec.String()) } }) } } // Test error handling in encoder func TestMgoDecimal_EncoderErrors(t *testing.T) { encoder := &MgoDecimal{} testCases := []struct { name string value interface{} }{ {"string", "not a decimal"}, {"int", 123}, {"float", 123.45}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { value := reflect.ValueOf(tc.value) err := encoder.EncodeValue(bson.EncodeContext{}, nil, value) if err == nil { t.Errorf("Expected error for type %T, got nil", tc.value) } }) } } // Test decimal precision func TestDecimalPrecision(t *testing.T) { testCases := []string{ "0.1", "0.01", "0.001", "0.0001", "0.00001", "0.000001", "0.0000001", "0.00000001", } for _, tc := range testCases { t.Run(tc, func(t *testing.T) { dec, err := decimal.NewFromString(tc) if err != nil { t.Fatalf("Failed to create decimal from %s: %v", tc, err) } // Test conversion to BSON Decimal128 primDec, err := bson.ParseDecimal128(dec.String()) if err != nil { t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err) } // Test conversion back to decimal result, err := decimal.NewFromString(primDec.String()) if err != nil { t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err) } if !result.Equal(dec) { t.Errorf("Precision lost: original=%s, result=%s", dec.String(), result.String()) } }) } } // Test large numbers func TestDecimalLargeNumbers(t *testing.T) { testCases := []string{ "1000000000000000", "10000000000000000", "100000000000000000", "1000000000000000000", } for _, tc := range testCases { t.Run(tc, func(t *testing.T) { dec, err := decimal.NewFromString(tc) if err != nil { t.Fatalf("Failed to create decimal from %s: %v", tc, err) } // Test conversion to BSON Decimal128 primDec, err := bson.ParseDecimal128(dec.String()) if err != nil { t.Fatalf("Failed to parse decimal128 from %s: %v", dec.String(), err) } // Test conversion back to decimal result, err := decimal.NewFromString(primDec.String()) if err != nil { t.Fatalf("Failed to create decimal from %s: %v", primDec.String(), err) } if !result.Equal(dec) { t.Errorf("Large number lost: original=%s, result=%s", dec.String(), result.String()) } }) } } // Benchmark tests func BenchmarkMgoDecimal_ParseDecimal128(b *testing.B) { dec := decimal.NewFromFloat(123.45) b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = bson.ParseDecimal128(dec.String()) } } func BenchmarkMgoDecimal_DecimalFromString(b *testing.B) { primDec, _ := bson.ParseDecimal128("123.45") b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = decimal.NewFromString(primDec.String()) } } func BenchmarkMgoDecimal_RoundTrip(b *testing.B) { dec := decimal.NewFromFloat(123.45) b.ResetTimer() for i := 0; i < b.N; i++ { primDec, _ := bson.ParseDecimal128(dec.String()) _, _ = decimal.NewFromString(primDec.String()) } }