From 90939a50977adf55366cedd0746843de1d156e51 Mon Sep 17 00:00:00 2001 From: Chao Xu Date: Mon, 5 Dec 2022 16:20:24 -0500 Subject: [PATCH] add decoding support. (#1) --- bcs/decode.go | 280 +++++++++++++++++++++++++++++++++++++++++++- bcs/encode.go | 73 +++++++++--- bcs/encode_test.go | 45 +++++++ bcs/marshaler.go | 9 -- bcs/uleb128.go | 71 ++++++----- bcs/uleb128_test.go | 5 +- bcs/unmarshaler.go | 11 ++ 7 files changed, 433 insertions(+), 61 deletions(-) create mode 100644 bcs/unmarshaler.go diff --git a/bcs/decode.go b/bcs/decode.go index 1eec05f..a0fcdcc 100644 --- a/bcs/decode.go +++ b/bcs/decode.go @@ -1,10 +1,280 @@ package bcs -import "fmt" +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "reflect" +) // Unmarshal unmarshal the bcs serialized data into v. -// It returns the number of bytes consumed and a possible error. -// If error is not nil, the consumed bytes will be 0. -func Unmarshal(data []byte, v any) (int, error) { - return 0, fmt.Errorf("unimplemented") +// +// Refer to notes in [Marshal] for details how data serialized/deserialized. +// +// During the unmarshalling process +// 1. if [Unmarshaler], use "UnmarshalBCS" method. +// 2. if not [Unmarshaler] but [Enum], use the specialization for [Enum]. +// 3. otherwise standard process. +func Unmarshal(data []byte, v any) error { + return NewDecoder(bytes.NewReader(data)).Decode(v) +} + +// Decoder takes an [io.Reader] and decodes value from it. +type Decoder struct { + r io.Reader +} + +// NewDecoder creates a new [Decoder] from an [io.Reader] +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{ + r: r, + } +} + +// Decode a value from the decoder. +// +// - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called. +// - If the value is [Enum], it will be special handled for [Enum] +func (d *Decoder) Decode(v any) error { + reflectValue := reflect.ValueOf(v) + if reflectValue.Kind() != reflect.Pointer || reflectValue.IsNil() { + return fmt.Errorf("not a pointer or nil pointer") + } + + return d.decode(reflectValue) +} + +func (d *Decoder) decode(v reflect.Value) error { + // if v cannot interface, ignore + if !v.CanInterface() { + return nil + } + + if i, isUnmarshaler := v.Interface().(Unmarshaler); isUnmarshaler { + _, err := i.UnmarshalBCS(d.r) + return err + } + + if _, isEnum := v.Interface().(Enum); isEnum { + switch v.Kind() { + case reflect.Pointer, reflect.Interface: + if v.IsNil() { + return fmt.Errorf("trying to decode into nil pointer/interface") + } + return d.decodeEnum(v.Elem()) + default: + return d.decodeEnum(v) + } + } + + switch v.Kind() { + case reflect.Pointer: + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + return d.decodeVanilla(v.Elem()) + + case reflect.Interface: + if v.IsNil() { + return fmt.Errorf("cannot decode into nil interface") + } + return d.decode(v.Elem()) + + case reflect.Chan, reflect.Func, reflect.Uintptr, reflect.UnsafePointer: + // silently ignore + return nil + default: + return d.decodeVanilla(v) + } +} + +func (d *Decoder) decodeVanilla(v reflect.Value) error { + kind := v.Kind() + + if !v.CanSet() { + return fmt.Errorf("cannot change value of kind %s", kind.String()) + } + + switch v.Kind() { + case reflect.Bool: + t, err := d.readByte() + if err != nil { + return nil + } + + if t == 0 { + v.SetBool(false) + } else { + v.SetBool(true) + } + + return nil + + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, // ints + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // uints + return binary.Read(d.r, binary.LittleEndian, v.Addr) + + case reflect.Struct: + return d.decodeStruct(v) + + case reflect.Slice: + return d.decodeSlice(v) + + case reflect.Array: + return d.decodeArray(v) + + case reflect.String: + return d.decodeString(v) + + default: + return fmt.Errorf("unsupported vanilla decoding type: %s", kind.String()) + } +} + +func (d *Decoder) decodeString(v reflect.Value) error { + size, _, err := ULEB128Decode[int](d.r) + if err != nil { + return err + } + + tmp := make([]byte, size, size) + + read, err := d.r.Read(tmp) + if err != nil { + return err + } + + if size != read { + return fmt.Errorf("wrong number of bytes read for []byte, want: %d, got %d", size, read) + } + + v.Set(reflect.ValueOf(string(tmp))) + + return nil +} + +func (d *Decoder) readByte() (byte, error) { + b := make([]byte, 1, 1) + n, err := d.r.Read(b) + if err != nil { + return 0, err + } + if n == 0 { + return 0, fmt.Errorf("EOF") + } + + return b[0], nil +} + +func (d *Decoder) decodeStruct(v reflect.Value) error { + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if !field.CanInterface() { + continue + } + tag, err := parseTagValue(t.Field(i).Tag.Get(tagName)) + if err != nil { + return err + } + // ignored + if tag&tagValue_Ignore != 0 { + continue + } + // optional + if tag&tagValue_Optional != 0 { + isOptional, err := d.readByte() + if err != nil { + return err + } + if isOptional == 0 { + field.Set(reflect.Zero(v.Type())) + } else { + field.Set(reflect.New(field.Type().Elem())) + err := d.decode(field.Elem()) + if err != nil { + return err + } + } + } + + if err := d.decode(field); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) decodeEnum(v reflect.Value) error { + if v.Kind() != reflect.Struct { + return fmt.Errorf("only support struct for Enum, got %s", v.Kind().String()) + } + enumId, _, err := ULEB128Decode[int](d.r) + if err != nil { + return err + } + + field := v.Field(enumId) + + return d.decode(field) +} + +func (d *Decoder) decodeByteSlice(v reflect.Value) error { + size, _, err := ULEB128Decode[int](d.r) + if err != nil { + return err + } + + tmp := make([]byte, size, size) + + read, err := d.r.Read(tmp) + if err != nil { + return err + } + + if size != read { + return fmt.Errorf("wrong number of bytes read for []byte, want: %d, got %d", size, read) + } + + v.Set(reflect.ValueOf(tmp)) + + return nil +} + +func (d *Decoder) decodeArray(v reflect.Value) error { + size := v.Len() + t := v.Type() + + for i := 0; i < size; i++ { + v.Index(i).Set(reflect.New(t.Elem())) + if err := d.decode(v.Index(i)); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) decodeSlice(v reflect.Value) error { + size, _, err := ULEB128Decode[int](d.r) + if err != nil { + return err + } + + t := v.Type() + tmp := reflect.MakeSlice(t, 0, size) + for i := 0; i < size; i++ { + ind := reflect.New(t.Elem()) + if err := d.decode(ind); err != nil { + return err + } + tmp = reflect.Append(tmp, ind) + } + + v.Set(tmp) + + return nil } diff --git a/bcs/encode.go b/bcs/encode.go index f2beeeb..cf91b72 100644 --- a/bcs/encode.go +++ b/bcs/encode.go @@ -22,9 +22,9 @@ func NewEncoder(w io.Writer) *Encoder { // Encode a value v into the encoder. // -// - If the value is [Marshaler], then the corresponding +// - If the value is [Marshaler], the corresponding // MarshalBCS implementation will be called. -// - If the value is [Enum], then it will be special handled for enum. +// - If the value is [Enum], it will be special handled for [Enum]. func (e *Encoder) Encode(v any) error { return e.encode(reflect.ValueOf(v)) } @@ -37,7 +37,7 @@ func (e *Encoder) encode(v reflect.Value) error { return nil } - // test for the two enums we defined. + // test for the two interfaces we defined. // 1. Marshaler // 2. Enum. i := v.Interface() @@ -63,25 +63,39 @@ func (e *Encoder) encode(v reflect.Value) error { reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // all the uints // use little endian to encode those. return binary.Write(e.w, binary.LittleEndian, v.Interface()) - case reflect.Ptr: + + case reflect.Pointer: // pointer // if v is nil pointer, use the zero value for v. // we don't check for optional flag here. - // that should be checked when the container struct is encoded, + // that should be checked when the container struct is encoded // if this pointer is contained in a struct. - if v.IsNil() { - return e.encode(reflect.Indirect(reflect.New(v.Type()))) - } else { - return e.encode(reflect.Indirect(v)) + return e.encode(reflect.Indirect(v)) + + case reflect.Interface: + return e.encode(v.Elem()) + + case reflect.Slice: // slices + // check if the element is uint8 or byteslice + if byteSlice, ok := (v.Interface()).([]byte); ok { + return e.encodeByteSlice(byteSlice) } - case reflect.Slice: return e.encodeSlice(v) + + case reflect.Array: // encode array + return e.encodeArray(v) + case reflect.String: str := []byte(v.String()) return e.encodeByteSlice(str) + case reflect.Struct: return e.encodeStruct(v) + + case reflect.Chan, reflect.Func, reflect.Uintptr, reflect.UnsafePointer: // channel, func, pointers + return nil + default: - return fmt.Errorf("unsupported kind: %s", kind.String()) + return fmt.Errorf("unsupported kind: %s, consider make the field ignored by using - tag or provide a customized Marshaler implementation.", kind.String()) } } @@ -138,6 +152,17 @@ func (e *Encoder) encodeByteSlice(b []byte) error { return nil } +func (e *Encoder) encodeArray(v reflect.Value) error { + length := v.Len() + for i := 0; i < length; i++ { + if err := e.encode(v.Index(i)); err != nil { + return err + } + } + + return nil +} + func (e *Encoder) encodeSlice(v reflect.Value) error { length := v.Len() if _, err := e.w.Write(ULEB128Encode(length)); err != nil { @@ -158,6 +183,7 @@ func (e *Encoder) encodeStruct(v reflect.Value) error { for i := 0; i < v.NumField(); i++ { field := v.Field(i) + // if a field is not exported, ignore if !field.CanInterface() { continue } @@ -165,12 +191,14 @@ func (e *Encoder) encodeStruct(v reflect.Value) error { if err != nil { return err } + // ignored if tag&tagValue_Ignore != 0 { continue } + // optional if tag&tagValue_Optional != 0 { - if v.Kind() != reflect.Pointer && v.Kind() != reflect.Interface { + if field.Kind() != reflect.Pointer && field.Kind() != reflect.Interface { return fmt.Errorf("optional field can only be pointer or interface") } if field.IsNil() { @@ -182,11 +210,14 @@ func (e *Encoder) encodeStruct(v reflect.Value) error { if _, err := e.w.Write([]byte{1}); err != nil { return err } - if err := e.encode(reflect.Indirect(field)); err != nil { + if err := e.encode(field.Elem()); err != nil { return err } } - } else if err := e.encode(field); err != nil { + continue + } + // finally + if err := e.encode(field); err != nil { return err } } @@ -204,15 +235,23 @@ func (e *Encoder) encodeStruct(v reflect.Value) error { // - Use tag `-` to ignore fields. // - Unexported fields are ignored. // -// Note that bcs doesn't have schema, and field names are irrelavant. The fields +// Note that bcs doesn't have schema, and field names are irrelevant. The fields // of struct are serialized in the order that they are defined. // // Pointers are serialized as the type they point to. Nil pointers will be serialized // as zero value of the type they point to unless it's marked as `optional`. // +// Arrays are serialized as fixed length vector (or serialize the each object individually without prefixing +// the length of the array). +// +// Vanilla maps are not supported, however, the code will error if map is encountered to call out they are +// not supported and either ignore or provide a customized marshal function. +// +// Channels, functions are silently ignored. +// // During marshalling process, how v is marshalled depends on if v implemented [Marshaler] or [Enum] -// 1. if [Marshaler], use "MarshalBCS" method of the class. -// 2. if not [Marshaler] but [Enum], use specialization for [Enum] +// 1. if [Marshaler], use "MarshalBCS" method. +// 2. if not [Marshaler] but [Enum], use specialization for [Enum]. // 3. otherwise standard process. func Marshal(v any) ([]byte, error) { var b bytes.Buffer diff --git a/bcs/encode_test.go b/bcs/encode_test.go index cc2b339..528319e 100644 --- a/bcs/encode_test.go +++ b/bcs/encode_test.go @@ -59,6 +59,16 @@ type Wrapper struct { String string } +type WrapperWithWrongOptional struct { + Inner MyStruct + Outer string `bcs:"optional"` +} + +type WrapperWithOptional struct { + Inner MyStruct + Outer *string `bcs:"optional"` +} + // struct from [bcs repo] // // [bcs repo]: https://github.com/diem/bcs @@ -92,3 +102,38 @@ func TestMarshal_struct(t *testing.T) { t.Fatalf("want: %v\ngot: %v\n", wBytesExpected, wBytes) } } + +func TestMarshal_optional(t *testing.T) { + if _, err := bcs.Marshal(WrapperWithWrongOptional{}); err == nil { + t.Fatalf("optional should be pointer or interface") + } else { + t.Log(err.Error()) + } + optionalUnset := WrapperWithOptional{ + Inner: MyStruct{ + Boolean: true, + Bytes: []byte{0xC0, 0xDE}, + Label: "a", + }, + } + optionalUnsetBytes, err := bcs.Marshal(optionalUnset) + if err != nil { + t.Error(err) + } + optionalUnsetExpected := []byte{1, 2, 0xC0, 0xDE, 1, 98, 0} + if !sliceEqual(optionalUnsetBytes, optionalUnsetExpected) { + t.Errorf("want: %v\ngot: %v\n", optionalUnsetExpected, optionalUnsetBytes) + } + + optionalSet := optionalUnset + s := "123" + optionalSet.Outer = &s + optionalSetBytes, err := bcs.Marshal(optionalSet) + if err != nil { + t.Error(err) + } + optionalSetExpected := []byte{1, 2, 0xC0, 0xDE, 1, 98, 1, 3, 49, 50, 51} + if !sliceEqual(optionalSetBytes, optionalSetExpected) { + t.Errorf("want: %v\ngot: %v\n", optionalSetExpected, optionalSetBytes) + } +} diff --git a/bcs/marshaler.go b/bcs/marshaler.go index 22ae211..d73ba04 100644 --- a/bcs/marshaler.go +++ b/bcs/marshaler.go @@ -4,12 +4,3 @@ package bcs type Marshaler interface { MarshalBCS() ([]byte, error) } - -// Unmarshaler customizes the unmarshalling behavior for a type. -// -// This is different from many other unmarshalers in golang that it -// returns the bytes consumed and the error. When the error is not nil, -// the byte consumed is guaranteed to be 0. -type Unmarshaler interface { - UnmarshalBCS([]byte) (int, error) -} diff --git a/bcs/uleb128.go b/bcs/uleb128.go index 18f72e6..02fe45a 100644 --- a/bcs/uleb128.go +++ b/bcs/uleb128.go @@ -1,55 +1,70 @@ package bcs -import "fmt" +import ( + "encoding/binary" + "fmt" + "io" +) + +// MaxUleb128Length is the max possible number of bytes for an ULEB128 encoded integer. +// Go's widest integer is uint64, so the length is 10. +const MaxUleb128Length = 10 // ULEB128SupportedTypes is a contraint interface that limits the input to -// [ULEB128Encode] and [ULEB128Decode] to signed and unsigned integers except int8. +// [ULEB128Encode] and [ULEB128Decode] to signed and unsigned integers. type ULEB128SupportedTypes interface { - ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uint | ~int16 | ~int32 | ~int64 | ~int + ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uint | ~int8 | ~int16 | ~int32 | ~int64 | ~int } // ULEB128Encode converts an integer into []byte (see [wikipedia] and [bcs]) // +// This reuses [binary.PutUvarint] in standard library. +// // [wikipedia]: https://en.wikipedia.org/wiki/LEB128 // [bcs]: https://github.com/diem/bcs#uleb128-encoded-integers func ULEB128Encode[T ULEB128SupportedTypes](input T) []byte { - var result []byte - - for { - b := (byte)(input & 127) - input >>= 7 - - if input == 0 { - result = append(result, b) - break - } else { - result = append(result, b|128) - } - } - - return result + result := make([]byte, 10) + i := binary.PutUvarint(result, uint64(input)) + return result[:i] } -// ULEB128Decode decodes byte array into an integer, returns the decoded value, the number of bytes consumed, and a possible error. -// If error is returned, the number of bytes returned is guaranteed to be 0. -func ULEB128Decode[T ULEB128SupportedTypes](data []byte) (T, int, error) { +// ULEB128Decode decodes [io.Reader] into an integer, returns the resulted value, the number of byte read, and a possible error. +// +// [binary.ReadUvarint] is not used here because +// - it doesn't support returning the number of bytes read. +// - it accepts only [io.ByteReader], but the recommended way of creating one from [bufio.NewReader] will read more than 1 byte at the +// to fill the buffer. +func ULEB128Decode[T ULEB128SupportedTypes](r io.Reader) (T, int, error) { + buf := make([]byte, 1, 1) var v, shift T - for i := 0; i < len(data); i++ { - d := T(data[i]) + var n int + for n < 10 { + i, err := r.Read(buf) + if i == 0 { + return 0, n, fmt.Errorf("zero read in. possible EOF") + } + if err != nil { + return 0, n, err + } + n += i + + d := T(buf[0]) ld := d & 127 if (ld<>shift != ld { - return v, 0, fmt.Errorf("overflow at index %d: %v", i, ld) + return v, n, fmt.Errorf("overflow at index %d: %v", n-1, ld) } + ld <<= shift v = ld + v if v < ld { - return v, 0, fmt.Errorf("overflow after adding index %d: %v %v", i, ld, v) + return v, n, fmt.Errorf("overflow after adding index %d: %v %v", n-1, ld, v) } - if d < 128 { - return v, i + 1, nil + if d <= 127 { + return v, n, nil } + shift += 7 } - return v, 0, fmt.Errorf("failed to find the highest significant 7 bits: %v", v) + return 0, n, fmt.Errorf("failed to find most significant bytes after reading %d bytes", n) } diff --git a/bcs/uleb128_test.go b/bcs/uleb128_test.go index db0b6e0..e3ec918 100644 --- a/bcs/uleb128_test.go +++ b/bcs/uleb128_test.go @@ -1,6 +1,7 @@ package bcs_test import ( + "bytes" "testing" "github.com/fardream/go-bcs/bcs" @@ -32,7 +33,7 @@ func TestULEB128Encode(t *testing.T) { func TestULEB128Decode(t *testing.T) { for _, aCase := range uleb128Tests { - r, n, e := bcs.ULEB128Decode[uint32](aCase.Expected) + r, n, e := bcs.ULEB128Decode[uint32](bytes.NewReader(aCase.Expected)) if e != nil { t.Fatalf("failed to decode: %v", e) } @@ -45,7 +46,7 @@ func TestULEB128Decode(t *testing.T) { } for _, aCase := range uleb128Tests[3:] { - r, n, e := bcs.ULEB128Decode[uint8](aCase.Expected) + r, n, e := bcs.ULEB128Decode[uint8](bytes.NewReader(aCase.Expected)) if e == nil { t.Fatalf("should overflow: %d %d", r, n) } else { diff --git a/bcs/unmarshaler.go b/bcs/unmarshaler.go new file mode 100644 index 0000000..cea3474 --- /dev/null +++ b/bcs/unmarshaler.go @@ -0,0 +1,11 @@ +package bcs + +import "io" + +// Unmarshaler customizes the unmarshalling behavior for a type. +// +// Compared with other Unmarshalers in golang, the Unmarshaler here takes +// a [io.Reader] instead of []byte, since it is difficult to delimit the byte streams. +type Unmarshaler interface { + UnmarshalBCS(io.Reader) (int, error) +}