package mongo import ( "fmt" "reflect" "github.com/shopspring/decimal" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/primitive" ) type MgoDecimal struct{} var ( _ bsoncodec.ValueEncoder = &MgoDecimal{} _ bsoncodec.ValueDecoder = &MgoDecimal{} ) func (dc *MgoDecimal) EncodeValue(_ bsoncodec.EncodeContext, w bsonrw.ValueWriter, value reflect.Value) error { // TODO 待確認是否有非decimal.Decimal type而導致error的場景 dec, ok := value.Interface().(decimal.Decimal) if !ok { return fmt.Errorf("value %v to encode is not of type decimal.Decimal", value) } // Convert decimal.Decimal to primitive.Decimal128. primDec, err := primitive.ParseDecimal128(dec.String()) if err != nil { return fmt.Errorf("converting decimal.Decimal %v to primitive.Decimal128 error: %w", dec, err) } return w.WriteDecimal128(primDec) } func (dc *MgoDecimal) DecodeValue(_ bsoncodec.DecodeContext, r bsonrw.ValueReader, value reflect.Value) error { primDec, err := r.ReadDecimal128() if err != nil { return fmt.Errorf("reading primitive.Decimal128 from ValueReader error: %w", err) } // Convert primitive.Decimal128 to decimal.Decimal. dec, err := decimal.NewFromString(primDec.String()) if err != nil { return fmt.Errorf("converting primitive.Decimal128 %v to decimal.Decimal error: %w", primDec, err) } // set as decimal.Decimal type value.Set(reflect.ValueOf(dec)) return nil }