diff --git a/diff.go b/diff.go index 8ac8bd0..64e3dcd 100644 --- a/diff.go +++ b/diff.go @@ -23,6 +23,56 @@ const ( DELETE = "delete" ) +// DiffType represents an enum with all the supported diff types +type DiffType uint8 + +const ( + UNSUPPORTED DiffType = iota + STRUCT + SLICE + ARRAY + STRING + BOOL + INT + UINT + FLOAT + MAP + PTR + INTERFACE +) + +func (t DiffType) String() string { + switch t { + case STRUCT: + return "STRUCT" + case SLICE: + return "SLICE" + case ARRAY: + return "ARRAY" + case STRING: + return "STRING" + case BOOL: + return "BOOL" + case INT: + return "INT" + case UINT: + return "UINT" + case FLOAT: + return "FLOAT" + case MAP: + return "MAP" + case PTR: + return "PTR" + case INTERFACE: + return "INTERFACE" + default: + return "UNSUPPORTED" + } +} + +// DiffFunc represents the built-in diff functions +type DiffFunc func([]string, reflect.Value, reflect.Value, interface{}) error + // Differ a configurable diff instance type Differ struct { TagName string @@ -53,7 +103,7 @@ type Change struct { // ValueDiffer is an interface for custom differs type ValueDiffer interface { Match(a, b reflect.Value) bool - Diff(cl *Changelog, path []string, a, b reflect.Value) error + Diff(dt DiffType, df DiffFunc, cl *Changelog, path []string, a, b reflect.Value, parent interface{}) error InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) } @@ -134,6 +184,35 @@ func (cl *Changelog) Filter(path []string) Changelog { return ncl } +func (d *Differ) getDiffType(a, b reflect.Value) (DiffType, DiffFunc) { + switch { + case are(a, b, reflect.Struct, reflect.Invalid): + return STRUCT, d.diffStruct + case are(a, b, reflect.Slice, reflect.Invalid): + return SLICE, d.diffSlice + case are(a, b, reflect.Array, reflect.Invalid): + return ARRAY, d.diffSlice + case are(a, b, reflect.String, reflect.Invalid): + return STRING, d.diffString + case are(a, b, reflect.Bool, reflect.Invalid): + return BOOL, d.diffBool + case are(a, b, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Invalid): + return INT, d.diffInt + case are(a, b, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Invalid): + return UINT, d.diffUint + case are(a, b, reflect.Float32, reflect.Float64, reflect.Invalid): + return FLOAT, d.diffFloat + case are(a, b, reflect.Map, reflect.Invalid): + return MAP, d.diffMap + case are(a, b, reflect.Ptr, reflect.Invalid): + return PTR, d.diffPtr + case are(a, b, reflect.Interface, reflect.Invalid): + return INTERFACE, d.diffInterface + default: + return UNSUPPORTED, nil + } +} + // Diff returns a changelog of all mutated values from both func (d *Differ) Diff(a, b interface{}) (Changelog, error) { // reset the state of the diff @@ -160,11 +239,14 @@ func (d *Differ) diff(path []string, a, b reflect.Value, parent interface{}) err return ErrTypeMismatch } + // get the diff type and the corresponding built-int diff function to handle this type + diffType, diffFunc := d.getDiffType(a, b) + // first go through custom diff functions if len(d.customValueDiffers) > 0 { for _, vd := range d.customValueDiffers { if vd.Match(a, b) { - err := vd.Diff(&d.cl, path, a, b) + err := vd.Diff(diffType, diffFunc, &d.cl, path, a, b, parent) if err != nil { return err } @@ -174,32 +256,11 @@ func (d *Differ) diff(path []string, a, b reflect.Value, parent interface{}) err } // then built-in diff functions - switch { - case are(a, b, reflect.Struct, reflect.Invalid): - return d.diffStruct(path, a, b) - case are(a, b, reflect.Slice, reflect.Invalid): - return d.diffSlice(path, a, b) - case are(a, b, reflect.Array, reflect.Invalid): - return d.diffSlice(path, a, b) - case are(a, b, reflect.String, reflect.Invalid): - return d.diffString(path, a, b, parent) - case are(a, b, reflect.Bool, reflect.Invalid): - return d.diffBool(path, a, b, parent) - case are(a, b, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Invalid): - return d.diffInt(path, a, b, parent) - case are(a, b, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Invalid): - return d.diffUint(path, a, b, parent) - case are(a, b, reflect.Float32, reflect.Float64, reflect.Invalid): - return d.diffFloat(path, a, b, parent) - case are(a, b, reflect.Map, reflect.Invalid): - return d.diffMap(path, a, b) - case are(a, b, reflect.Ptr, reflect.Invalid): - return d.diffPtr(path, a, b, parent) - case are(a, b, reflect.Interface, reflect.Invalid): - return d.diffInterface(path, a, b, parent) - default: + if diffType == UNSUPPORTED { return errors.New("unsupported type: " + a.Kind().String()) } + + return diffFunc(path, a, b, parent) } func (cl *Changelog) Add(t string, path []string, ftco ...interface{}) { diff --git a/diff_map.go b/diff_map.go index dcb8efb..02b1c1b 100644 --- a/diff_map.go +++ b/diff_map.go @@ -11,7 +11,7 @@ import ( "github.com/vmihailenco/msgpack" ) -func (d *Differ) diffMap(path []string, a, b reflect.Value) error { +func (d *Differ) diffMap(path []string, a, b reflect.Value, parent interface{}) error { if a.Kind() == reflect.Invalid { return d.mapValues(CREATE, path, b) } diff --git a/diff_slice.go b/diff_slice.go index f6b5455..3fd281b 100644 --- a/diff_slice.go +++ b/diff_slice.go @@ -8,7 +8,7 @@ import ( "reflect" ) -func (d *Differ) diffSlice(path []string, a, b reflect.Value) error { +func (d *Differ) diffSlice(path []string, a, b reflect.Value, parent interface{}) error { if a.Kind() == reflect.Invalid { d.cl.Add(CREATE, path, nil, exportInterface(b)) return nil diff --git a/diff_struct.go b/diff_struct.go index b58f3e8..fb14c57 100644 --- a/diff_struct.go +++ b/diff_struct.go @@ -9,7 +9,7 @@ import ( "time" ) -func (d *Differ) diffStruct(path []string, a, b reflect.Value) error { +func (d *Differ) diffStruct(path []string, a, b reflect.Value, parent interface{}) error { if AreType(a, b, reflect.TypeOf(time.Time{})) { return d.diffTime(path, a, b) } diff --git a/diff_test.go b/diff_test.go index a075fac..22a6655 100644 --- a/diff_test.go +++ b/diff_test.go @@ -6,6 +6,7 @@ package diff_test import ( "reflect" + "strings" "sync" "testing" "time" @@ -907,7 +908,7 @@ func (o *testTypeDiffer) InsertParentDiffer(dfunc func(path []string, a, b refle func (o *testTypeDiffer) Match(a, b reflect.Value) bool { return diff.AreType(a, b, reflect.TypeOf(testType(""))) } -func (o *testTypeDiffer) Diff(cl *diff.Changelog, path []string, a, b reflect.Value) error { +func (o *testTypeDiffer) Diff(dt diff.DiffType, df diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, parent interface{}) error { if a.String() != "custom" && b.String() != "match" { cl.Add(diff.UPDATE, path, a.Interface(), b.Interface()) } @@ -944,6 +945,54 @@ func TestCustomDiffer(t *testing.T) { assert.Len(t, cl, 1) } +type testStringInterceptorDiffer struct { + DiffFunc (func(path []string, a, b reflect.Value, p interface{}) error) +} + +func (o *testStringInterceptorDiffer) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + o.DiffFunc = dfunc +} + +func (o *testStringInterceptorDiffer) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(testType(""))) +} +func (o *testStringInterceptorDiffer) Diff(dt diff.DiffType, df diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, parent interface{}) error { + if dt.String() == "STRING" { + // intercept the data + aValue, aOk := a.Interface().(testType) + bValue, bOk := b.Interface().(testType) + + if aOk && bOk { + if aValue == "avalue" { + aValue = testType(strings.ToUpper(string(aValue))) + a = reflect.ValueOf(aValue) + } + + if bValue == "bvalue" { + bValue = testType(strings.ToUpper(string(aValue))) + b = reflect.ValueOf(bValue) + } + } + } + + // continue the diff logic passing the updated a/b values + return df(path, a, b, parent) +} + +func TestStringInterceptorDiffer(t *testing.T) { + d, err := diff.NewDiffer( + diff.CustomValueDiffers( + &testStringInterceptorDiffer{}, + ), + ) + require.Nil(t, err) + + cl, err := d.Diff(testType("avalue"), testType("bvalue")) + require.Nil(t, err) + + assert.Len(t, cl, 0) +} + type RecursiveTestStruct struct { Id int Children []RecursiveTestStruct @@ -961,7 +1010,7 @@ func (o *recursiveTestStructDiffer) Match(a, b reflect.Value) bool { return diff.AreType(a, b, reflect.TypeOf(RecursiveTestStruct{})) } -func (o *recursiveTestStructDiffer) Diff(cl *diff.Changelog, path []string, a, b reflect.Value) error { +func (o *recursiveTestStructDiffer) Diff(dt diff.DiffType, df diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, parent interface{}) error { if a.Kind() == reflect.Invalid { cl.Add(diff.CREATE, path, nil, b.Interface()) return nil