diff --git a/router/handler.go b/router/handler.go index 4e76b47..67be2d5 100644 --- a/router/handler.go +++ b/router/handler.go @@ -60,13 +60,13 @@ func getEndpoint(m reflect.Method) *endpoint { return nil } - // Functions must have inputs like (c context.Context, d *data.Data), plus one input for the receiver - if m.Type.NumIn() != 3 { + // Functions must have inputs like (c context.Context, req *data.Request, resp *data.Response), plus one input for the receiver + if m.Type.NumIn() != 4 { return nil } - // Functions must return either 1 or 2 values - if m.Type.NumOut() == 0 || m.Type.NumOut() > 2 { + // Functions must return an error value + if m.Type.NumOut() != 1 { return nil } @@ -74,6 +74,6 @@ func getEndpoint(m reflect.Method) *endpoint { Name: m.Name, HandlerFunc: m.Func, In: m.Type.In(2), - Out: m.Type.Out(0), + Out: m.Type.In(3), } } diff --git a/router/router.go b/router/router.go index 12f5f32..c125fbc 100644 --- a/router/router.go +++ b/router/router.go @@ -73,26 +73,24 @@ func (s *router) Handle(path string, data []byte) ([]byte, error) { return nil, fmt.Errorf("decode request: %w", err) } + respValue := reflect.New(method.Out) + ret := method.HandlerFunc.Call([]reflect.Value{ reflect.ValueOf(handler.Instance), reflect.ValueOf(context.Background()), *in, + respValue, }) - if len(ret) == 2 && !ret[1].IsNil() { + if !ret[0].IsNil() { return nil, &HandlerError{ endpoint: method, handler: handler, - err: ret[1].Interface().(error), + err: ret[0].Interface().(error), } } - retval := ret[0] - if len(ret) == 2 { - retval = ret[1] - } - - outdata, err := s.codec.Marshal(retval.Interface()) + outdata, err := s.codec.Marshal(respValue.Interface()) if err != nil { return nil, fmt.Errorf("encode response: %w", err) }