diff --git a/VERSION b/VERSION index ceddfb2..e3b86dd 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.15 +0.0.16 diff --git a/go.mod b/go.mod index 92f2d2c..68a6d91 100644 --- a/go.mod +++ b/go.mod @@ -11,11 +11,14 @@ require ( golang.org/x/tools v0.25.0 ) +require golang.org/x/sys v0.25.0 // indirect + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/sirupsen/logrus v1.9.3 golang.org/x/mod v0.21.0 // indirect golang.org/x/sync v0.8.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index bf9c26b..5b24475 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -48,6 +51,9 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -65,5 +71,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/test/reflect.go b/test/reflect.go index 5573362..dee1215 100644 --- a/test/reflect.go +++ b/test/reflect.go @@ -5,48 +5,66 @@ import ( "unsafe" ) -// Error creates an Accessor for the given error to access and modify its +// Error creates an accessor/build for the given error to access and modify its // unexported fields by field name. -// -// Example: -// -// err := test.Error(errors.New("error message")).Set("text", "new message").Get("") -// fmt.Println(err.Error()) // Output: new message -// -// err := test.Error(errors.New("error message")).Set("text", "new message").Get("text") -// fmt.Println(err) // Output: new message func Error(err error) *Accessor[error] { return NewAccessor[error](err) } // Accessor allows you to access and modify unexported fields of a struct. type Accessor[T any] struct { - target T + target T + wrapped bool } -// NewAccessor creates a generic accessor for the given target. +// NewAccessor creates a generic accessor/builder for a given target struct. +// If the target is a pointer to a struct (template), the instance is stored +// and modified. If the target is a struct, a pointer to a new instance of is +// created, since a struct cannot be modified by reflection. func NewAccessor[T any](target T) *Accessor[T] { - return &Accessor[T]{target: target} + value := reflect.ValueOf(target) + if value.Kind() == reflect.Ptr && value.Elem().Kind() == reflect.Struct { + return &Accessor[T]{ + target: value.Interface().(T), + } + } else if value.Kind() == reflect.Struct { + target = reflect.New(value.Type()).Interface().(T) + + return &Accessor[T]{ + target: target, + wrapped: true, + } + } + panic("target must be a struct or pointer to struct") } -// Set sets the value of the accessor target's field with the given name. +// Set sets the value of the field with the given name. If the name is empty, +// and of the same type the stored target instance is replaced by the given +// value. func (a *Accessor[T]) Set(name string, value any) *Accessor[T] { - field := reflect.ValueOf(a.target).Elem().FieldByName(name) - // #nosec G103,G115 // This is a safe use of unsafe.Pointer. - reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). - Elem().Set(reflect.ValueOf(value)) - + if name != "" { + field := reflect.ValueOf(a.target).Elem().FieldByName(name) + // #nosec G103,G115 // This is a safe use of unsafe.Pointer. + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). + Elem().Set(reflect.ValueOf(value)) + } else if reflect.TypeOf(a.target) == reflect.TypeOf(value) { + a.target = value.(T) + } else { + panic("target must of compatible struct pointer type") + } return a } -// Get returns the value of the field with the given name. If the name is empty, -// it returns the accessor target itself. +// Get returns the value of the field with the given name. If the name is +// empty, the stored target instance is returned. func (a *Accessor[T]) Get(name string) any { - if name == "" { - return a.target + if name != "" { + field := reflect.ValueOf(a.target).Elem().FieldByName(name) + // #nosec G103,G115 // This is a safe use of unsafe.Pointer. + return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). + Elem().Interface() + } else if a.wrapped { + return reflect.ValueOf(a.target).Elem().Interface() } - field := reflect.ValueOf(a.target).Elem().FieldByName(name) - // #nosec G103,G115 // This is a safe use of unsafe.Pointer. - return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). - Elem().Interface() + return a.target } diff --git a/test/reflect_test.go b/test/reflect_test.go index b7d5fbb..f669633 100644 --- a/test/reflect_test.go +++ b/test/reflect_test.go @@ -5,19 +5,129 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tkrop/go-testing/mock" "github.com/tkrop/go-testing/test" ) +type Struct struct{ s string } + +func NewStruct(s string) Struct { return Struct{s: s} } +func NewPtrStruct(s string) *Struct { return &Struct{s: s} } + +type testAccessorParam struct { + target any + setup func(*test.Accessor[any]) + expect mock.SetupFunc + check func(test.Test, *test.Accessor[any]) +} + +var testAccessorParams = map[string]testAccessorParam{ + "test invalid type": { + target: int(1), + expect: test.Panic("target must be a struct or pointer to struct"), + }, + + "test struct get is empty - no copy possible": { + target: NewStruct("test get"), + check: func(t test.Test, a *test.Accessor[any]) { + assert.Equal(t, "", a.Get("s")) + assert.Equal(t, NewStruct(""), a.Get("")) + }, + }, + + "test struct set": { + target: NewStruct("test set"), + setup: func(a *test.Accessor[any]) { + a.Set("s", "test set first"). + Set("s", "test set final") + }, + check: func(t test.Test, a *test.Accessor[any]) { + assert.Equal(t, "test set final", a.Get("s")) + assert.Equal(t, NewStruct("test set final"), a.Get("")) + }, + }, + + "test struct reset no pointer": { + target: NewStruct("test reset"), + setup: func(a *test.Accessor[any]) { + a.Set("s", "test reset first"). + Set("", NewStruct("test reset final")) + }, + expect: test.Panic("target must of compatible struct pointer type"), + }, + + "test struct reset pointer": { + target: NewStruct("test reset"), + setup: func(a *test.Accessor[any]) { + a.Set("s", "test reset first"). + Set("", NewPtrStruct("test reset final")) + }, + check: func(t test.Test, a *test.Accessor[any]) { + assert.Equal(t, "test reset final", a.Get("s")) + assert.Equal(t, NewStruct("test reset final"), a.Get("")) + }, + }, + + "test ptr get": { + target: NewPtrStruct("test get"), + check: func(t test.Test, a *test.Accessor[any]) { + assert.Equal(t, "test get", a.Get("s")) + assert.Equal(t, NewPtrStruct("test get"), a.Get("")) + }, + }, + + "test ptr set": { + target: NewPtrStruct("test set"), + setup: func(a *test.Accessor[any]) { + a.Set("s", "test set first"). + Set("s", "test set final") + }, + check: func(t test.Test, a *test.Accessor[any]) { + assert.Equal(t, "test set final", a.Get("s")) + assert.Equal(t, NewPtrStruct("test set final"), a.Get("")) + }, + }, + + "test ptr reset": { + target: NewPtrStruct("test reset"), + setup: func(a *test.Accessor[any]) { + a.Set("s", "test reset first"). + Set("", NewPtrStruct("test reset final")) + }, + check: func(t test.Test, a *test.Accessor[any]) { + assert.Equal(t, "test reset final", a.Get("s")) + assert.Equal(t, NewPtrStruct("test reset final"), a.Get("")) + }, + }, +} + +func TestAccessor(t *testing.T) { + test.Map(t, testAccessorParams). + Run(func(t test.Test, param testAccessorParam) { + // Given + mock.NewMocks(t).Expect(param.expect) + accessor := test.NewAccessor(param.target) + + // When + if param.setup != nil { + param.setup(accessor) + } + + // The + param.check(t, accessor) + }) +} + type testErrorParam struct { error error setup func(*test.Accessor[error]) - test func(test.Test, *test.Accessor[error]) + check func(test.Test, *test.Accessor[error]) } var testErrorParams = map[string]testErrorParam{ "test get": { error: errors.New("test get"), - test: func(t test.Test, a *test.Accessor[error]) { + check: func(t test.Test, a *test.Accessor[error]) { assert.Equal(t, "test get", a.Get("s")) assert.Equal(t, errors.New("test get"), a.Get("")) }, @@ -29,7 +139,19 @@ var testErrorParams = map[string]testErrorParam{ a.Set("s", "test set first"). Set("s", "test set final") }, - test: func(t test.Test, a *test.Accessor[error]) { + check: func(t test.Test, a *test.Accessor[error]) { + assert.Equal(t, "test set final", a.Get("s")) + assert.Equal(t, errors.New("test set final"), a.Get("")) + }, + }, + + "test reset": { + error: errors.New("test set"), + setup: func(a *test.Accessor[error]) { + a.Set("s", "test set first"). + Set("", errors.New("test set final")) + }, + check: func(t test.Test, a *test.Accessor[error]) { assert.Equal(t, "test set final", a.Get("s")) assert.Equal(t, errors.New("test set final"), a.Get("")) }, @@ -48,6 +170,6 @@ func TestError(t *testing.T) { } // The - param.test(t, accessor) + param.check(t, accessor) }) }