53 lines
1.5 KiB
Go
53 lines
1.5 KiB
Go
|
package mongo
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
|
||
|
"github.com/shopspring/decimal"
|
||
|
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||
|
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
||
|
)
|
||
|
|
||
|
type MgoDecimal struct{}
|
||
|
|
||
|
var (
|
||
|
_ bsoncodec.ValueEncoder = &MgoDecimal{}
|
||
|
_ bsoncodec.ValueDecoder = &MgoDecimal{}
|
||
|
)
|
||
|
|
||
|
func (dc *MgoDecimal) EncodeValue(_ bsoncodec.EncodeContext, w bsonrw.ValueWriter, value reflect.Value) error {
|
||
|
// TODO 待確認是否有非decimal.Decimal type而導致error的場景
|
||
|
dec, ok := value.Interface().(decimal.Decimal)
|
||
|
if !ok {
|
||
|
return fmt.Errorf("value %v to encode is not of type decimal.Decimal", value)
|
||
|
}
|
||
|
|
||
|
// Convert decimal.Decimal to primitive.Decimal128.
|
||
|
primDec, err := primitive.ParseDecimal128(dec.String())
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("converting decimal.Decimal %v to primitive.Decimal128 error: %w", dec, err)
|
||
|
}
|
||
|
|
||
|
return w.WriteDecimal128(primDec)
|
||
|
}
|
||
|
|
||
|
func (dc *MgoDecimal) DecodeValue(_ bsoncodec.DecodeContext, r bsonrw.ValueReader, value reflect.Value) error {
|
||
|
primDec, err := r.ReadDecimal128()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("reading primitive.Decimal128 from ValueReader error: %w", err)
|
||
|
}
|
||
|
|
||
|
// Convert primitive.Decimal128 to decimal.Decimal.
|
||
|
dec, err := decimal.NewFromString(primDec.String())
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("converting primitive.Decimal128 %v to decimal.Decimal error: %w", primDec, err)
|
||
|
}
|
||
|
|
||
|
// set as decimal.Decimal type
|
||
|
value.Set(reflect.ValueOf(dec))
|
||
|
|
||
|
return nil
|
||
|
}
|