From 50f7eb738305cadf0c9c54b8a2e59e0c3aa94ea5 Mon Sep 17 00:00:00 2001 From: Chao Xu Date: Sun, 17 Sep 2023 16:28:38 -0400 Subject: [PATCH] Rework decoder. (#3) --- bcs/decode.go | 229 +++++++++++++++++++++++++------------ bcs/decode_test.go | 156 +++++++++++++++++++++++++ bcs/encode_test.go | 2 +- bcs/enum_unmarshal_test.go | 37 ++++++ bcs/uint128.go | 18 +++ bcs/unmarshaler.go | 1 + go.sum | 0 7 files changed, 369 insertions(+), 74 deletions(-) create mode 100644 bcs/decode_test.go create mode 100644 bcs/enum_unmarshal_test.go delete mode 100644 go.sum diff --git a/bcs/decode.go b/bcs/decode.go index 22a7dff..9a12168 100644 --- a/bcs/decode.go +++ b/bcs/decode.go @@ -8,7 +8,7 @@ import ( "reflect" ) -// Unmarshal unmarshal the bcs serialized data into v. +// Unmarshal unmarshals the bcs serialized data into v. // // Refer to notes in [Marshal] for details how data serialized/deserialized. // @@ -20,9 +20,15 @@ func Unmarshal(data []byte, v any) error { return NewDecoder(bytes.NewReader(data)).Decode(v) } +// UnmarshalWithSize unmarshals the bcs serialized ata into v, and returns the number of bytes consumed. +func UnmarshalWithSize(data []byte, v any) (int, error) { + return NewDecoder(bytes.NewReader(data)).DecodeWithSize(v) +} + // Decoder takes an [io.Reader] and decodes value from it. type Decoder struct { - r io.Reader + r io.Reader + byteBuffer [1]byte } // NewDecoder creates a new [Decoder] from an [io.Reader] @@ -37,30 +43,46 @@ func NewDecoder(r io.Reader) *Decoder { // - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called. // - If the value is [Enum], it will be special handled for [Enum] func (d *Decoder) Decode(v any) error { + _, err := d.DecodeWithSize(v) + return err +} + +// DecodeWithSize decodes a value from the decoder, and returns the number of bytes it consumed from the decoder. +// +// - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called. +// - If the value is [Enum], it will be special handled for [Enum] +func (d *Decoder) DecodeWithSize(v any) (int, error) { reflectValue := reflect.ValueOf(v) if reflectValue.Kind() != reflect.Pointer || reflectValue.IsNil() { - return fmt.Errorf("not a pointer or nil pointer") + return 0, fmt.Errorf("not a pointer or nil pointer") } return d.decode(reflectValue) } -func (d *Decoder) decode(v reflect.Value) error { +// decode is the main lifter, it first checks if a value can be [reflect.Value.CanInterface], +// then checks if the value implements [Unmarshaler] or [Enum], and then switch on the kind of the value: +// - pointer, create a new one and decode into its element. +// - interface, decode into element. +// - function, channel, unsafe pointers, ignore +// - otherwise call [decodeVanilla]. +func (d *Decoder) decode(v reflect.Value) (int, error) { // if v cannot interface, ignore if !v.CanInterface() { - return nil + return 0, nil } + // Unmarshaler if i, isUnmarshaler := v.Interface().(Unmarshaler); isUnmarshaler { - _, err := i.UnmarshalBCS(d.r) - return err + return i.UnmarshalBCS(d.r) } + // Enum if _, isEnum := v.Interface().(Enum); isEnum { switch v.Kind() { case reflect.Pointer, reflect.Interface: if v.IsNil() { - return fmt.Errorf("trying to decode into nil pointer/interface") + return 0, fmt.Errorf("trying to decode into nil pointer/interface") } return d.decodeEnum(v.Elem()) default: @@ -68,39 +90,40 @@ func (d *Decoder) decode(v reflect.Value) error { } } + // switch kind switch v.Kind() { case reflect.Pointer: if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } - return d.decodeVanilla(v.Elem()) + return d.decode(v.Elem()) case reflect.Interface: if v.IsNil() { - return fmt.Errorf("cannot decode into nil interface") + return 0, fmt.Errorf("cannot decode into nil interface") } return d.decode(v.Elem()) case reflect.Chan, reflect.Func, reflect.Uintptr, reflect.UnsafePointer: // silently ignore - return nil + return 0, nil default: return d.decodeVanilla(v) } } -func (d *Decoder) decodeVanilla(v reflect.Value) error { +// decodeVanilla decodes bool, ints, slice, struct, array, and string. +func (d *Decoder) decodeVanilla(v reflect.Value) (int, error) { kind := v.Kind() - if !v.CanSet() { - return fmt.Errorf("cannot change value of kind %s", kind.String()) + return 0, fmt.Errorf("cannot change value of kind %s", kind.String()) } - switch v.Kind() { + switch kind { case reflect.Bool: - t, err := d.readByte() + t, n, err := d.readByte() if err != nil { - return nil + return n, err } if t == 0 { @@ -109,17 +132,23 @@ func (d *Decoder) decodeVanilla(v reflect.Value) error { v.SetBool(true) } - return nil + return n, nil - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, // ints - reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // uints - return binary.Read(d.r, binary.LittleEndian, v.Addr) + case reflect.Int8, reflect.Uint8: + return 1, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface()) + case reflect.Int16, reflect.Uint16: + return 2, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface()) + case reflect.Int32, reflect.Uint32: + return 4, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface()) + case reflect.Int64, reflect.Uint64: + return 8, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface()) case reflect.Struct: return d.decodeStruct(v) case reflect.Slice: - if v.Elem().Kind() == reflect.Uint8 { + sliceType := v.Type().Elem() + if sliceType.Kind() == reflect.Uint8 { return d.decodeByteSlice(v) } @@ -132,48 +161,57 @@ func (d *Decoder) decodeVanilla(v reflect.Value) error { return d.decodeString(v) default: - return fmt.Errorf("unsupported vanilla decoding type: %s", kind.String()) + return 0, fmt.Errorf("unsupported vanilla decoding type: %s", kind.String()) } } -func (d *Decoder) decodeString(v reflect.Value) error { - size, _, err := ULEB128Decode[int](d.r) +// decodeString +func (d *Decoder) decodeString(v reflect.Value) (int, error) { + size, n, err := ULEB128Decode[int](d.r) if err != nil { - return err + return n, err + } + + if size == 0 { + v.SetString("") + return n, nil } tmp := make([]byte, size) read, err := d.r.Read(tmp) + n += read if err != nil { - return err + return n, err } if size != read { - return fmt.Errorf("wrong number of bytes read for []byte, want: %d, got %d", size, read) + return n, fmt.Errorf("wrong number of bytes read for string, want: %d, got %d", size, read) } - v.Set(reflect.ValueOf(string(tmp))) + v.SetString(string(tmp)) - return nil + return n, nil } -func (d *Decoder) readByte() (byte, error) { - b := make([]byte, 1) +// readByte reads one byte from the input, error if no byte is read. +func (d *Decoder) readByte() (byte, int, error) { + b := d.byteBuffer[:] n, err := d.r.Read(b) if err != nil { - return 0, err + return 0, n, err } if n == 0 { - return 0, fmt.Errorf("EOF") + return 0, n, io.ErrUnexpectedEOF } - return b[0], nil + return b[0], n, nil } -func (d *Decoder) decodeStruct(v reflect.Value) error { +func (d *Decoder) decodeStruct(v reflect.Value) (int, error) { t := v.Type() + var n int for i := 0; i < v.NumField(); i++ { field := v.Field(i) if !field.CanInterface() { @@ -181,7 +219,7 @@ func (d *Decoder) decodeStruct(v reflect.Value) error { } tag, err := parseTagValue(t.Field(i).Tag.Get(tagName)) if err != nil { - return err + return n, err } // ignored if tag&tagValue_Ignore != 0 { @@ -189,96 +227,141 @@ func (d *Decoder) decodeStruct(v reflect.Value) error { } // optional if tag&tagValue_Optional != 0 { - isOptional, err := d.readByte() + isOptional, k, err := d.readByte() + n += k if err != nil { - return err + return n, err } if isOptional == 0 { field.Set(reflect.Zero(v.Type())) } else { field.Set(reflect.New(field.Type().Elem())) - err := d.decode(field.Elem()) + k, err := d.decode(field.Elem()) + n += k if err != nil { - return err + return n, err } } + + continue } - if err := d.decode(field); err != nil { - return err + k, err := d.decode(field) + n += k + if err != nil { + return n, err } } - return nil + return n, nil } -func (d *Decoder) decodeEnum(v reflect.Value) error { +func (d *Decoder) decodeEnum(v reflect.Value) (int, error) { if v.Kind() != reflect.Struct { - return fmt.Errorf("only support struct for Enum, got %s", v.Kind().String()) + return 0, fmt.Errorf("only support struct for Enum, got %s", v.Kind().String()) } - enumId, _, err := ULEB128Decode[int](d.r) + enumId, n, err := ULEB128Decode[int](d.r) if err != nil { - return err + return n, err } field := v.Field(enumId) - return d.decode(field) + k, err := d.decode(field) + n += k + + return n, err } -func (d *Decoder) decodeByteSlice(v reflect.Value) error { - size, _, err := ULEB128Decode[int](d.r) +func (d *Decoder) decodeByteSlice(v reflect.Value) (int, error) { + size, n, err := ULEB128Decode[int](d.r) if err != nil { - return err + return n, err } tmp := make([]byte, size) read, err := d.r.Read(tmp) + n += read if err != nil { - return err + return n, err } if size != read { - return fmt.Errorf("wrong number of bytes read for []byte, want: %d, got %d", size, read) + return n, fmt.Errorf("wrong number of bytes read for []byte, want: %d, got %d", size, read) } v.Set(reflect.ValueOf(tmp)) - return nil + return n, nil } -func (d *Decoder) decodeArray(v reflect.Value) error { +func (d *Decoder) decodeArray(v reflect.Value) (int, error) { size := v.Len() t := v.Type() - - for i := 0; i < size; i++ { - v.Index(i).Set(reflect.New(t.Elem())) - if err := d.decode(v.Index(i)); err != nil { - return err + elementType := t.Elem() + + var n int + if elementType.Kind() == reflect.Pointer { + for i := 0; i < size; i++ { + idx := reflect.New(elementType.Elem()) + k, err := d.decode(idx.Elem()) + n += k + if err != nil { + return n, err + } + v.Index(i).Set(idx) + } + } else { + for i := 0; i < size; i++ { + idx := reflect.New(elementType) + k, err := d.decode(idx.Elem()) + n += k + if err != nil { + return n, err + } + v.Index(i).Set(idx.Elem()) } } - return nil + return n, nil } -func (d *Decoder) decodeSlice(v reflect.Value) error { - size, _, err := ULEB128Decode[int](d.r) +func (d *Decoder) decodeSlice(v reflect.Value) (int, error) { + // get the length of the slice. + size, n, err := ULEB128Decode[int](d.r) if err != nil { - return err + return n, err } - t := v.Type() - tmp := reflect.MakeSlice(t, 0, size) - for i := 0; i < size; i++ { - ind := reflect.New(t.Elem()) - if err := d.decode(ind); err != nil { - return err + // element type of the slice + elementType := v.Type().Elem() + // make a new slice + tmp := reflect.MakeSlice(v.Type(), 0, size) + + if elementType.Kind() == reflect.Pointer { + for i := 0; i < size; i++ { + ind := reflect.New(elementType.Elem()) + k, err := d.decode(ind.Elem()) + n += k + if err != nil { + return n, err + } + tmp = reflect.Append(tmp, ind) + } + } else { + for i := 0; i < size; i++ { + ind := reflect.New(elementType) + k, err := d.decode(ind.Elem()) + n += k + if err != nil { + return n, err + } + tmp = reflect.Append(tmp, ind.Elem()) } - tmp = reflect.Append(tmp, ind) } v.Set(tmp) - return nil + return n, nil } diff --git a/bcs/decode_test.go b/bcs/decode_test.go new file mode 100644 index 0000000..983480f --- /dev/null +++ b/bcs/decode_test.go @@ -0,0 +1,156 @@ +package bcs_test + +import ( + "fmt" + "testing" + + "github.com/fardream/go-bcs/bcs" +) + +func runVanillaCaseTest[T bool | uint8 | int8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | string](v any, exp []byte) error { + x, ok := v.(T) + if !ok { + return runVanillaSliceCaseTest[T](v, exp) + } + + nv := new(T) + n, err := bcs.UnmarshalWithSize(exp, nv) + if err != nil { + return err + } + if *nv != x { + return fmt.Errorf("want value: %v, got value: %v", x, *nv) + } + if n != len(exp) { + return fmt.Errorf("want length: %d, got length: %d", len(exp), n) + } + + return nil +} + +func runVanillaSliceCaseTest[T bool | uint8 | int8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | string](v any, exp []byte) error { + x, ok := v.([]T) + if !ok { + return nil + } + + nv := make([]T, 0) + n, err := bcs.UnmarshalWithSize(exp, &nv) + if err != nil { + return err + } + + if len(nv) != len(x) { + return fmt.Errorf("want length: %d, got %d", len(x), len(nv)) + } + + if n != len(exp) { + return fmt.Errorf("want parsed length: %d, got %d", len(exp), n) + } + + for i := 0; i < len(x); i++ { + if nv[i] != x[i] { + return fmt.Errorf("diff at %d %v %v", i, nv[i], x[i]) + } + } + + return nil +} + +func TestUnmarshalWithSize_BasicTypes(t *testing.T) { + for _, aCase := range basicMarshalTests { + if err := runVanillaCaseTest[bool](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[uint8](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[int8](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[uint16](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[int16](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[int32](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[uint32](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[int64](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[uint64](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + if err := runVanillaCaseTest[string](aCase.input, aCase.expected); err != nil { + t.Fatal(err) + } + } +} + +type UnmarshalStruct struct { + WrapperWithOptional + StructArray [2]*MyStruct +} + +type UnmarshalCase struct { + v *UnmarshalStruct + expected []byte + errNotNil bool +} + +var unmarshalCases = []*UnmarshalCase{ + { + v: &UnmarshalStruct{ + WrapperWithOptional: WrapperWithOptional{ + Inner: MyStruct{Bytes: []byte{9, 2}}, + Outer: new(string), + }, + StructArray: [2]*MyStruct{ + { + Boolean: true, + Bytes: []byte{1, 2, 3}, + Label: "what", + }, + { + Boolean: false, + }, + }, + }, + errNotNil: false, + expected: []byte{0, 2, 9, 2, 0, 1, 0, 1, 3, 1, 2, 3, 4, 119, 104, 97, 116, 0, 0, 0}, + }, +} + +func TestUnmarshalWithSize(t *testing.T) { + for _, v := range unmarshalCases { + m, err := bcs.Marshal(v.v) + if err != nil { + t.Error(err) + } + if !sliceEqual(m, v.expected) { + t.Errorf("want: %v, got %v", v.expected, m) + } + nv := new(UnmarshalStruct) + n, err := bcs.UnmarshalWithSize(v.expected, nv) + if err != nil { + t.Error(err) + } + if n != len(v.expected) { + t.Errorf("want parsed length: %d, got: %d", len(v.expected), n) + } + + nb, err := bcs.Marshal(nv) + if err != nil { + t.Fatal(err) + } + + if !sliceEqual(nb, v.expected) { + t.Fatalf("want: %v, got: %v", v.expected, nb) + } + } +} diff --git a/bcs/encode_test.go b/bcs/encode_test.go index 528319e..63abe63 100644 --- a/bcs/encode_test.go +++ b/bcs/encode_test.go @@ -22,6 +22,7 @@ var utf8Encoded = []byte{ // // [bcs repo]: https://github.com/diem/bcs var basicMarshalTests = []BasicTypeTest{ + {input: []uint16{1, 2}, expected: []byte{2, 1, 0, 2, 0}}, {input: false, expected: []byte{0}}, {input: true, expected: []byte{1}}, {input: uint8(1), expected: []byte{1}}, @@ -32,7 +33,6 @@ var basicMarshalTests = []BasicTypeTest{ {input: uint32(305419896), expected: []byte{0x78, 0x56, 0x34, 0x12}}, {input: int64(-1311768467750121216), expected: []byte{0x00, 0x11, 0x32, 0x54, 0x87, 0xa9, 0xcb, 0xed}}, {input: uint64(1311768467750121216), expected: []byte{0x00, 0xef, 0xcd, 0xab, 0x78, 0x56, 0x34, 0x12}}, - {input: []uint16{1, 2}, expected: []byte{2, 1, 0, 2, 0}}, {input: utf8Str, expected: utf8Encoded}, } diff --git a/bcs/enum_unmarshal_test.go b/bcs/enum_unmarshal_test.go new file mode 100644 index 0000000..e262d53 --- /dev/null +++ b/bcs/enum_unmarshal_test.go @@ -0,0 +1,37 @@ +package bcs_test + +import ( + "testing" + + "github.com/fardream/go-bcs/bcs" +) + +func TestEnum_Unmarshal(t *testing.T) { + cases := [][]byte{ + {0, 42}, + {0, 0}, + {4, 3, 97, 98, 99}, + } + + for _, v := range cases { + e := &EnumExample{} + + n, err := bcs.UnmarshalWithSize(v, e) + if err != nil { + t.Error(err) + } + + if n != len(v) { + t.Errorf("want parsed length: %d, got: %d", len(v), n) + } + + nb, err := bcs.Marshal(e) + if err != nil { + t.Error(err) + } + + if !sliceEqual(nb, v) { + t.Errorf("want %v, got %v", v, nb) + } + } +} diff --git a/bcs/uint128.go b/bcs/uint128.go index 85dcf5b..27c2373 100644 --- a/bcs/uint128.go +++ b/bcs/uint128.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "io" "math/big" ) @@ -17,6 +18,7 @@ var ( _ json.Marshaler = (*Uint128)(nil) _ json.Unmarshaler = (*Uint128)(nil) _ Marshaler = (*Uint128)(nil) + _ Unmarshaler = (*Uint128)(nil) ) func (i Uint128) Big() *big.Int { @@ -120,6 +122,22 @@ func (i Uint128) MarshalBCS() ([]byte, error) { return r, nil } +func (i *Uint128) UnmarshalBCS(r io.Reader) (int, error) { + buf := make([]byte, 16) + n, err := r.Read(buf) + if err != nil { + return n, err + } + if n != 16 { + return n, fmt.Errorf("failed to read 16 bytes for Uint128 (read %d bytes)", n) + } + + i.lo = binary.LittleEndian.Uint64(buf[0:8]) + i.hi = binary.LittleEndian.Uint64(buf[8:16]) + + return n, nil +} + func (i *Uint128) Cmp(j *Uint128) int { switch { case i.hi > j.hi || (i.hi == j.hi && i.lo > j.lo): diff --git a/bcs/unmarshaler.go b/bcs/unmarshaler.go index 3e80499..37325e3 100644 --- a/bcs/unmarshaler.go +++ b/bcs/unmarshaler.go @@ -6,6 +6,7 @@ import "io" // // Compared with other Unmarshalers in golang, the Unmarshaler here takes // a [io.Reader] instead of []byte, since it is difficult to delimit the byte streams without unmarshalling. +// Method [UnmarshalBCS] returns the number of bytes read, and potentially an error. type Unmarshaler interface { UnmarshalBCS(io.Reader) (int, error) } diff --git a/go.sum b/go.sum deleted file mode 100644 index e69de29..0000000