diff --git a/marshal.go b/marshal.go index ac8a30b9c..1b21251fb 100644 --- a/marshal.go +++ b/marshal.go @@ -9,7 +9,6 @@ import ( "encoding/binary" "errors" "fmt" - "gopkg.in/inf.v0" "math" "math/big" "math/bits" @@ -23,6 +22,7 @@ import ( "github.com/gocql/gocql/serialization/blob" "github.com/gocql/gocql/serialization/counter" "github.com/gocql/gocql/serialization/cqlint" + "github.com/gocql/gocql/serialization/decimal" "github.com/gocql/gocql/serialization/double" "github.com/gocql/gocql/serialization/float" "github.com/gocql/gocql/serialization/inet" @@ -168,7 +168,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeDouble: return marshalDouble(value) case TypeDecimal: - return marshalDecimal(info, value) + return marshalDecimal(value) case TypeTime: return marshalTime(info, value) case TypeTimestamp: @@ -282,7 +282,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { case TypeDouble: return unmarshalDouble(data, value) case TypeDecimal: - return unmarshalDecimal(info, data, value) + return unmarshalDecimal(data, value) case TypeTime: return unmarshalTime(info, data, value) case TypeTimestamp: @@ -611,44 +611,19 @@ func unmarshalDouble(data []byte, value interface{}) error { return nil } -func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { - if value == nil { - return nil, nil - } - - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case inf.Dec: - unscaled := encBigInt2C(v.UnscaledBig()) - if unscaled == nil { - return nil, marshalErrorf("can not marshal %T into %s", value, info) - } - - buf := make([]byte, 4+len(unscaled)) - copy(buf[0:4], encInt(int32(v.Scale()))) - copy(buf[4:], unscaled) - return buf, nil +func marshalDecimal(value interface{}) ([]byte, error) { + data, err := decimal.Marshal(value) + if err != nil { + return nil, wrapMarshalError(err, "marshal error") } - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return data, nil } -func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *inf.Dec: - if len(data) < 4 { - return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) - } - scale := decInt(data[0:4]) - unscaled := decBigInt2C(data[4:], nil) - *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) - return nil +func unmarshalDecimal(data []byte, value interface{}) error { + if err := decimal.Unmarshal(data, value); err != nil { + return wrapUnmarshalError(err, "unmarshal error") } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) + return nil } // decBigInt2C sets the value of n to the big-endian two's complement diff --git a/serialization/decimal/marshal.go b/serialization/decimal/marshal.go new file mode 100644 index 000000000..dbf95a249 --- /dev/null +++ b/serialization/decimal/marshal.go @@ -0,0 +1,29 @@ +package decimal + +import ( + "gopkg.in/inf.v0" + "reflect" +) + +func Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case nil: + return nil, nil + case inf.Dec: + return EncInfDec(v) + case *inf.Dec: + return EncInfDecR(v) + case string: + return EncString(v) + case *string: + return EncStringR(v) + default: + // Custom types (type MyString string) can be serialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.TypeOf(value) + if rv.Kind() != reflect.Ptr { + return EncReflect(reflect.ValueOf(v)) + } + return EncReflectR(reflect.ValueOf(v)) + } +} diff --git a/serialization/decimal/marshal_utils.go b/serialization/decimal/marshal_utils.go new file mode 100644 index 000000000..6384af79f --- /dev/null +++ b/serialization/decimal/marshal_utils.go @@ -0,0 +1,141 @@ +package decimal + +import ( + "fmt" + "gopkg.in/inf.v0" + "math/big" + "reflect" + "strconv" + "strings" + + "github.com/gocql/gocql/serialization/varint" +) + +func EncInfDec(v inf.Dec) ([]byte, error) { + sign := v.Sign() + if sign == 0 { + return []byte{0, 0, 0, 0, 0}, nil + } + return append(encScale(v.Scale()), varint.EncBigIntRS(v.UnscaledBig())...), nil +} + +func EncInfDecR(v *inf.Dec) ([]byte, error) { + if v == nil { + return nil, nil + } + return encInfDecR(v), nil +} + +// EncString encodes decimal string which should contains `scale` and `unscaled` strings separated by `;`. +func EncString(v string) ([]byte, error) { + if v == "" { + return nil, nil + } + vs := strings.Split(v, ";") + if len(vs) != 2 { + return nil, fmt.Errorf("failed to marshal decimal: invalid decimal string %s", v) + } + scale, err := strconv.ParseInt(vs[0], 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to marshal decimal: invalid decimal scale string %s", vs[0]) + } + unscaleData, err := encUnscaledString(vs[1]) + if err != nil { + return nil, err + } + return append(encScale64(scale), unscaleData...), nil +} + +func EncStringR(v *string) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncString(*v) +} + +func EncReflect(v reflect.Value) ([]byte, error) { + switch v.Type().Kind() { + case reflect.String: + return encReflectString(v) + case reflect.Struct: + if v.Type().String() == "gocql.unsetColumn" { + return nil, nil + } + return nil, fmt.Errorf("failed to marshal decimal: unsupported value type (%T)(%[1]v)", v.Interface()) + default: + return nil, fmt.Errorf("failed to marshal decimal: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func EncReflectR(v reflect.Value) ([]byte, error) { + if v.IsNil() { + return nil, nil + } + return EncReflect(v.Elem()) +} + +func encReflectString(v reflect.Value) ([]byte, error) { + val := v.String() + if val == "" { + return nil, nil + } + vs := strings.Split(val, ";") + if len(vs) != 2 { + return nil, fmt.Errorf("failed to marshal decimal: invalid decimal string (%T)(%[1]v)", v.Interface()) + } + scale, err := strconv.ParseInt(vs[0], 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to marshal decimal: invalid decimal scale string (%T)(%s)", v.Interface(), vs[0]) + } + unscaledData, err := encUnscaledString(vs[1]) + if err != nil { + return nil, err + } + return append(encScale64(scale), unscaledData...), nil +} + +func encInfDecR(v *inf.Dec) []byte { + sign := v.Sign() + if sign == 0 { + return []byte{0, 0, 0, 0, 0} + } + return append(encScale(v.Scale()), varint.EncBigIntRS(v.UnscaledBig())...) +} + +func encScale(v inf.Scale) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func encScale64(v int64) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func encUnscaledString(v string) ([]byte, error) { + switch { + case len(v) == 0: + return nil, nil + case len(v) <= 18: + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s, %s", v, err) + } + return varint.EncInt64Ext(n), nil + case len(v) <= 20: + n, err := strconv.ParseInt(v, 10, 64) + if err == nil { + return varint.EncInt64Ext(n), nil + } + + t, ok := new(big.Int).SetString(v, 10) + if !ok { + return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s", v) + } + return varint.EncBigIntRS(t), nil + default: + t, ok := new(big.Int).SetString(v, 10) + if !ok { + return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s", v) + } + return varint.EncBigIntRS(t), nil + } +} diff --git a/serialization/decimal/unmarshal.go b/serialization/decimal/unmarshal.go new file mode 100644 index 000000000..6433f0dd2 --- /dev/null +++ b/serialization/decimal/unmarshal.go @@ -0,0 +1,34 @@ +package decimal + +import ( + "fmt" + "gopkg.in/inf.v0" + "reflect" +) + +func Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + case *inf.Dec: + return DecInfDec(data, v) + case **inf.Dec: + return DecInfDecR(data, v) + case *string: + return DecString(data, v) + case **string: + return DecStringR(data, v) + default: + // Custom types (type MyString string) can be deserialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.ValueOf(value) + rt := rv.Type() + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("failed to unmarshal decimal: unsupported value type (%T)(%#[1]v)", value) + } + if rt.Elem().Kind() != reflect.Ptr { + return DecReflect(data, rv) + } + return DecReflectR(data, rv) + } +} diff --git a/serialization/decimal/unmarshal_ints.go b/serialization/decimal/unmarshal_ints.go new file mode 100644 index 000000000..f21910047 --- /dev/null +++ b/serialization/decimal/unmarshal_ints.go @@ -0,0 +1,80 @@ +package decimal + +import ( + "gopkg.in/inf.v0" +) + +const ( + neg8 = int64(-1) << 8 + neg16 = int64(-1) << 16 + neg24 = int64(-1) << 24 + neg32 = int64(-1) << 32 + neg40 = int64(-1) << 40 + neg48 = int64(-1) << 48 + neg56 = int64(-1) << 56 + neg32Int = int(-1) << 32 +) + +func decScale(p []byte) inf.Scale { + return inf.Scale(p[0])<<24 | inf.Scale(p[1])<<16 | inf.Scale(p[2])<<8 | inf.Scale(p[3]) +} + +func decScaleInt64(p []byte) int64 { + if p[0] > 127 { + return neg32 | int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) + } + return int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) +} + +func dec1toInt64(p []byte) int64 { + if p[4] > 127 { + return neg8 | int64(p[4]) + } + return int64(p[4]) +} + +func dec2toInt64(p []byte) int64 { + if p[4] > 127 { + return neg16 | int64(p[4])<<8 | int64(p[5]) + } + return int64(p[4])<<8 | int64(p[5]) +} + +func dec3toInt64(p []byte) int64 { + if p[4] > 127 { + return neg24 | int64(p[4])<<16 | int64(p[5])<<8 | int64(p[6]) + } + return int64(p[4])<<16 | int64(p[5])<<8 | int64(p[6]) +} + +func dec4toInt64(p []byte) int64 { + if p[4] > 127 { + return neg32 | int64(p[4])<<24 | int64(p[5])<<16 | int64(p[6])<<8 | int64(p[7]) + } + return int64(p[4])<<24 | int64(p[5])<<16 | int64(p[6])<<8 | int64(p[7]) +} + +func dec5toInt64(p []byte) int64 { + if p[4] > 127 { + return neg40 | int64(p[4])<<32 | int64(p[5])<<24 | int64(p[6])<<16 | int64(p[7])<<8 | int64(p[8]) + } + return int64(p[4])<<32 | int64(p[5])<<24 | int64(p[6])<<16 | int64(p[7])<<8 | int64(p[8]) +} + +func dec6toInt64(p []byte) int64 { + if p[4] > 127 { + return neg48 | int64(p[4])<<40 | int64(p[5])<<32 | int64(p[6])<<24 | int64(p[7])<<16 | int64(p[8])<<8 | int64(p[9]) + } + return int64(p[4])<<40 | int64(p[5])<<32 | int64(p[6])<<24 | int64(p[7])<<16 | int64(p[8])<<8 | int64(p[9]) +} + +func dec7toInt64(p []byte) int64 { + if p[4] > 127 { + return neg56 | int64(p[4])<<48 | int64(p[5])<<40 | int64(p[6])<<32 | int64(p[7])<<24 | int64(p[8])<<16 | int64(p[9])<<8 | int64(p[10]) + } + return int64(p[4])<<48 | int64(p[5])<<40 | int64(p[6])<<32 | int64(p[7])<<24 | int64(p[8])<<16 | int64(p[9])<<8 | int64(p[10]) +} + +func dec8toInt64(p []byte) int64 { + return int64(p[4])<<56 | int64(p[5])<<48 | int64(p[6])<<40 | int64(p[7])<<32 | int64(p[8])<<24 | int64(p[9])<<16 | int64(p[10])<<8 | int64(p[11]) +} diff --git a/serialization/decimal/unmarshal_utils.go b/serialization/decimal/unmarshal_utils.go new file mode 100644 index 000000000..557f6cec3 --- /dev/null +++ b/serialization/decimal/unmarshal_utils.go @@ -0,0 +1,323 @@ +package decimal + +import ( + "fmt" + "gopkg.in/inf.v0" + "reflect" + "strconv" + + "github.com/gocql/gocql/serialization/varint" +) + +var errWrongDataLen = fmt.Errorf("failed to unmarshal decimal: the length of the data should be 0 or more than 5") + +func errBrokenData(p []byte) error { + if p[4] == 0 && p[5] <= 127 || p[4] == 255 && p[5] > 127 { + return fmt.Errorf("failed to unmarshal decimal: the data is broken") + } + return nil +} + +func errNilReference(v interface{}) error { + return fmt.Errorf("failed to unmarshal decimal: can not unmarshal into nil reference(%T)(%[1]v)", v) +} + +func DecInfDec(p []byte, v *inf.Dec) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + v.SetScale(0).SetUnscaled(0) + return nil + case 1, 2, 3, 4: + return errWrongDataLen + case 5: + v.SetScale(decScale(p)).SetUnscaled(dec1toInt64(p)) + return nil + case 6: + v.SetScale(decScale(p)).SetUnscaled(dec2toInt64(p)) + case 7: + v.SetScale(decScale(p)).SetUnscaled(dec3toInt64(p)) + case 8: + v.SetScale(decScale(p)).SetUnscaled(dec4toInt64(p)) + case 9: + v.SetScale(decScale(p)).SetUnscaled(dec5toInt64(p)) + case 10: + v.SetScale(decScale(p)).SetUnscaled(dec6toInt64(p)) + case 11: + v.SetScale(decScale(p)).SetUnscaled(dec7toInt64(p)) + case 12: + v.SetScale(decScale(p)).SetUnscaled(dec8toInt64(p)) + default: + v.SetScale(decScale(p)).SetUnscaledBig(varint.Dec2BigInt(p[4:])) + } + return errBrokenData(p) +} + +func DecInfDecR(p []byte, v **inf.Dec) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = inf.NewDec(0, 0) + } + return nil + case 1, 2, 3, 4: + return errWrongDataLen + case 5: + *v = inf.NewDec(dec1toInt64(p), decScale(p)) + return nil + case 6: + *v = inf.NewDec(dec2toInt64(p), decScale(p)) + case 7: + *v = inf.NewDec(dec3toInt64(p), decScale(p)) + case 8: + *v = inf.NewDec(dec4toInt64(p), decScale(p)) + case 9: + *v = inf.NewDec(dec5toInt64(p), decScale(p)) + case 10: + *v = inf.NewDec(dec6toInt64(p), decScale(p)) + case 11: + *v = inf.NewDec(dec7toInt64(p), decScale(p)) + case 12: + *v = inf.NewDec(dec8toInt64(p), decScale(p)) + default: + *v = inf.NewDecBig(varint.Dec2BigInt(p[4:]), decScale(p)) + } + return errBrokenData(p) +} + +func DecString(p []byte, v *string) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = "" + } else { + *v = "0;0" + } + return nil + case 1, 2, 3, 4: + return errWrongDataLen + case 5: + *v = decString5(p) + return nil + case 6: + *v = decString6(p) + case 7: + *v = decString7(p) + case 8: + *v = decString8(p) + case 9: + *v = decString9(p) + case 10: + *v = decString10(p) + case 11: + *v = decString11(p) + case 12: + *v = decString12(p) + default: + *v = decString(p) + } + return errBrokenData(p) +} + +func DecStringR(p []byte, v **string) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + tmp := "0;0" + *v = &tmp + } + return nil + case 1, 2, 3, 4: + return errWrongDataLen + case 5: + tmp := decString5(p) + *v = &tmp + return nil + case 6: + tmp := decString6(p) + *v = &tmp + case 7: + tmp := decString7(p) + *v = &tmp + case 8: + tmp := decString8(p) + *v = &tmp + case 9: + tmp := decString9(p) + *v = &tmp + case 10: + tmp := decString10(p) + *v = &tmp + case 11: + tmp := decString11(p) + *v = &tmp + case 12: + tmp := decString12(p) + *v = &tmp + default: + tmp := decString(p) + *v = &tmp + } + return errBrokenData(p) +} + +func DecReflect(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal decimal: can not unmarshal into nil reference (%T)(%#[1]v)", v.Interface()) + } + + switch v = v.Elem(); v.Kind() { + case reflect.String: + return decReflectString(p, v) + default: + return fmt.Errorf("failed to unmarshal decimal: unsupported value type (%T)(%#[1]v)", v.Interface()) + } +} + +func DecReflectR(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal decimal: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface()) + } + + switch v.Type().Elem().Elem().Kind() { + case reflect.String: + return decReflectStringR(p, v) + default: + return fmt.Errorf("failed to unmarshal decimal: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectString(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + if p == nil { + v.SetString("") + } else { + v.SetString("0;0") + } + return nil + case 1, 2, 3, 4: + return errWrongDataLen + case 5: + v.SetString(decString5(p)) + return nil + case 6: + v.SetString(decString6(p)) + case 7: + v.SetString(decString7(p)) + case 8: + v.SetString(decString8(p)) + case 9: + v.SetString(decString9(p)) + case 10: + v.SetString(decString10(p)) + case 11: + v.SetString(decString11(p)) + case 12: + v.SetString(decString12(p)) + default: + v.SetString(decString(p)) + } + return errBrokenData(p) +} + +func decReflectStringR(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + var val reflect.Value + if p == nil { + val = reflect.Zero(v.Type().Elem()) + } else { + val = reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString("0;0") + } + v.Elem().Set(val) + return nil + case 1, 2, 3, 4: + return errWrongDataLen + case 5: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString5(p)) + v.Elem().Set(newVal) + return nil + case 6: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString6(p)) + v.Elem().Set(newVal) + case 7: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString7(p)) + v.Elem().Set(newVal) + case 8: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString8(p)) + v.Elem().Set(newVal) + case 9: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString9(p)) + v.Elem().Set(newVal) + case 10: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString10(p)) + v.Elem().Set(newVal) + case 11: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString11(p)) + v.Elem().Set(newVal) + case 12: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString12(p)) + v.Elem().Set(newVal) + default: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetString(decString(p)) + v.Elem().Set(newVal) + } + return errBrokenData(p) +} + +func decString5(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec1toInt64(p), 10) +} + +func decString6(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec2toInt64(p), 10) +} + +func decString7(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec3toInt64(p), 10) +} +func decString8(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec4toInt64(p), 10) +} +func decString9(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec5toInt64(p), 10) +} +func decString10(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec6toInt64(p), 10) +} +func decString11(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec7toInt64(p), 10) +} +func decString12(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec8toInt64(p), 10) +} + +func decString(p []byte) string { + return strconv.FormatInt(decScaleInt64(p), 10) + ";" + varint.Dec2BigInt(p[4:]).String() +}