diff --git a/marshal.go b/marshal.go index d5e511242..503b4bfe0 100644 --- a/marshal.go +++ b/marshal.go @@ -8,6 +8,7 @@ import ( "bytes" "errors" "fmt" + "github.com/gocql/gocql/serialization/boolean" "math" "math/big" "math/bits" @@ -154,7 +155,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeAscii: return marshalAscii(value) case TypeBoolean: - return marshalBool(info, value) + return marshalBool(value) case TypeTinyInt: return marshalTinyInt(value) case TypeSmallInt: @@ -266,7 +267,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { case TypeAscii: return unmarshalAscii(data, value) case TypeBoolean: - return unmarshalBool(info, data, value) + return unmarshalBool(data, value) case TypeInt: return unmarshalInt(data, value) case TypeBigInt: @@ -525,61 +526,19 @@ func decBigInt(data []byte) int64 { int64(data[6])<<8 | int64(data[7]) } -func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case bool: - return encBool(v), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Bool: - return encBool(rv.Bool()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encBool(v bool) []byte { - if v { - return []byte{1} - } - return []byte{0} -} - -func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *bool: - *v = decBool(data) - return nil - } - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Bool: - rv.SetBool(decBool(data)) - return nil +func marshalBool(value interface{}) ([]byte, error) { + data, err := boolean.Marshal(value) + if err != nil { + return nil, wrapMarshalError(err, "marshal error") } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) + return data, nil } -func decBool(v []byte) bool { - if len(v) == 0 { - return false +func unmarshalBool(data []byte, value interface{}) error { + if err := boolean.Unmarshal(data, value); err != nil { + return wrapUnmarshalError(err, "unmarshal error") } - return v[0] != 0 + return nil } func marshalFloat(value interface{}) ([]byte, error) { diff --git a/marshal_test.go b/marshal_test.go index 59a5211dc..d5494f02a 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -31,20 +31,6 @@ var marshalTests = []struct { MarshalError error UnmarshalError error }{ - { - NativeType{proto: 2, typ: TypeBoolean}, - []byte("\x00"), - false, - nil, - nil, - }, - { - NativeType{proto: 2, typ: TypeBoolean}, - []byte("\x01"), - true, - nil, - nil, - }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x00\x00"), @@ -303,33 +289,6 @@ var marshalTests = []struct { nil, nil, }, - { - NativeType{proto: 2, typ: TypeBoolean}, - []byte("\x00"), - func() *bool { - b := false - return &b - }(), - nil, - nil, - }, - { - NativeType{proto: 2, typ: TypeBoolean}, - []byte("\x01"), - func() *bool { - b := true - return &b - }(), - nil, - nil, - }, - { - NativeType{proto: 2, typ: TypeBoolean}, - []byte(nil), - (*bool)(nil), - nil, - nil, - }, { NativeType{proto: 2, typ: TypeInet}, []byte("\x7F\x00\x00\x01"), diff --git a/serialization/boolean/marshal.go b/serialization/boolean/marshal.go new file mode 100644 index 000000000..36c9b2942 --- /dev/null +++ b/serialization/boolean/marshal.go @@ -0,0 +1,24 @@ +package boolean + +import ( + "reflect" +) + +func Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case nil: + return nil, nil + case bool: + return EncBool(v) + case *bool: + return EncBoolR(v) + default: + // Custom types (type MyBool bool) can be serialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.TypeOf(value) + if rv.Kind() != reflect.Ptr { + return EncReflect(reflect.ValueOf(v)) + } + return EncReflectR(reflect.ValueOf(v)) + } +} diff --git a/serialization/boolean/marshal_utils.go b/serialization/boolean/marshal_utils.go new file mode 100644 index 000000000..1166cd3f5 --- /dev/null +++ b/serialization/boolean/marshal_utils.go @@ -0,0 +1,45 @@ +package boolean + +import ( + "fmt" + "reflect" +) + +func EncBool(v bool) ([]byte, error) { + return encBool(v), nil +} + +func EncBoolR(v *bool) ([]byte, error) { + if v == nil { + return nil, nil + } + return encBool(*v), nil +} + +func EncReflect(v reflect.Value) ([]byte, error) { + switch v.Kind() { + case reflect.Bool: + return encBool(v.Bool()), nil + case reflect.Struct: + if v.Type().String() == "gocql.unsetColumn" { + return nil, nil + } + return nil, fmt.Errorf("failed to marshal boolean: unsupported value type (%T)(%[1]v)", v.Interface()) + default: + return nil, fmt.Errorf("failed to marshal boolean: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func EncReflectR(v reflect.Value) ([]byte, error) { + if v.IsNil() { + return nil, nil + } + return EncReflect(v.Elem()) +} + +func encBool(v bool) []byte { + if v { + return []byte{1} + } + return []byte{0} +} diff --git a/serialization/boolean/unmarshal.go b/serialization/boolean/unmarshal.go new file mode 100644 index 000000000..0bf746cf9 --- /dev/null +++ b/serialization/boolean/unmarshal.go @@ -0,0 +1,29 @@ +package boolean + +import ( + "fmt" + "reflect" +) + +func Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + case *bool: + return DecBool(data, v) + case **bool: + return DecBoolR(data, v) + default: + // Custom types (type MyBool bool) can be deserialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.ValueOf(value) + rt := rv.Type() + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v) + } + if rt.Elem().Kind() != reflect.Ptr { + return DecReflect(data, rv) + } + return DecReflectR(data, rv) + } +} diff --git a/serialization/boolean/unmarshal_utils.go b/serialization/boolean/unmarshal_utils.go new file mode 100644 index 000000000..53c21da46 --- /dev/null +++ b/serialization/boolean/unmarshal_utils.go @@ -0,0 +1,108 @@ +package boolean + +import ( + "fmt" + "reflect" +) + +var errWrongDataLen = fmt.Errorf("failed to unmarshal boolean: the length of the data should be 0 or 1") + +func errNilReference(v interface{}) error { + return fmt.Errorf("failed to unmarshal boolean: can not unmarshal into nil reference(%T)(%[1]v)", v) +} + +func DecBool(p []byte, v *bool) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = false + case 1: + *v = decBool(p) + default: + return errWrongDataLen + } + return nil +} + +func DecBoolR(p []byte, v **bool) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(bool) + } + case 1: + val := decBool(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecReflect(p []byte, v reflect.Value) error { + if v.IsNil() { + return errNilReference(v) + } + + switch v = v.Elem(); v.Kind() { + case reflect.Bool: + return decReflectBool(p, v) + default: + return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func DecReflectR(p []byte, v reflect.Value) error { + if v.IsNil() { + return errNilReference(v) + } + + switch v.Type().Elem().Elem().Kind() { + case reflect.Bool: + return decReflectBoolR(p, v) + default: + return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectBool(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetBool(false) + case 1: + v.SetBool(decBool(p)) + default: + return errWrongDataLen + } + return nil +} + +func decReflectBoolR(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + if p == nil { + v.Elem().Set(reflect.Zero(v.Type().Elem())) + } else { + val := reflect.New(v.Type().Elem().Elem()) + v.Elem().Set(val) + } + case 1: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetBool(decBool(p)) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decBool(p []byte) bool { + return p[0] != 0 +} diff --git a/tests/serialization/marshal_1_boolean_corrupt_test.go b/tests/serialization/marshal_1_boolean_corrupt_test.go new file mode 100644 index 000000000..3b1e76fa8 --- /dev/null +++ b/tests/serialization/marshal_1_boolean_corrupt_test.go @@ -0,0 +1,54 @@ +//go:build all || unit +// +build all unit + +package serialization_test + +import ( + "github.com/gocql/gocql/serialization/boolean" + "testing" + + "github.com/gocql/gocql" + "github.com/gocql/gocql/internal/tests/serialization" + "github.com/gocql/gocql/internal/tests/serialization/mod" +) + +func TestMarshalBooleanCorrupt(t *testing.T) { + tType := gocql.NewNativeType(4, gocql.TypeBoolean, "") + + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error + } + + testSuites := [2]testSuite{ + { + name: "serialization.boolean", + marshal: boolean.Marshal, + unmarshal: boolean.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, + } + + for _, tSuite := range testSuites { + unmarshal := tSuite.unmarshal + + t.Run(tSuite.name, func(t *testing.T) { + + serialization.NegativeUnmarshalSet{ + Data: []byte("\x00\x00"), + Values: mod.Values{ + false, + }.AddVariants(mod.All...), + }.Run("big_data", t, unmarshal) + }) + } +} diff --git a/tests/serialization/marshal_1_boolean_test.go b/tests/serialization/marshal_1_boolean_test.go index 77d75b5b2..40c3a6751 100644 --- a/tests/serialization/marshal_1_boolean_test.go +++ b/tests/serialization/marshal_1_boolean_test.go @@ -9,43 +9,70 @@ import ( "github.com/gocql/gocql" "github.com/gocql/gocql/internal/tests/serialization" "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/serialization/boolean" ) func TestMarshalBoolean(t *testing.T) { tType := gocql.NewNativeType(4, gocql.TypeBoolean, "") - marshal := func(i interface{}) ([]byte, error) { return gocql.Marshal(tType, i) } - unmarshal := func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(tType, bytes, i) + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error } - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{(*bool)(nil)}.AddVariants(mod.CustomType), - }.Run("[nil]nullable", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{false}.AddVariants(mod.CustomType), - }.Run("[nil]unmarshal", t, nil, unmarshal) - - serialization.PositiveSet{ - Data: make([]byte, 0), - Values: mod.Values{false}.AddVariants(mod.All...), - }.Run("[]unmarshal", t, nil, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x00"), - Values: mod.Values{false}.AddVariants(mod.All...), - }.Run("zeros", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x01"), - Values: mod.Values{true}.AddVariants(mod.All...), - }.Run("[ff]unmarshal", t, nil, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\xff"), - Values: mod.Values{true}.AddVariants(mod.All...), - }.Run("[01]", t, nil, unmarshal) + testSuites := [2]testSuite{ + { + name: "serialization.boolean", + marshal: boolean.Marshal, + unmarshal: boolean.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, + } + + for _, tSuite := range testSuites { + marshal := tSuite.marshal + unmarshal := tSuite.unmarshal + + t.Run(tSuite.name, func(t *testing.T) { + + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{(*bool)(nil)}.AddVariants(mod.CustomType), + }.Run("[nil]nullable", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{false}.AddVariants(mod.CustomType), + }.Run("[nil]unmarshal", t, nil, unmarshal) + + serialization.PositiveSet{ + Data: make([]byte, 0), + Values: mod.Values{false}.AddVariants(mod.All...), + }.Run("[]unmarshal", t, nil, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x00"), + Values: mod.Values{false}.AddVariants(mod.All...), + }.Run("zeros", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x01"), + Values: mod.Values{true}.AddVariants(mod.All...), + }.Run("[1]unmarshal", t, nil, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\xff"), + Values: mod.Values{true}.AddVariants(mod.All...), + }.Run("[255]", t, nil, unmarshal) + }) + } }