Skip to content

Commit

Permalink
Merge pull request #379 from illia-li/il/fix/marshal/boolean
Browse files Browse the repository at this point in the history
Fix `boolean`, marshal, unmarshall functions
  • Loading branch information
dkropachev authored Dec 27, 2024
2 parents 3785839 + be3b863 commit fffa208
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 126 deletions.
65 changes: 12 additions & 53 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gocql/gocql/serialization/boolean"
"math"
"math/big"
"math/bits"
Expand Down Expand Up @@ -154,7 +155,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeAscii:
return marshalAscii(value)
case TypeBoolean:
return marshalBool(info, value)
return marshalBool(value)
case TypeTinyInt:
return marshalTinyInt(value)
case TypeSmallInt:
Expand Down Expand Up @@ -266,7 +267,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeAscii:
return unmarshalAscii(data, value)
case TypeBoolean:
return unmarshalBool(info, data, value)
return unmarshalBool(data, value)
case TypeInt:
return unmarshalInt(data, value)
case TypeBigInt:
Expand Down Expand Up @@ -525,61 +526,19 @@ func decBigInt(data []byte) int64 {
int64(data[6])<<8 | int64(data[7])
}

func marshalBool(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case bool:
return encBool(v), nil
}

if value == nil {
return nil, nil
}

rv := reflect.ValueOf(value)
switch rv.Type().Kind() {
case reflect.Bool:
return encBool(rv.Bool()), nil
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
}

func encBool(v bool) []byte {
if v {
return []byte{1}
}
return []byte{0}
}

func unmarshalBool(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *bool:
*v = decBool(data)
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
}
rv = rv.Elem()
switch rv.Type().Kind() {
case reflect.Bool:
rv.SetBool(decBool(data))
return nil
func marshalBool(value interface{}) ([]byte, error) {
data, err := boolean.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
return data, nil
}

func decBool(v []byte) bool {
if len(v) == 0 {
return false
func unmarshalBool(data []byte, value interface{}) error {
if err := boolean.Unmarshal(data, value); err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return v[0] != 0
return nil
}

func marshalFloat(value interface{}) ([]byte, error) {
Expand Down
41 changes: 0 additions & 41 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,6 @@ var marshalTests = []struct {
MarshalError error
UnmarshalError error
}{
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte("\x00"),
false,
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte("\x01"),
true,
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeDecimal},
[]byte("\x00\x00\x00\x00\x00"),
Expand Down Expand Up @@ -303,33 +289,6 @@ var marshalTests = []struct {
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte("\x00"),
func() *bool {
b := false
return &b
}(),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte("\x01"),
func() *bool {
b := true
return &b
}(),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte(nil),
(*bool)(nil),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeInet},
[]byte("\x7F\x00\x00\x01"),
Expand Down
24 changes: 24 additions & 0 deletions serialization/boolean/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package boolean

import (
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case bool:
return EncBool(v)
case *bool:
return EncBoolR(v)
default:
// Custom types (type MyBool bool) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
45 changes: 45 additions & 0 deletions serialization/boolean/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package boolean

import (
"fmt"
"reflect"
)

func EncBool(v bool) ([]byte, error) {
return encBool(v), nil
}

func EncBoolR(v *bool) ([]byte, error) {
if v == nil {
return nil, nil
}
return encBool(*v), nil
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Bool:
return encBool(v.Bool()), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}

func encBool(v bool) []byte {
if v {
return []byte{1}
}
return []byte{0}
}
29 changes: 29 additions & 0 deletions serialization/boolean/unmarshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package boolean

import (
"fmt"
"reflect"
)

func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *bool:
return DecBool(data, v)
case **bool:
return DecBoolR(data, v)
default:
// Custom types (type MyBool bool) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}
108 changes: 108 additions & 0 deletions serialization/boolean/unmarshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package boolean

import (
"fmt"
"reflect"
)

var errWrongDataLen = fmt.Errorf("failed to unmarshal boolean: the length of the data should be 0 or 1")

func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal boolean: can not unmarshal into nil reference(%T)(%[1]v)", v)
}

func DecBool(p []byte, v *bool) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = false
case 1:
*v = decBool(p)
default:
return errWrongDataLen
}
return nil
}

func DecBoolR(p []byte, v **bool) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(bool)
}
case 1:
val := decBool(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}

func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}

switch v = v.Elem(); v.Kind() {
case reflect.Bool:
return decReflectBool(p, v)
default:
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}

switch v.Type().Elem().Elem().Kind() {
case reflect.Bool:
return decReflectBoolR(p, v)
default:
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func decReflectBool(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetBool(false)
case 1:
v.SetBool(decBool(p))
default:
return errWrongDataLen
}
return nil
}

func decReflectBoolR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.Elem().Set(reflect.Zero(v.Type().Elem()))
} else {
val := reflect.New(v.Type().Elem().Elem())
v.Elem().Set(val)
}
case 1:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBool(decBool(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}

func decBool(p []byte) bool {
return p[0] != 0
}
Loading

0 comments on commit fffa208

Please sign in to comment.