diff --git a/cmd/kelpie/mock.go.tmpl b/cmd/kelpie/mock.go.tmpl index df8fb99..c927fac 100644 --- a/cmd/kelpie/mock.go.tmpl +++ b/cmd/kelpie/mock.go.tmpl @@ -19,7 +19,7 @@ {{- end -}} {{- define "observationCallback" -}} -func({{ template "parameterWithTypeList" .Parameters }}) {{ template "resultTypeList" .Results }} +func({{ template "parameterWithTypeList" .Parameters }}){{ if .Results }} {{ template "resultTypeList" .Results }}{{ end }} {{- end -}} {{- define "matcherTypeParams" -}} @@ -58,12 +58,17 @@ type Instance struct { {{- range $method := .Methods }} -func (m *Instance) {{ $method.Name }}({{ template "parameterWithTypeList" $method.Parameters }}) ({{ template "resultWithTypeList" $method.Results }}) { +func (m *Instance) {{ $method.Name }}({{ template "parameterWithTypeList" $method.Parameters }}){{ if $method.Results }} ({{ template "resultWithTypeList" $method.Results }}){{ end }} { expectation := m.mock.Call("{{ $method.Name }}", {{ template "parameterList" $method.Parameters }}) if expectation != nil { if expectation.ObserveFn != nil { observe := expectation.ObserveFn.({{ template "observationCallback" $method }}) + {{- if $method.Results }} return observe({{ template "parameterList" $method.Parameters }}) + {{- else }} + observe({{ template "parameterList" $method.Parameters }}) + return + {{- end }} } if expectation.PanicArg != nil { @@ -121,6 +126,8 @@ func (e *{{ $method.Name }}Expectation) CreateExpectation() *mocking.Expectation return &e.expectation } +{{- if $method.Results }} + func (a *{{ $method.Name }}MethodMatcher) Return({{ template "resultWithTypeList" $method.Results }}) *{{ $method.Name }}Expectation { return &{{ $method.Name }}Expectation{ expectation: mocking.Expectation{ @@ -129,6 +136,7 @@ func (a *{{ $method.Name }}MethodMatcher) Return({{ template "resultWithTypeList }, } } +{{- end }} func (a *{{ $method.Name }}MethodMatcher) Panic(arg any) *{{ $method.Name }}Expectation { return &{{ $method.Name }}Expectation{ diff --git a/examples/mocks/accountservice/accountservice.go b/examples/mocks/accountservice/accountservice.go index 7bbf7b3..23f4ce9 100644 --- a/examples/mocks/accountservice/accountservice.go +++ b/examples/mocks/accountservice/accountservice.go @@ -44,6 +44,23 @@ func (m *Instance) SendActivationEmail(emailAddress string) (r0 bool) { return } +func (m *Instance) DisableAccount(id uint) { + expectation := m.mock.Call("DisableAccount", id) + if expectation != nil { + if expectation.ObserveFn != nil { + observe := expectation.ObserveFn.(func(id uint)) + observe(id) + return + } + + if expectation.PanicArg != nil { + panic(expectation.PanicArg) + } + } + + return +} + func (m *Mock) Instance() *Instance { return &m.instance } @@ -107,3 +124,54 @@ func (a *SendActivationEmailMethodMatcher) When(observe func(emailAddress string }, } } + +type DisableAccountMethodMatcher struct { + matcher mocking.MethodMatcher +} + +func (m *DisableAccountMethodMatcher) CreateMethodMatcher() *mocking.MethodMatcher { + return &m.matcher +} + +func DisableAccount[P0 uint | mocking.Matcher[uint]](id P0) *DisableAccountMethodMatcher { + result := DisableAccountMethodMatcher{ + matcher: mocking.MethodMatcher{ + MethodName: "DisableAccount", + ArgumentMatchers: make([]mocking.ArgumentMatcher, 1), + }, + } + + if matcher, ok := any(id).(mocking.Matcher[uint]); ok { + result.matcher.ArgumentMatchers[0] = matcher + } else { + result.matcher.ArgumentMatchers[0] = kelpie.ExactMatch(any(id).(uint)) + } + + return &result +} + +type DisableAccountExpectation struct { + expectation mocking.Expectation +} + +func (e *DisableAccountExpectation) CreateExpectation() *mocking.Expectation { + return &e.expectation +} + +func (a *DisableAccountMethodMatcher) Panic(arg any) *DisableAccountExpectation { + return &DisableAccountExpectation{ + expectation: mocking.Expectation{ + MethodMatcher: &a.matcher, + PanicArg: arg, + }, + } +} + +func (a *DisableAccountMethodMatcher) When(observe func(id uint)) *DisableAccountExpectation { + return &DisableAccountExpectation{ + expectation: mocking.Expectation{ + MethodMatcher: &a.matcher, + ObserveFn: observe, + }, + } +} diff --git a/examples/result_test.go b/examples/result_test.go index 87fcb49..f1543a0 100644 --- a/examples/result_test.go +++ b/examples/result_test.go @@ -3,6 +3,7 @@ package examples import ( "testing" + "github.com/adamconnelly/kelpie" "github.com/adamconnelly/kelpie/examples/mocks/accountservice" "github.com/stretchr/testify/suite" ) @@ -10,6 +11,7 @@ import ( //go:generate go run ../cmd/kelpie generate --source-file result_test.go --package github.com/adamconnelly/kelpie/examples --interfaces AccountService type AccountService interface { SendActivationEmail(emailAddress string) bool + DisableAccount(id uint) } type ResultTests struct { @@ -53,6 +55,21 @@ func (t *ResultTests) Test_CustomAction() { t.Equal("a@b.com", recipientAddress) } +func (t *ResultTests) Test_CanMockMethodsWithNoReturnArgs() { + // Arrange + var accountID uint + mock := accountservice.NewMock() + mock.Setup(accountservice.DisableAccount(kelpie.Any[uint]()).When(func(id uint) { + accountID = id + })) + + // Act + mock.Instance().DisableAccount(uint(123)) + + // Assert + t.Equal(uint(123), accountID) +} + func TestResults(t *testing.T) { suite.Run(t, new(ResultTests)) } diff --git a/parser/parser.go b/parser/parser.go index a0b9d0a..599177c 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -83,18 +83,20 @@ func Parse(reader io.Reader, packageName string, filter InterfaceFilter) ([]Mock } } - for _, result := range funcType.Results.List { - if len(result.Names) > 0 { - for _, resultName := range result.Names { + if funcType.Results != nil { + for _, result := range funcType.Results.List { + if len(result.Names) > 0 { + for _, resultName := range result.Names { + methodDefinition.Results = append(methodDefinition.Results, ResultDefinition{ + Name: resultName.Name, + Type: result.Type.(*ast.Ident).Name, + }) + } + } else { methodDefinition.Results = append(methodDefinition.Results, ResultDefinition{ - Name: resultName.Name, Type: result.Type.(*ast.Ident).Name, }) } - } else { - methodDefinition.Results = append(methodDefinition.Results, ResultDefinition{ - Type: result.Type.(*ast.Ident).Name, - }) } } diff --git a/parser/parser_test.go b/parser/parser_test.go index b77e656..7e1fd8e 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -120,6 +120,40 @@ type NotificationService interface { t.Equal("error", broadcastNotification.Results[1].Type) } +func (t *ParserTests) Test_Parse_SupportsMethodsWithNoResults() { + // Arrange + input := `package test + +type NotificationService interface { + Block(recipient string) +}` + + t.interfaceFilter.Setup(interfacefilter.Include("github.com/adamconnelly/kelpie/tests.UserService").Return(false)) + + // Act + result, err := parser.Parse(strings.NewReader(input), "github.com/adamconnelly/kelpie/tests", t.interfaceFilter.Instance()) + + // Assert + t.NoError(err) + t.Len(result, 1) + + notificationService := slices.FirstOrPanic(result, func(mock parser.MockedInterface) bool { + return mock.Name == "NotificationService" + }) + t.Equal("notificationservice", notificationService.PackageName) + t.Len(notificationService.Methods, 1) + + block := slices.FirstOrPanic(notificationService.Methods, func(method parser.MethodDefinition) bool { + return method.Name == "Block" + }) + + t.Len(block.Parameters, 1) + t.Equal("recipient", block.Parameters[0].Name) + t.Equal("string", block.Parameters[0].Type) + + t.Len(block.Results, 0) +} + // TODO: what about empty interfaces? Return a warning? func TestParser(t *testing.T) {