From 3e4ac04e29e66d46ad8193868b0f2efc989edd56 Mon Sep 17 00:00:00 2001 From: Chao Xu Date: Mon, 18 Sep 2023 19:53:53 -0400 Subject: [PATCH] update tag and remove `WithSize` functions (#4) * update tag * remove `WithSize` functions. --- bcs/decode.go | 75 +++++++++++++++----------------------- bcs/decode_test.go | 10 ++--- bcs/encode.go | 20 +++++----- bcs/enum_example_test.go | 1 + bcs/enum_unmarshal_test.go | 2 +- bcs/tag.go | 18 +++++++-- 6 files changed, 60 insertions(+), 66 deletions(-) diff --git a/bcs/decode.go b/bcs/decode.go index 9a12168..1d59359 100644 --- a/bcs/decode.go +++ b/bcs/decode.go @@ -16,42 +16,28 @@ import ( // 1. if [Unmarshaler], use "UnmarshalBCS" method. // 2. if not [Unmarshaler] but [Enum], use the specialization for [Enum]. // 3. otherwise standard process. -func Unmarshal(data []byte, v any) error { +func Unmarshal(data []byte, v any) (int, error) { return NewDecoder(bytes.NewReader(data)).Decode(v) } -// UnmarshalWithSize unmarshals the bcs serialized ata into v, and returns the number of bytes consumed. -func UnmarshalWithSize(data []byte, v any) (int, error) { - return NewDecoder(bytes.NewReader(data)).DecodeWithSize(v) -} - // Decoder takes an [io.Reader] and decodes value from it. type Decoder struct { - r io.Reader + reader io.Reader byteBuffer [1]byte } // NewDecoder creates a new [Decoder] from an [io.Reader] func NewDecoder(r io.Reader) *Decoder { return &Decoder{ - r: r, + reader: r, } } -// Decode a value from the decoder. -// -// - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called. -// - If the value is [Enum], it will be special handled for [Enum] -func (d *Decoder) Decode(v any) error { - _, err := d.DecodeWithSize(v) - return err -} - // DecodeWithSize decodes a value from the decoder, and returns the number of bytes it consumed from the decoder. // // - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called. // - If the value is [Enum], it will be special handled for [Enum] -func (d *Decoder) DecodeWithSize(v any) (int, error) { +func (d *Decoder) Decode(v any) (int, error) { reflectValue := reflect.ValueOf(v) if reflectValue.Kind() != reflect.Pointer || reflectValue.IsNil() { return 0, fmt.Errorf("not a pointer or nil pointer") @@ -74,7 +60,7 @@ func (d *Decoder) decode(v reflect.Value) (int, error) { // Unmarshaler if i, isUnmarshaler := v.Interface().(Unmarshaler); isUnmarshaler { - return i.UnmarshalBCS(d.r) + return i.UnmarshalBCS(d.reader) } // Enum @@ -135,13 +121,13 @@ func (d *Decoder) decodeVanilla(v reflect.Value) (int, error) { return n, nil case reflect.Int8, reflect.Uint8: - return 1, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface()) + return 1, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface()) case reflect.Int16, reflect.Uint16: - return 2, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface()) + return 2, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface()) case reflect.Int32, reflect.Uint32: - return 4, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface()) + return 4, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface()) case reflect.Int64, reflect.Uint64: - return 8, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface()) + return 8, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface()) case reflect.Struct: return d.decodeStruct(v) @@ -167,7 +153,7 @@ func (d *Decoder) decodeVanilla(v reflect.Value) (int, error) { // decodeString func (d *Decoder) decodeString(v reflect.Value) (int, error) { - size, n, err := ULEB128Decode[int](d.r) + size, n, err := ULEB128Decode[int](d.reader) if err != nil { return n, err } @@ -179,7 +165,7 @@ func (d *Decoder) decodeString(v reflect.Value) (int, error) { tmp := make([]byte, size) - read, err := d.r.Read(tmp) + read, err := d.reader.Read(tmp) n += read if err != nil { return n, err @@ -197,7 +183,7 @@ func (d *Decoder) decodeString(v reflect.Value) (int, error) { // readByte reads one byte from the input, error if no byte is read. func (d *Decoder) readByte() (byte, int, error) { b := d.byteBuffer[:] - n, err := d.r.Read(b) + n, err := d.reader.Read(b) if err != nil { return 0, n, err } @@ -212,21 +198,22 @@ func (d *Decoder) decodeStruct(v reflect.Value) (int, error) { t := v.Type() var n int + +fieldLoop: for i := 0; i < v.NumField(); i++ { field := v.Field(i) if !field.CanInterface() { - continue + continue fieldLoop } tag, err := parseTagValue(t.Field(i).Tag.Get(tagName)) if err != nil { return n, err } - // ignored - if tag&tagValue_Ignore != 0 { - continue - } - // optional - if tag&tagValue_Optional != 0 { + + switch { + case tag.isIgnored(): // ignored + continue fieldLoop + case tag.isOptional(): // optional isOptional, k, err := d.readByte() n += k if err != nil { @@ -242,14 +229,12 @@ func (d *Decoder) decodeStruct(v reflect.Value) (int, error) { return n, err } } - - continue - } - - k, err := d.decode(field) - n += k - if err != nil { - return n, err + default: + k, err := d.decode(field) + n += k + if err != nil { + return n, err + } } } @@ -260,7 +245,7 @@ func (d *Decoder) decodeEnum(v reflect.Value) (int, error) { if v.Kind() != reflect.Struct { return 0, fmt.Errorf("only support struct for Enum, got %s", v.Kind().String()) } - enumId, n, err := ULEB128Decode[int](d.r) + enumId, n, err := ULEB128Decode[int](d.reader) if err != nil { return n, err } @@ -274,14 +259,14 @@ func (d *Decoder) decodeEnum(v reflect.Value) (int, error) { } func (d *Decoder) decodeByteSlice(v reflect.Value) (int, error) { - size, n, err := ULEB128Decode[int](d.r) + size, n, err := ULEB128Decode[int](d.reader) if err != nil { return n, err } tmp := make([]byte, size) - read, err := d.r.Read(tmp) + read, err := d.reader.Read(tmp) n += read if err != nil { return n, err @@ -329,7 +314,7 @@ func (d *Decoder) decodeArray(v reflect.Value) (int, error) { func (d *Decoder) decodeSlice(v reflect.Value) (int, error) { // get the length of the slice. - size, n, err := ULEB128Decode[int](d.r) + size, n, err := ULEB128Decode[int](d.reader) if err != nil { return n, err } diff --git a/bcs/decode_test.go b/bcs/decode_test.go index 983480f..d6471f1 100644 --- a/bcs/decode_test.go +++ b/bcs/decode_test.go @@ -14,7 +14,7 @@ func runVanillaCaseTest[T bool | uint8 | int8 | int16 | uint16 | int32 | uint32 } nv := new(T) - n, err := bcs.UnmarshalWithSize(exp, nv) + n, err := bcs.Unmarshal(exp, nv) if err != nil { return err } @@ -35,7 +35,7 @@ func runVanillaSliceCaseTest[T bool | uint8 | int8 | int16 | uint16 | int32 | ui } nv := make([]T, 0) - n, err := bcs.UnmarshalWithSize(exp, &nv) + n, err := bcs.Unmarshal(exp, &nv) if err != nil { return err } @@ -57,7 +57,7 @@ func runVanillaSliceCaseTest[T bool | uint8 | int8 | int16 | uint16 | int32 | ui return nil } -func TestUnmarshalWithSize_BasicTypes(t *testing.T) { +func TestUnmarshal_BasicTypes(t *testing.T) { for _, aCase := range basicMarshalTests { if err := runVanillaCaseTest[bool](aCase.input, aCase.expected); err != nil { t.Fatal(err) @@ -126,7 +126,7 @@ var unmarshalCases = []*UnmarshalCase{ }, } -func TestUnmarshalWithSize(t *testing.T) { +func TestUnmarshal(t *testing.T) { for _, v := range unmarshalCases { m, err := bcs.Marshal(v.v) if err != nil { @@ -136,7 +136,7 @@ func TestUnmarshalWithSize(t *testing.T) { t.Errorf("want: %v, got %v", v.expected, m) } nv := new(UnmarshalStruct) - n, err := bcs.UnmarshalWithSize(v.expected, nv) + n, err := bcs.Unmarshal(v.expected, nv) if err != nil { t.Error(err) } diff --git a/bcs/encode.go b/bcs/encode.go index bc5cfa3..a666dac 100644 --- a/bcs/encode.go +++ b/bcs/encode.go @@ -116,7 +116,7 @@ func (e *Encoder) encodeEnum(v reflect.Value) error { if err != nil { return err } - if tag&tagValue_Ignore > 0 { + if tag.isIgnored() { continue } fieldKind := field.Kind() @@ -191,13 +191,10 @@ func (e *Encoder) encodeStruct(v reflect.Value) error { if err != nil { return err } - // ignored - if tag&tagValue_Ignore != 0 { + switch { + case tag.isIgnored(): continue - } - - // optional - if tag&tagValue_Optional != 0 { + case tag.isOptional(): if field.Kind() != reflect.Pointer && field.Kind() != reflect.Interface { return fmt.Errorf("optional field can only be pointer or interface") } @@ -215,10 +212,11 @@ func (e *Encoder) encodeStruct(v reflect.Value) error { } } continue - } - // finally - if err := e.encode(field); err != nil { - return err + default: + // finally + if err := e.encode(field); err != nil { + return err + } } } diff --git a/bcs/enum_example_test.go b/bcs/enum_example_test.go index 1728591..6b56002 100644 --- a/bcs/enum_example_test.go +++ b/bcs/enum_example_test.go @@ -9,6 +9,7 @@ import ( type AnotherStruct struct { S string } + type EnumExample struct { V0 *uint8 V1 *uint16 `bcs:"-"` diff --git a/bcs/enum_unmarshal_test.go b/bcs/enum_unmarshal_test.go index e262d53..dae0b08 100644 --- a/bcs/enum_unmarshal_test.go +++ b/bcs/enum_unmarshal_test.go @@ -16,7 +16,7 @@ func TestEnum_Unmarshal(t *testing.T) { for _, v := range cases { e := &EnumExample{} - n, err := bcs.UnmarshalWithSize(v, e) + n, err := bcs.Unmarshal(v, e) if err != nil { t.Error(err) } diff --git a/bcs/tag.go b/bcs/tag.go index 70c810d..a6392b4 100644 --- a/bcs/tag.go +++ b/bcs/tag.go @@ -7,13 +7,15 @@ import ( const tagName = "bcs" +type tagValue int64 + const ( - tagValue_Optional int64 = 1 << iota // optional - tagValue_Ignore // - + tagValue_Optional tagValue = 1 << iota // optional + tagValue_Ignore // - ) -func parseTagValue(tag string) (int64, error) { - var r int64 +func parseTagValue(tag string) (tagValue, error) { + var r tagValue tagSegs := strings.Split(tag, ",") for _, seg := range tagSegs { seg := strings.TrimSpace(seg) @@ -32,3 +34,11 @@ func parseTagValue(tag string) (int64, error) { return r, nil } + +func (t tagValue) isOptional() bool { + return t&tagValue_Optional != 0 +} + +func (t tagValue) isIgnored() bool { + return t&tagValue_Ignore != 0 +}