Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "clearomitted" directive #373

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions _generated/clearomitted.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package _generated

import (
"encoding/json"
"time"
)

//go:generate msgp

//msgp:clearomitted

// check some specific cases for omitzero

type ClearOmitted0 struct {
AStruct ClearOmittedA `msg:"astruct,omitempty"` // leave this one omitempty
BStruct ClearOmittedA `msg:"bstruct,omitzero"` // and compare to this
AStructPtr *ClearOmittedA `msg:"astructptr,omitempty"` // a pointer case omitempty
BStructPtr *ClearOmittedA `msg:"bstructptr,omitzero"` // a pointer case omitzero
AExt OmitZeroExt `msg:"aext,omitzero"` // external type case

// more
APtrNamedStr *NamedStringCO `msg:"aptrnamedstr,omitzero"`
ANamedStruct NamedStructCO `msg:"anamedstruct,omitzero"`
APtrNamedStruct *NamedStructCO `msg:"aptrnamedstruct,omitzero"`
EmbeddableStructCO `msg:",flatten,omitzero"` // embed flat
EmbeddableStructCO2 `msg:"embeddablestruct2,omitzero"` // embed non-flat
ATime time.Time `msg:"atime,omitzero"`
ASlice []int `msg:"aslice,omitempty"`
AMap map[string]int `msg:"amap,omitempty"`
ABin []byte `msg:"abin,omitempty"`
AInt int `msg:"aint,omitempty"`
AString string `msg:"atring,omitempty"`
Adur time.Duration `msg:"adur,omitempty"`
AJSON json.Number `msg:"ajson,omitempty"`

ClearOmittedTuple ClearOmittedTuple `msg:"ozt"` // the inside of a tuple should ignore both omitempty and omitzero
}

type ClearOmittedA struct {
A string `msg:"a,omitempty"`
B NamedStringCO `msg:"b,omitzero"`
C NamedStringCO `msg:"c,omitzero"`
}

func (o *ClearOmittedA) IsZero() bool {
if o == nil {
return true
}
return *o == (ClearOmittedA{})
}

type NamedStructCO struct {
A string `msg:"a,omitempty"`
B string `msg:"b,omitempty"`
}

func (ns *NamedStructCO) IsZero() bool {
if ns == nil {
return true
}
return *ns == (NamedStructCO{})
}

type NamedStringCO string

func (ns *NamedStringCO) IsZero() bool {
if ns == nil {
return true
}
return *ns == ""
}

type EmbeddableStructCO struct {
SomeEmbed string `msg:"someembed2,omitempty"`
}

func (es EmbeddableStructCO) IsZero() bool { return es == (EmbeddableStructCO{}) }

type EmbeddableStructCO2 struct {
SomeEmbed2 string `msg:"someembed2,omitempty"`
}

func (es EmbeddableStructCO2) IsZero() bool { return es == (EmbeddableStructCO2{}) }

//msgp:tuple ClearOmittedTuple

// ClearOmittedTuple is flagged for tuple output, it should ignore all omitempty and omitzero functionality
// since it's fundamentally incompatible.
type ClearOmittedTuple struct {
FieldA string `msg:"fielda,omitempty"`
FieldB NamedStringCO `msg:"fieldb,omitzero"`
FieldC NamedStringCO `msg:"fieldc,omitzero"`
}

type ClearOmitted1 struct {
T1 ClearOmittedTuple `msg:"t1"`
}
93 changes: 93 additions & 0 deletions _generated/clearomitted_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package _generated

import (
"bytes"
"encoding/json"
"reflect"
"testing"
"time"

"github.com/tinylib/msgp/msgp"
)

func TestClearOmitted(t *testing.T) {
cleared := ClearOmitted0{}
encoded, err := cleared.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
vPtr := NamedStringCO("value")
filled := ClearOmitted0{
AStruct: ClearOmittedA{A: "something"},
BStruct: ClearOmittedA{A: "somthing"},
AStructPtr: &ClearOmittedA{A: "something"},
AExt: OmitZeroExt{25},
APtrNamedStr: &vPtr,
ANamedStruct: NamedStructCO{A: "value"},
APtrNamedStruct: &NamedStructCO{A: "sdf"},
EmbeddableStructCO: EmbeddableStructCO{"value"},
EmbeddableStructCO2: EmbeddableStructCO2{"value"},
ATime: time.Now(),
ASlice: []int{1, 2, 3},
AMap: map[string]int{"1": 1},
ABin: []byte{1, 2, 3},
ClearOmittedTuple: ClearOmittedTuple{FieldA: "value"},
AInt: 42,
AString: "value",
Adur: time.Second,
AJSON: json.Number(`43.0000000000002`),
}
dst := filled
_, err = dst.UnmarshalMsg(encoded)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(dst, cleared) {
t.Errorf("\n got=%#v\nwant=%#v", dst, cleared)
}
// Reset
dst = filled
err = dst.DecodeMsg(msgp.NewReader(bytes.NewReader(encoded)))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(dst, cleared) {
t.Errorf("\n got=%#v\nwant=%#v", dst, cleared)
}

// Check that fields aren't accidentally zeroing fields.
wantJson, err := json.Marshal(filled)
if err != nil {
t.Fatal(err)
}
encoded, err = filled.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
dst = ClearOmitted0{}
_, err = dst.UnmarshalMsg(encoded)
if err != nil {
t.Fatal(err)
}
got, err := json.Marshal(dst)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, wantJson) {
t.Errorf("\n got=%#v\nwant=%#v", string(got), string(wantJson))
}
// Reset
dst = ClearOmitted0{}
err = dst.DecodeMsg(msgp.NewReader(bytes.NewReader(encoded)))
if err != nil {
t.Fatal(err)
}
got, err = json.Marshal(dst)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, wantJson) {
t.Errorf("\n got=%#v\nwant=%#v", string(got), string(wantJson))
}
t.Log("OK - got", string(got))
}
38 changes: 38 additions & 0 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ func (d *decodeGen) structAsMap(s *Struct) {
d.p.declare(sz, u32)
d.assignAndCheck(sz, mapHeader)

oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero")
if !d.ctx.clearOmitted {
oeCount = 0
}
bm := bmask{
bitlen: oeCount,
varname: sz + "Mask",
}
if oeCount > 0 {
// Declare mask
d.p.printf("\n%s", bm.typeDecl())
d.p.printf("\n_ = %s", bm.varname)
}
// Index to field idx of each emitted
oeEmittedIdx := []int{}

d.p.printf("\nfor %s > 0 {\n%s--", sz, sz)
d.assignAndCheck("field", mapKey)
d.p.print("\nswitch msgp.UnsafeString(field) {")
Expand All @@ -123,6 +139,10 @@ func (d *decodeGen) structAsMap(s *Struct) {
}
SetIsAllowNil(fieldElem, anField)
next(d, fieldElem)
if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) {
d.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx)))
oeEmittedIdx = append(oeEmittedIdx, i)
}
d.ctx.Pop()
if !d.p.ok() {
return
Expand All @@ -136,6 +156,24 @@ func (d *decodeGen) structAsMap(s *Struct) {

d.p.closeblock() // close switch
d.p.closeblock() // close for loop

if oeCount > 0 {
d.p.printf("\n// Clear omitted fields.\n")
d.p.printf("if %s {\n", bm.notAllSet())
for bitIdx, fieldIdx := range oeEmittedIdx {
fieldElem := s.Fields[fieldIdx].FieldElem

d.p.printf("if %s == 0 {\n", bm.readExpr(bitIdx))
fze := fieldElem.ZeroExpr()
if fze != "" {
d.p.printf("%s = %s\n", fieldElem.Varname(), fze)
} else {
d.p.printf("%s = %s{}\n", fieldElem.Varname(), fieldElem.TypeName())
}
d.p.printf("}\n")
}
d.p.printf("}")
}
}

func (d *decodeGen) gBase(b *BaseElem) {
Expand Down
11 changes: 11 additions & 0 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,17 @@ func (s *Struct) AnyHasTagPart(pname string) bool {
return false
}

// CountFieldTagPart the count of HasTagPart(p) is true for any field.
func (s *Struct) CountFieldTagPart(pname string) int {
var n int
for _, sf := range s.Fields {
if sf.HasTagPart(pname) {
n++
}
}
return n
}

type StructField struct {
FieldTag string // the string inside the `msg:""` tag up to the first comma
FieldTagParts []string // the string inside the `msg:""` tag split by commas
Expand Down
30 changes: 27 additions & 3 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"io"
"strings"
)

const (
Expand Down Expand Up @@ -77,6 +78,7 @@ const (
type Printer struct {
gens []generator
CompactFloats bool
ClearOmitted bool
}

func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
Expand Down Expand Up @@ -145,7 +147,7 @@ func (p *Printer) Print(e Elem) error {
// collisions between idents created during SetVarname and idents created during Print,
// hence the separate prefixes.
resetIdent("zb")
err := g.Execute(e, Context{compFloats: p.CompactFloats})
err := g.Execute(e, Context{compFloats: p.CompactFloats, clearOmitted: p.ClearOmitted})
resetIdent("za")

if err != nil {
Expand All @@ -172,8 +174,9 @@ func (c contextVar) Arg() string {
}

type Context struct {
path []contextItem
compFloats bool
path []contextItem
compFloats bool
clearOmitted bool
}

func (c *Context) PushString(s string) {
Expand Down Expand Up @@ -501,3 +504,24 @@ func (b *bmask) setStmt(bitoffset int) string {

return buf.String()
}

// notAllSet returns a check against all fields having been set in set.
func (b *bmask) notAllSet() string {
var buf bytes.Buffer
buf.Grow(len(b.varname) + 16)
buf.WriteString(b.varname)
if b.bitlen > 64 {
var bytes []string
remain := b.bitlen
for remain >= 8 {
bytes = append(bytes, "0xff")
}
if remain > 0 {
bytes = append(bytes, fmt.Sprintf("0x%X", remain))
}
fmt.Fprintf(&buf, " != [%d]byte{%s}\n", (b.bitlen+63)/64, strings.Join(bytes, ","))
}
fmt.Fprintf(&buf, " != 0x%x", uint64(1<<b.bitlen)-1)

return buf.String()
}
37 changes: 37 additions & 0 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
u.p.declare(sz, u32)
u.assignAndCheck(sz, mapHeader)

oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero")
if !u.ctx.clearOmitted {
oeCount = 0
}
bm := bmask{
bitlen: oeCount,
varname: sz + "Mask",
}
if oeCount > 0 {
// Declare mask
u.p.printf("\n%s", bm.typeDecl())
u.p.printf("\n_ = %s", bm.varname)
}
// Index to field idx of each emitted
oeEmittedIdx := []int{}

u.p.printf("\nfor %s > 0 {", sz)
u.p.printf("\n%s--; field, bts, err = msgp.ReadMapKeyZC(bts)", sz)
u.p.wrapErrCheck(u.ctx.ArgsStr())
Expand All @@ -122,13 +138,34 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
SetIsAllowNil(fieldElem, anField)
next(u, fieldElem)
u.ctx.Pop()
if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) {
u.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx)))
oeEmittedIdx = append(oeEmittedIdx, i)
}
if anField {
u.p.printf("\n}")
}
}
u.p.print("\ndefault:\nbts, err = msgp.Skip(bts)")
u.p.wrapErrCheck(u.ctx.ArgsStr())
u.p.print("\n}\n}") // close switch and for loop
if oeCount > 0 {
u.p.printf("\n// Clear omitted fields.\n")
u.p.printf("if %s {\n", bm.notAllSet())
for bitIdx, fieldIdx := range oeEmittedIdx {
fieldElem := s.Fields[fieldIdx].FieldElem

u.p.printf("if %s == 0 {\n", bm.readExpr(bitIdx))
fze := fieldElem.ZeroExpr()
if fze != "" {
u.p.printf("%s = %s\n", fieldElem.Varname(), fze)
} else {
u.p.printf("%s = %s{}\n", fieldElem.Varname(), fieldElem.TypeName())
}
u.p.printf("}\n")
}
u.p.printf("}")
}
}

func (u *unmarshalGen) gBase(b *BaseElem) {
Expand Down
Loading
Loading