From 1de38982a83cce5ef2a03879df43a35f98519d2f Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 29 Oct 2024 11:09:32 -0700 Subject: [PATCH] Fix omitempty on aliased types (#377) Instead of doing return, insert an `if` block to skip emitting zero fields. This also allows the check to be inserted at any level. Fixes #376 Emitted code: ```Go // MarshalMsg implements msgp.Marshaler func (z TypeSamples) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) o = msgp.AppendArrayHeader(o, uint32(len(z))) for zb0004 := range z { // check for omitted fields zb0001Len := uint32(2) var zb0001Mask uint8 /* 2 bits */ _ = zb0001Mask if z[zb0004].K == 0 { zb0001Len-- zb0001Mask |= 0x1 } if z[zb0004].V == 0 { zb0001Len-- zb0001Mask |= 0x2 } // variable map header, size zb0001Len o = append(o, 0x80|uint8(zb0001Len)) // skip if no fields are to be emitted if zb0001Len != 0 { if (zb0001Mask & 0x1) == 0 { // if not omitted // string "k" o = append(o, 0xa1, 0x6b) o = msgp.AppendUint32(o, z[zb0004].K) } if (zb0001Mask & 0x2) == 0 { // if not omitted // string "v" o = append(o, 0xa1, 0x76) o = msgp.AppendUint32(o, z[zb0004].V) } } } return } ``` If only 1 field, the check is omitted (and there is similar behavior on clearomitted). --- _generated/omitempty.go | 7 +++++++ _generated/omitempty_test.go | 33 +++++++++++++++++++++++++++++++++ gen/decode.go | 8 ++++++-- gen/encode.go | 14 +++++++++----- gen/marshal.go | 14 +++++++++----- gen/unmarshal.go | 8 ++++++-- 6 files changed, 70 insertions(+), 14 deletions(-) diff --git a/_generated/omitempty.go b/_generated/omitempty.go index 5dc8da11..cb3f33ef 100644 --- a/_generated/omitempty.go +++ b/_generated/omitempty.go @@ -54,6 +54,13 @@ type OmitEmpty0 struct { ATime time.Time `msg:"atime,omitempty"` } +type TypeSample struct { + K uint32 `msg:"k,omitempty"` + V uint32 `msg:"v,omitempty"` +} + +type TypeSamples []TypeSample + type ( NamedBool bool NamedInt int diff --git a/_generated/omitempty_test.go b/_generated/omitempty_test.go index f0ec87d6..5d077ebb 100644 --- a/_generated/omitempty_test.go +++ b/_generated/omitempty_test.go @@ -3,6 +3,7 @@ package _generated import ( "bytes" "io" + "reflect" "testing" "github.com/tinylib/msgp/msgp" @@ -285,3 +286,35 @@ func BenchmarkNotOmitEmpty10AllFull(b *testing.B) { } } } + +func TestTypeAlias(t *testing.T) { + value := TypeSamples{TypeSample{}, TypeSample{K: 1, V: 2}} + encoded, err := value.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + var got TypeSamples + _, err = got.UnmarshalMsg(encoded) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(value, got) { + t.Errorf("UnmarshalMsg got %v want %v", value, got) + } + var buf bytes.Buffer + w := msgp.NewWriter(&buf) + err = value.EncodeMsg(w) + if err != nil { + t.Fatal(err) + } + w.Flush() + got = TypeSamples{} + r := msgp.NewReader(&buf) + err = got.DecodeMsg(r) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(value, got) { + t.Errorf("UnmarshalMsg got %v want %v", value, got) + } +} diff --git a/gen/decode.go b/gen/decode.go index eb060e26..55977d4a 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -159,7 +159,9 @@ func (d *decodeGen) structAsMap(s *Struct) { if oeCount > 0 { d.p.printf("\n// Clear omitted fields.\n") - d.p.printf("if %s {\n", bm.notAllSet()) + if bm.bitlen > 1 { + d.p.printf("if %s {\n", bm.notAllSet()) + } for bitIdx, fieldIdx := range oeEmittedIdx { fieldElem := s.Fields[fieldIdx].FieldElem @@ -172,7 +174,9 @@ func (d *decodeGen) structAsMap(s *Struct) { } d.p.printf("}\n") } - d.p.printf("}") + if bm.bitlen > 1 { + d.p.printf("}") + } } } diff --git a/gen/encode.go b/gen/encode.go index 4e654c26..680d7bd0 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -3,7 +3,6 @@ package gen import ( "fmt" "io" - "strings" "github.com/tinylib/msgp/msgp" ) @@ -142,6 +141,7 @@ func (e *encodeGen) structmap(s *Struct) { omitempty := s.AnyHasTagPart("omitempty") omitzero := s.AnyHasTagPart("omitzero") + var closeZero bool var fieldNVar string if omitempty || omitzero { @@ -175,9 +175,11 @@ func (e *encodeGen) structmap(s *Struct) { return } - // quick return for the case where the entire thing is empty, but only at the top level - if !strings.Contains(s.Varname(), ".") { - e.p.printf("\nif %s == 0 { return }", fieldNVar) + // Skip block, if no fields are set. + if nfields > 1 { + e.p.printf("\n\n// skip if no fields are to be emitted") + e.p.printf("\nif %s != 0 {", fieldNVar) + closeZero = true } } else { @@ -226,7 +228,9 @@ func (e *encodeGen) structmap(s *Struct) { if oeField || anField { e.p.print("\n}") // close if statement } - + } + if closeZero { + e.p.printf("\n}") // close if statement } } diff --git a/gen/marshal.go b/gen/marshal.go index 59a6e6ec..6fb95ec6 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -3,7 +3,6 @@ package gen import ( "fmt" "io" - "strings" "github.com/tinylib/msgp/msgp" ) @@ -137,6 +136,7 @@ func (m *marshalGen) mapstruct(s *Struct) { omitempty := s.AnyHasTagPart("omitempty") omitzero := s.AnyHasTagPart("omitzero") + var closeZero bool var fieldNVar string if omitempty || omitzero { @@ -169,9 +169,11 @@ func (m *marshalGen) mapstruct(s *Struct) { return } - // quick return for the case where the entire thing is empty, but only at the top level - if !strings.Contains(s.Varname(), ".") { - m.p.printf("\nif %s == 0 { return }", fieldNVar) + // Skip block, if no fields are set. + if nfields > 1 { + m.p.printf("\n\n// skip if no fields are to be emitted") + m.p.printf("\nif %s != 0 {", fieldNVar) + closeZero = true } } else { @@ -222,7 +224,9 @@ func (m *marshalGen) mapstruct(s *Struct) { if oeField || anField { m.p.printf("\n}") // close if statement } - + } + if closeZero { + m.p.printf("\n}") // close if statement } } diff --git a/gen/unmarshal.go b/gen/unmarshal.go index f4a9652a..37e066fd 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -151,7 +151,9 @@ func (u *unmarshalGen) mapstruct(s *Struct) { u.p.print("\n}\n}") // close switch and for loop if oeCount > 0 { u.p.printf("\n// Clear omitted fields.\n") - u.p.printf("if %s {\n", bm.notAllSet()) + if bm.bitlen > 1 { + u.p.printf("if %s {\n", bm.notAllSet()) + } for bitIdx, fieldIdx := range oeEmittedIdx { fieldElem := s.Fields[fieldIdx].FieldElem @@ -164,7 +166,9 @@ func (u *unmarshalGen) mapstruct(s *Struct) { } u.p.printf("}\n") } - u.p.printf("}") + if bm.bitlen > 1 { + u.p.printf("}") + } } }