From b723d66c6638e5a1c85808132766124f2ccb4933 Mon Sep 17 00:00:00 2001 From: Tom Bevan Date: Mon, 20 Dec 2021 15:29:25 +0000 Subject: [PATCH] fix patching nil values --- change_value.go | 89 +++++++++++++++++++++++++++---------------------- patch_test.go | 76 +++++++++++++++++++++++------------------ 2 files changed, 92 insertions(+), 73 deletions(-) diff --git a/change_value.go b/change_value.go index 529554b..f115da9 100644 --- a/change_value.go +++ b/change_value.go @@ -103,58 +103,67 @@ func (c ChangeValue) Len() int { //Set echos reflect set func (c *ChangeValue) Set(value reflect.Value, convertCompatibleTypes bool) { - if c != nil { - defer func() { - if r := recover(); r != nil { - c.AddError(NewError(r.(string))) - c.SetFlag(FlagFailed) + if c == nil { + return + } + defer func() { + if r := recover(); r != nil { + switch e := r.(type) { + case string: + c.AddError(NewError(e)) + case *reflect.ValueError: + c.AddError(NewError(e.Error())) } - }() - if c.HasFlag(OptionImmutable) { - c.SetFlag(FlagIgnored) - return + c.SetFlag(FlagFailed) } + }() - if convertCompatibleTypes { - if c.target.Kind() == reflect.Ptr && value.Kind() != reflect.Ptr { - if !value.Type().ConvertibleTo(c.target.Elem().Type()) { - c.AddError(fmt.Errorf("Value of type %s is not convertible to %s", value.Type().String(), c.target.Type().String())) - c.SetFlag(FlagFailed) - return - } + if c.HasFlag(OptionImmutable) { + c.SetFlag(FlagIgnored) + return + } - fmt.Println(c.target.Elem().Type()) + if convertCompatibleTypes { + if c.target.Kind() == reflect.Ptr && value.Kind() != reflect.Ptr { + if !value.IsValid() { + c.target.Set(reflect.Zero(c.target.Type())) + c.SetFlag(FlagApplied) + return + } else if !value.Type().ConvertibleTo(c.target.Elem().Type()) { + c.AddError(fmt.Errorf("Value of type %s is not convertible to %s", value.Type().String(), c.target.Type().String())) + c.SetFlag(FlagFailed) + return + } + + tv := reflect.New(c.target.Elem().Type()) + tv.Elem().Set(value.Convert(c.target.Elem().Type())) + c.target.Set(tv) + } else { + if !value.Type().ConvertibleTo(c.target.Type()) { + c.AddError(fmt.Errorf("Value of type %s is not convertible to %s", value.Type().String(), c.target.Type().String())) + c.SetFlag(FlagFailed) + return + } - tv := reflect.New(c.target.Elem().Type()) - tv.Elem().Set(value.Convert(c.target.Elem().Type())) + c.target.Set(value.Convert(c.target.Type())) + } + } else { + if value.IsValid() { + if c.target.Kind() == reflect.Ptr && value.Kind() != reflect.Ptr { + tv := reflect.New(value.Type()) + tv.Elem().Set(value) c.target.Set(tv) } else { - if !value.Type().ConvertibleTo(c.target.Type()) { - c.AddError(fmt.Errorf("Value of type %s is not convertible to %s", value.Type().String(), c.target.Type().String())) - c.SetFlag(FlagFailed) - return - } - - c.target.Set(value.Convert(c.target.Type())) - } - } else { - if value.IsValid() { - if c.target.Kind() == reflect.Ptr && value.Kind() != reflect.Ptr { - tv := reflect.New(value.Type()) - tv.Elem().Set(value) - c.target.Set(tv) - } else { - c.target.Set(value) - } - } else if !c.target.IsZero() { - t := c.target.Elem() - t.Set(reflect.Zero(t.Type())) + c.target.Set(value) } + } else if !c.target.IsZero() { + t := c.target.Elem() + t.Set(reflect.Zero(t.Type())) } - c.SetFlag(FlagApplied) } + c.SetFlag(FlagApplied) } //Index echo for index diff --git a/patch_test.go b/patch_test.go index bb69e7a..4f6b116 100644 --- a/patch_test.go +++ b/patch_test.go @@ -285,51 +285,61 @@ func TestPatch(t *testing.T) { assert.Equal(t, int(b.Bar), a.Bar) require.Equal(t, len(cl), len(pl)) }) -} -func TestPatchPointer(t *testing.T) { - type tps struct { - S *string - } + t.Run("pointer", func(t *testing.T) { + type tps struct { + S *string + } - str1 := "before" - str2 := "after" + str1 := "before" + str2 := "after" - t1 := tps{S: &str1} - t2 := tps{S: &str2} + t1 := tps{S: &str1} + t2 := tps{S: &str2} - changelog, err := diff.Diff(t1, t2) - assert.NoError(t, err) + changelog, err := diff.Diff(t1, t2) + assert.NoError(t, err) - patchLog := diff.Patch(changelog, &t1) - assert.False(t, patchLog.HasErrors()) -} + patchLog := diff.Patch(changelog, &t1) + assert.False(t, patchLog.HasErrors()) + }) -func TestPatchPointerConvertTypes(t *testing.T) { - type tps struct { - S *int - } + t.Run("pointer-with-converted-type", func(t *testing.T) { + type tps struct { + S *int + } + + val1 := 1 + val2 := 2 + + t1 := tps{S: &val1} + t2 := tps{S: &val2} - val1 := 1 - val2 := 2 + changelog, err := diff.Diff(t1, t2) + assert.NoError(t, err) - t1 := tps{S: &val1} - t2 := tps{S: &val2} + js, err := json.Marshal(changelog) + assert.NoError(t, err) - changelog, err := diff.Diff(t1, t2) - assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(js, &changelog)) - js, err := json.Marshal(changelog) - assert.NoError(t, err) + d, err := diff.NewDiffer(diff.ConvertCompatibleTypes()) + assert.NoError(t, err) - assert.NoError(t, json.Unmarshal(js, &changelog)) + assert.Equal(t, 1, *t1.S) - d, err := diff.NewDiffer(diff.ConvertCompatibleTypes()) - assert.NoError(t, err) + patchLog := d.Patch(changelog, &t1) + assert.False(t, patchLog.HasErrors()) + assert.Equal(t, 2, *t1.S) - assert.Equal(t, 1, *t1.S) + // test nil pointer + t1 = tps{S: &val1} + t2 = tps{S: nil} - patchLog := d.Patch(changelog, &t1) - assert.False(t, patchLog.HasErrors()) - assert.Equal(t, 2, *t1.S) + changelog, err = diff.Diff(t1, t2) + assert.NoError(t, err) + + patchLog = d.Patch(changelog, &t1) + assert.False(t, patchLog.HasErrors()) + }) }