Skip to content

Commit

Permalink
update tag and remove WithSize functions (#4)
Browse files Browse the repository at this point in the history
* update tag

* remove `WithSize` functions.
  • Loading branch information
fardream authored Sep 18, 2023
1 parent 50f7eb7 commit 3e4ac04
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 66 deletions.
75 changes: 30 additions & 45 deletions bcs/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,28 @@ import (
// 1. if [Unmarshaler], use "UnmarshalBCS" method.
// 2. if not [Unmarshaler] but [Enum], use the specialization for [Enum].
// 3. otherwise standard process.
func Unmarshal(data []byte, v any) error {
func Unmarshal(data []byte, v any) (int, error) {
return NewDecoder(bytes.NewReader(data)).Decode(v)
}

// UnmarshalWithSize unmarshals the bcs serialized ata into v, and returns the number of bytes consumed.
func UnmarshalWithSize(data []byte, v any) (int, error) {
return NewDecoder(bytes.NewReader(data)).DecodeWithSize(v)
}

// Decoder takes an [io.Reader] and decodes value from it.
type Decoder struct {
r io.Reader
reader io.Reader
byteBuffer [1]byte
}

// NewDecoder creates a new [Decoder] from an [io.Reader]
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{
r: r,
reader: r,
}
}

// Decode a value from the decoder.
//
// - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called.
// - If the value is [Enum], it will be special handled for [Enum]
func (d *Decoder) Decode(v any) error {
_, err := d.DecodeWithSize(v)
return err
}

// DecodeWithSize decodes a value from the decoder, and returns the number of bytes it consumed from the decoder.
//
// - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called.
// - If the value is [Enum], it will be special handled for [Enum]
func (d *Decoder) DecodeWithSize(v any) (int, error) {
func (d *Decoder) Decode(v any) (int, error) {
reflectValue := reflect.ValueOf(v)
if reflectValue.Kind() != reflect.Pointer || reflectValue.IsNil() {
return 0, fmt.Errorf("not a pointer or nil pointer")
Expand All @@ -74,7 +60,7 @@ func (d *Decoder) decode(v reflect.Value) (int, error) {

// Unmarshaler
if i, isUnmarshaler := v.Interface().(Unmarshaler); isUnmarshaler {
return i.UnmarshalBCS(d.r)
return i.UnmarshalBCS(d.reader)
}

// Enum
Expand Down Expand Up @@ -135,13 +121,13 @@ func (d *Decoder) decodeVanilla(v reflect.Value) (int, error) {
return n, nil

case reflect.Int8, reflect.Uint8:
return 1, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface())
return 1, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int16, reflect.Uint16:
return 2, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface())
return 2, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int32, reflect.Uint32:
return 4, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface())
return 4, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int64, reflect.Uint64:
return 8, binary.Read(d.r, binary.LittleEndian, v.Addr().Interface())
return 8, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())

case reflect.Struct:
return d.decodeStruct(v)
Expand All @@ -167,7 +153,7 @@ func (d *Decoder) decodeVanilla(v reflect.Value) (int, error) {

// decodeString
func (d *Decoder) decodeString(v reflect.Value) (int, error) {
size, n, err := ULEB128Decode[int](d.r)
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
Expand All @@ -179,7 +165,7 @@ func (d *Decoder) decodeString(v reflect.Value) (int, error) {

tmp := make([]byte, size)

read, err := d.r.Read(tmp)
read, err := d.reader.Read(tmp)
n += read
if err != nil {
return n, err
Expand All @@ -197,7 +183,7 @@ func (d *Decoder) decodeString(v reflect.Value) (int, error) {
// readByte reads one byte from the input, error if no byte is read.
func (d *Decoder) readByte() (byte, int, error) {
b := d.byteBuffer[:]
n, err := d.r.Read(b)
n, err := d.reader.Read(b)
if err != nil {
return 0, n, err
}
Expand All @@ -212,21 +198,22 @@ func (d *Decoder) decodeStruct(v reflect.Value) (int, error) {
t := v.Type()

var n int

fieldLoop:
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.CanInterface() {
continue
continue fieldLoop
}
tag, err := parseTagValue(t.Field(i).Tag.Get(tagName))
if err != nil {
return n, err
}
// ignored
if tag&tagValue_Ignore != 0 {
continue
}
// optional
if tag&tagValue_Optional != 0 {

switch {
case tag.isIgnored(): // ignored
continue fieldLoop
case tag.isOptional(): // optional
isOptional, k, err := d.readByte()
n += k
if err != nil {
Expand All @@ -242,14 +229,12 @@ func (d *Decoder) decodeStruct(v reflect.Value) (int, error) {
return n, err
}
}

continue
}

k, err := d.decode(field)
n += k
if err != nil {
return n, err
default:
k, err := d.decode(field)
n += k
if err != nil {
return n, err
}
}
}

Expand All @@ -260,7 +245,7 @@ func (d *Decoder) decodeEnum(v reflect.Value) (int, error) {
if v.Kind() != reflect.Struct {
return 0, fmt.Errorf("only support struct for Enum, got %s", v.Kind().String())
}
enumId, n, err := ULEB128Decode[int](d.r)
enumId, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
Expand All @@ -274,14 +259,14 @@ func (d *Decoder) decodeEnum(v reflect.Value) (int, error) {
}

func (d *Decoder) decodeByteSlice(v reflect.Value) (int, error) {
size, n, err := ULEB128Decode[int](d.r)
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}

tmp := make([]byte, size)

read, err := d.r.Read(tmp)
read, err := d.reader.Read(tmp)
n += read
if err != nil {
return n, err
Expand Down Expand Up @@ -329,7 +314,7 @@ func (d *Decoder) decodeArray(v reflect.Value) (int, error) {

func (d *Decoder) decodeSlice(v reflect.Value) (int, error) {
// get the length of the slice.
size, n, err := ULEB128Decode[int](d.r)
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
Expand Down
10 changes: 5 additions & 5 deletions bcs/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func runVanillaCaseTest[T bool | uint8 | int8 | int16 | uint16 | int32 | uint32
}

nv := new(T)
n, err := bcs.UnmarshalWithSize(exp, nv)
n, err := bcs.Unmarshal(exp, nv)
if err != nil {
return err
}
Expand All @@ -35,7 +35,7 @@ func runVanillaSliceCaseTest[T bool | uint8 | int8 | int16 | uint16 | int32 | ui
}

nv := make([]T, 0)
n, err := bcs.UnmarshalWithSize(exp, &nv)
n, err := bcs.Unmarshal(exp, &nv)
if err != nil {
return err
}
Expand All @@ -57,7 +57,7 @@ func runVanillaSliceCaseTest[T bool | uint8 | int8 | int16 | uint16 | int32 | ui
return nil
}

func TestUnmarshalWithSize_BasicTypes(t *testing.T) {
func TestUnmarshal_BasicTypes(t *testing.T) {
for _, aCase := range basicMarshalTests {
if err := runVanillaCaseTest[bool](aCase.input, aCase.expected); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -126,7 +126,7 @@ var unmarshalCases = []*UnmarshalCase{
},
}

func TestUnmarshalWithSize(t *testing.T) {
func TestUnmarshal(t *testing.T) {
for _, v := range unmarshalCases {
m, err := bcs.Marshal(v.v)
if err != nil {
Expand All @@ -136,7 +136,7 @@ func TestUnmarshalWithSize(t *testing.T) {
t.Errorf("want: %v, got %v", v.expected, m)
}
nv := new(UnmarshalStruct)
n, err := bcs.UnmarshalWithSize(v.expected, nv)
n, err := bcs.Unmarshal(v.expected, nv)
if err != nil {
t.Error(err)
}
Expand Down
20 changes: 9 additions & 11 deletions bcs/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (e *Encoder) encodeEnum(v reflect.Value) error {
if err != nil {
return err
}
if tag&tagValue_Ignore > 0 {
if tag.isIgnored() {
continue
}
fieldKind := field.Kind()
Expand Down Expand Up @@ -191,13 +191,10 @@ func (e *Encoder) encodeStruct(v reflect.Value) error {
if err != nil {
return err
}
// ignored
if tag&tagValue_Ignore != 0 {
switch {
case tag.isIgnored():
continue
}

// optional
if tag&tagValue_Optional != 0 {
case tag.isOptional():
if field.Kind() != reflect.Pointer && field.Kind() != reflect.Interface {
return fmt.Errorf("optional field can only be pointer or interface")
}
Expand All @@ -215,10 +212,11 @@ func (e *Encoder) encodeStruct(v reflect.Value) error {
}
}
continue
}
// finally
if err := e.encode(field); err != nil {
return err
default:
// finally
if err := e.encode(field); err != nil {
return err
}
}
}

Expand Down
1 change: 1 addition & 0 deletions bcs/enum_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
type AnotherStruct struct {
S string
}

type EnumExample struct {
V0 *uint8
V1 *uint16 `bcs:"-"`
Expand Down
2 changes: 1 addition & 1 deletion bcs/enum_unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestEnum_Unmarshal(t *testing.T) {
for _, v := range cases {
e := &EnumExample{}

n, err := bcs.UnmarshalWithSize(v, e)
n, err := bcs.Unmarshal(v, e)
if err != nil {
t.Error(err)
}
Expand Down
18 changes: 14 additions & 4 deletions bcs/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (

const tagName = "bcs"

type tagValue int64

const (
tagValue_Optional int64 = 1 << iota // optional
tagValue_Ignore // -
tagValue_Optional tagValue = 1 << iota // optional
tagValue_Ignore // -
)

func parseTagValue(tag string) (int64, error) {
var r int64
func parseTagValue(tag string) (tagValue, error) {
var r tagValue
tagSegs := strings.Split(tag, ",")
for _, seg := range tagSegs {
seg := strings.TrimSpace(seg)
Expand All @@ -32,3 +34,11 @@ func parseTagValue(tag string) (int64, error) {

return r, nil
}

func (t tagValue) isOptional() bool {
return t&tagValue_Optional != 0
}

func (t tagValue) isIgnored() bool {
return t&tagValue_Ignore != 0
}

0 comments on commit 3e4ac04

Please sign in to comment.