Skip to content

Commit

Permalink
Fix omitempty on aliased types (#377)
Browse files Browse the repository at this point in the history
Instead of doing return, insert an `if` block to skip emitting zero fields.

This also allows the check to be inserted at any level.

Fixes #376

Emitted code:

```Go
// MarshalMsg implements msgp.Marshaler
func (z TypeSamples) MarshalMsg(b []byte) (o []byte, err error) {
	o = msgp.Require(b, z.Msgsize())
	o = msgp.AppendArrayHeader(o, uint32(len(z)))
	for zb0004 := range z {
		// check for omitted fields
		zb0001Len := uint32(2)
		var zb0001Mask uint8 /* 2 bits */
		_ = zb0001Mask
		if z[zb0004].K == 0 {
			zb0001Len--
			zb0001Mask |= 0x1
		}
		if z[zb0004].V == 0 {
			zb0001Len--
			zb0001Mask |= 0x2
		}
		// variable map header, size zb0001Len
		o = append(o, 0x80|uint8(zb0001Len))

		// skip if no fields are to be emitted
		if zb0001Len != 0 {
			if (zb0001Mask & 0x1) == 0 { // if not omitted
				// string "k"
				o = append(o, 0xa1, 0x6b)
				o = msgp.AppendUint32(o, z[zb0004].K)
			}
			if (zb0001Mask & 0x2) == 0 { // if not omitted
				// string "v"
				o = append(o, 0xa1, 0x76)
				o = msgp.AppendUint32(o, z[zb0004].V)
			}
		}
	}
	return
}
```

If only 1 field, the check is omitted (and there is similar behavior on clearomitted).
  • Loading branch information
klauspost authored Oct 29, 2024
1 parent 4558fbf commit 1de3898
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 14 deletions.
7 changes: 7 additions & 0 deletions _generated/omitempty.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ type OmitEmpty0 struct {
ATime time.Time `msg:"atime,omitempty"`
}

type TypeSample struct {
K uint32 `msg:"k,omitempty"`
V uint32 `msg:"v,omitempty"`
}

type TypeSamples []TypeSample

type (
NamedBool bool
NamedInt int
Expand Down
33 changes: 33 additions & 0 deletions _generated/omitempty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package _generated
import (
"bytes"
"io"
"reflect"
"testing"

"github.com/tinylib/msgp/msgp"
Expand Down Expand Up @@ -285,3 +286,35 @@ func BenchmarkNotOmitEmpty10AllFull(b *testing.B) {
}
}
}

func TestTypeAlias(t *testing.T) {
value := TypeSamples{TypeSample{}, TypeSample{K: 1, V: 2}}
encoded, err := value.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
var got TypeSamples
_, err = got.UnmarshalMsg(encoded)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(value, got) {
t.Errorf("UnmarshalMsg got %v want %v", value, got)
}
var buf bytes.Buffer
w := msgp.NewWriter(&buf)
err = value.EncodeMsg(w)
if err != nil {
t.Fatal(err)
}
w.Flush()
got = TypeSamples{}
r := msgp.NewReader(&buf)
err = got.DecodeMsg(r)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(value, got) {
t.Errorf("UnmarshalMsg got %v want %v", value, got)
}
}
8 changes: 6 additions & 2 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ func (d *decodeGen) structAsMap(s *Struct) {

if oeCount > 0 {
d.p.printf("\n// Clear omitted fields.\n")
d.p.printf("if %s {\n", bm.notAllSet())
if bm.bitlen > 1 {
d.p.printf("if %s {\n", bm.notAllSet())
}
for bitIdx, fieldIdx := range oeEmittedIdx {
fieldElem := s.Fields[fieldIdx].FieldElem

Expand All @@ -172,7 +174,9 @@ func (d *decodeGen) structAsMap(s *Struct) {
}
d.p.printf("}\n")
}
d.p.printf("}")
if bm.bitlen > 1 {
d.p.printf("}")
}
}
}

Expand Down
14 changes: 9 additions & 5 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package gen
import (
"fmt"
"io"
"strings"

"github.com/tinylib/msgp/msgp"
)
Expand Down Expand Up @@ -142,6 +141,7 @@ func (e *encodeGen) structmap(s *Struct) {

omitempty := s.AnyHasTagPart("omitempty")
omitzero := s.AnyHasTagPart("omitzero")
var closeZero bool
var fieldNVar string
if omitempty || omitzero {

Expand Down Expand Up @@ -175,9 +175,11 @@ func (e *encodeGen) structmap(s *Struct) {
return
}

// quick return for the case where the entire thing is empty, but only at the top level
if !strings.Contains(s.Varname(), ".") {
e.p.printf("\nif %s == 0 { return }", fieldNVar)
// Skip block, if no fields are set.
if nfields > 1 {
e.p.printf("\n\n// skip if no fields are to be emitted")
e.p.printf("\nif %s != 0 {", fieldNVar)
closeZero = true
}

} else {
Expand Down Expand Up @@ -226,7 +228,9 @@ func (e *encodeGen) structmap(s *Struct) {
if oeField || anField {
e.p.print("\n}") // close if statement
}

}
if closeZero {
e.p.printf("\n}") // close if statement
}
}

Expand Down
14 changes: 9 additions & 5 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package gen
import (
"fmt"
"io"
"strings"

"github.com/tinylib/msgp/msgp"
)
Expand Down Expand Up @@ -137,6 +136,7 @@ func (m *marshalGen) mapstruct(s *Struct) {

omitempty := s.AnyHasTagPart("omitempty")
omitzero := s.AnyHasTagPart("omitzero")
var closeZero bool
var fieldNVar string
if omitempty || omitzero {

Expand Down Expand Up @@ -169,9 +169,11 @@ func (m *marshalGen) mapstruct(s *Struct) {
return
}

// quick return for the case where the entire thing is empty, but only at the top level
if !strings.Contains(s.Varname(), ".") {
m.p.printf("\nif %s == 0 { return }", fieldNVar)
// Skip block, if no fields are set.
if nfields > 1 {
m.p.printf("\n\n// skip if no fields are to be emitted")
m.p.printf("\nif %s != 0 {", fieldNVar)
closeZero = true
}

} else {
Expand Down Expand Up @@ -222,7 +224,9 @@ func (m *marshalGen) mapstruct(s *Struct) {
if oeField || anField {
m.p.printf("\n}") // close if statement
}

}
if closeZero {
m.p.printf("\n}") // close if statement
}
}

Expand Down
8 changes: 6 additions & 2 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
u.p.print("\n}\n}") // close switch and for loop
if oeCount > 0 {
u.p.printf("\n// Clear omitted fields.\n")
u.p.printf("if %s {\n", bm.notAllSet())
if bm.bitlen > 1 {
u.p.printf("if %s {\n", bm.notAllSet())
}
for bitIdx, fieldIdx := range oeEmittedIdx {
fieldElem := s.Fields[fieldIdx].FieldElem

Expand All @@ -164,7 +166,9 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
}
u.p.printf("}\n")
}
u.p.printf("}")
if bm.bitlen > 1 {
u.p.printf("}")
}
}
}

Expand Down

0 comments on commit 1de3898

Please sign in to comment.