From b832b0bf5a69e6845f2b8152349ffb8161bdd9eb Mon Sep 17 00:00:00 2001 From: zonewave Date: Mon, 17 Apr 2023 17:54:03 +0800 Subject: [PATCH] feat(unset): support unset value to zero. (#42) --- go.mod | 2 + go.sum | 11 ++++ unset.go | 90 +++++++++++++++++++++++++++++++++ unset_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 239 insertions(+) create mode 100644 unset.go create mode 100644 unset_test.go diff --git a/go.mod b/go.mod index 2b71726..a49c24f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/creasty/defaults go 1.14 + +require github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index e69de29..acb88a4 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,11 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/unset.go b/unset.go new file mode 100644 index 0000000..a117454 --- /dev/null +++ b/unset.go @@ -0,0 +1,90 @@ +package defaults + +import ( + "reflect" +) + +const ( + unsetFlag = "unset" + // unsetRecursion means that the field will be unset recursively + unsetRecursion = "walk" +) + +func MustUnset(ptr interface{}) { + if err := Unset(ptr); err != nil { + panic(err) + } +} +func Unset(obj interface{}) error { + v := indirect(reflect.ValueOf(obj)) + if v.Kind() != reflect.Struct { + return errInvalidType + } + unsetStruct(v) + return nil +} +func unsetStruct(obj reflect.Value) { + t := obj.Type() + for i := 0; i < t.NumField(); i++ { + unsetVal := t.Field(i).Tag.Get(unsetFlag) + if unsetVal == "-" { + continue + } + unsetField(obj.Field(i), unsetVal == unsetRecursion) + } + return +} + +func indirect(v reflect.Value) reflect.Value { + finalValue := v + for finalValue.Kind() == reflect.Ptr { + finalValue = finalValue.Elem() + } + return finalValue +} + +func unsetField(field reflect.Value, unsetWalk bool) { + if !field.CanSet() { + return + } + + isInitial := isInitialValue(field) + if isInitial { + return + } + if !unsetWalk { + field.Set(reflect.Zero(field.Type())) + return + } + field = indirect(field) + if !field.CanSet() { + return + } + + switch field.Kind() { + default: + field.Set(reflect.Zero(field.Type())) + case reflect.Struct: + unsetStruct(field) + case reflect.Slice: + for j := 0; j < field.Len(); j++ { + unsetField(field.Index(j), true) + } + case reflect.Map: + for _, e := range field.MapKeys() { + var mapValue = field.MapIndex(e) + switch mapValue.Kind() { + case reflect.Ptr: + unsetField(indirect(mapValue), true) + case reflect.Struct, reflect.Slice, reflect.Map: + ref := reflect.New(mapValue.Type()) + ref.Elem().Set(mapValue) + unsetField(ref.Elem(), true) + field.SetMapIndex(e, ref.Elem().Convert(mapValue.Type())) + default: + field.SetMapIndex(e, reflect.Zero(mapValue.Type())) + } + } + } + return +} diff --git a/unset_test.go b/unset_test.go new file mode 100644 index 0000000..4ef7f78 --- /dev/null +++ b/unset_test.go @@ -0,0 +1,136 @@ +package defaults + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestUnset(t *testing.T) { + t.Run("sample unset", func(t *testing.T) { + s := &Sample{} + MustSet(s) + err := Unset(s) + require.NoError(t, err) + require.Equal(t, s, &Sample{}) + }) + + t.Run("errInvalidType", func(t *testing.T) { + tmp := 8 + require.Equal(t, Unset(&tmp), errInvalidType) + require.Panics(t, func() { MustUnset(&tmp) }) + }) + + t.Run("not reset by -", func(t *testing.T) { + s := &struct { + IgnoreMe string `default:"-" unset:"-"` + }{ + IgnoreMe: "test", + } + MustUnset(s) + require.Equal(t, s.IgnoreMe, "test") + }) + + t.Run("sampleUnset test", func(t *testing.T) { + s := &SampleUnset{} + MustSet(s) + var testNumPtr *int = nil + s.SliceOfIntPtrPtr[1] = &testNumPtr + s.private = StructUnset{WithDefault: "test"} + MustUnset(s) + + // struct + require.Equal(t, s.private.WithDefault, "test") + require.Equal(t, 0, s.Struct.Foo) + require.Equal(t, 0, s.Struct.Bar) + require.Nil(t, s.Struct.BarPtr) + require.Equal(t, 0, *s.Struct.BarPtrWithWalk) + require.Empty(t, s.StructPtrNoWalk) + require.Equal(t, "foo", s.StructPtr.EmbeddedUnset.String) + require.Equal(t, 0, s.StructPtr.EmbeddedUnset.Int) + require.Equal(t, "foo", s.StructPtr.Struct.String) + require.Equal(t, 0, s.StructPtr.Struct.Int) + + // slice + require.Equal(t, 0, s.SliceOfInt[0]) + require.Equal(t, 0, *s.SliceOfIntPtr[0]) + require.Equal(t, (*int)(nil), *s.SliceOfIntPtrPtr[1]) + require.Equal(t, 0, s.SliceOfStruct[0].Foo) + require.Equal(t, 0, s.SliceOfStructPtr[0].Foo) + require.Equal(t, 0, s.SliceOfSliceInt[0][0]) + require.Equal(t, 0, *s.SliceOfSliceIntPtr[0][0]) + require.Equal(t, 0, s.SliceOfSliceStruct[0][0].Foo) + require.Equal(t, 0, s.SliceOfMapOfInt[0]["int1"]) + require.Equal(t, 0, s.SliceOfMapOfStruct[0]["Struct3"].Foo) + require.Equal(t, 0, s.SliceOfMapOfStructPtr[0]["Struct3"].Foo) + require.Nil(t, s.SliceSetNil) + + // map + require.Equal(t, 0, s.MapOfInt["int1"]) + require.Equal(t, 0, *s.MapOfIntPtr["int1"]) + require.Equal(t, 0, s.MapOfStruct["Struct3"].Foo) + require.Equal(t, 0, s.MapOfStructPtr["Struct3"].Foo) + require.Equal(t, 0, s.MapOfSliceInt["slice1"][0]) + require.Equal(t, 0, *s.MapOfSliceIntPtr["slice1"][0]) + require.Equal(t, 0, s.MapOfSliceStruct["slice1"][0].Foo) + require.Equal(t, 0, s.MapOfMapOfInt["map1"]["int1"]) + require.Equal(t, 0, s.MapOfMapOfInt["map1"]["int1"]) + require.Equal(t, 0, s.MapOfMapOfStruct["map1"]["Struct3"].Foo, 0) + require.Nil(t, s.MapSetNil) + + // map embed + require.Equal(t, "foo", s.MapOfStruct["Struct3"].String) + require.Equal(t, "foo", s.MapOfStructPtr["Struct3"].String) + require.Equal(t, "foo", s.MapOfSliceStruct["slice1"][0].String) + require.Equal(t, "foo", s.MapOfMapOfStruct["map1"]["Struct3"].String) + + }) +} + +type EmbeddedUnset struct { + Int int `default:"1"` + String string `default:"foo" unset:"-"` +} + +type StructUnset struct { + EmbeddedUnset `default:"{}" unset:"walk"` + Foo int `default:"1"` + Bar int `default:"1"` + BarPtr *int `default:"1"` + BarPtrWithWalk *int `default:"1" unset:"walk"` + WithDefault string `default:"foo"` + Struct EmbeddedUnset ` unset:"walk"` +} + +type SampleUnset struct { + private StructUnset `default:"{}" unset:"walk"` + Struct StructUnset `default:"{}" unset:"walk"` + StructPtr *StructUnset `default:"{}" unset:"walk"` + StructPtrNoWalk *StructUnset `default:"{}"` + + SliceOfInt []int `default:"[1,2,3]" unset:"walk"` + SliceOfIntPtr []*int `default:"[1,2,3]" unset:"walk"` + SliceOfIntPtrPtr []**int `default:"[1,2,3]" unset:"walk"` + SliceOfStruct []StructUnset `default:"[{\"Foo\":123}]" unset:"walk"` + SliceOfStructPtr []*StructUnset `default:"[{\"Foo\":123}]" unset:"walk"` + SliceOfMapOfInt []map[string]int `default:"[{\"int1\": 1}]" unset:"walk"` + SliceOfMapOfStruct []map[string]StructUnset `default:"[{\"Struct3\": {\"Foo\":123}}]" unset:"walk"` + SliceOfMapOfStructPtr []map[string]*StructUnset `default:"[{\"Struct3\": {\"Foo\":123}}]" unset:"walk"` + SliceOfSliceInt [][]int `default:"[[1,2,3]]" unset:"walk"` + SliceOfSliceIntPtr [][]*int `default:"[[1,2,3]]" unset:"walk"` + SliceOfSliceStruct [][]StructUnset `default:"[[{\"Foo\":123}]]" unset:"walk"` + SliceOfSliceStructPtr [][]*StructUnset `default:"[[{\"Foo\":123}]]" unset:"walk"` + SliceSetNil []StructUnset `default:"[{\"Foo\":123}]"` + + MapOfInt map[string]int `default:"{\"int1\": 1}" unset:"walk"` + MapOfIntPtr map[string]*int `default:"{\"int1\": 1}" unset:"walk"` + MapOfStruct map[string]StructUnset `default:"{\"Struct3\": {\"Foo\":123}}" unset:"walk" ` + MapOfStructPtr map[string]*StructUnset `default:"{\"Struct3\": {\"Foo\":123}}" unset:"walk"` + MapOfSliceInt map[string][]int `default:"{\"slice1\": [1,2,3]}" unset:"walk"` + MapOfSliceIntPtr map[string][]*int `default:"{\"slice1\": [1,2,3]}" unset:"walk"` + MapOfSliceStruct map[string][]StructUnset `default:"{\"slice1\": [{\"Foo\":123}]}" unset:"walk"` + MapOfSliceStructPtr map[string][]*StructUnset `default:"{\"slice1\": [{\"Foo\":123}]}" unset:"walk"` + MapOfMapOfInt map[string]map[string]int `default:"{\"map1\": {\"int1\": 1}}" unset:"walk"` + MapOfMapOfStruct map[string]map[string]StructUnset `default:"{\"map1\": {\"Struct3\": {\"Foo\":123}}}" unset:"walk"` + + MapSetNil map[string]StructUnset `default:"{\"Struct3\": {\"Foo\":123}}"` +}