diff --git a/apitest.go b/apitest.go index 5db3e46..45d7e36 100644 --- a/apitest.go +++ b/apitest.go @@ -322,7 +322,7 @@ type Response struct { jsonPathExpression string jsonPathAssert func(interface{}) apiTest *APITest - assert Assert + assert []Assert } // Assert is a user defined custom assertion function @@ -378,7 +378,7 @@ func (r *Response) Status(s int) *Response { // Assert allows the consumer to provide a user defined function containing their own // custom assertions func (r *Response) Assert(fn func(*http.Response, *http.Request) error) *Response { - r.assert = fn + r.assert = append(r.assert, fn) return r.apiTest.response } @@ -530,13 +530,22 @@ func (a *APITest) run() { a.assertResponse(res) a.assertHeaders(res) a.assertCookies(res) + err := a.assertFunc(res, req) + if err != nil { + a.t.Fatal(err.Error()) + } +} - if a.response.assert != nil { - err := a.response.assert(res.Result(), req) - if err != nil { - a.t.Fatal(err.Error()) +func (a *APITest) assertFunc(res *httptest.ResponseRecorder, req *http.Request) error { + if len(a.response.assert) > 0 { + for _, assertFn := range a.response.assert { + err := assertFn(res.Result(), req) + if err != nil { + return err + } } } + return nil } func (a *APITest) runTest() (*httptest.ResponseRecorder, *http.Request) { diff --git a/apitest_test.go b/apitest_test.go index 17bceab..2805c0e 100644 --- a/apitest_test.go +++ b/apitest_test.go @@ -2,6 +2,7 @@ package apitest import ( "bytes" + "errors" "fmt" "github.com/stretchr/testify/assert" "io/ioutil" @@ -410,6 +411,40 @@ func TestApiTest_CustomAssert(t *testing.T) { End() } +func TestApiTest_SupportsMultipleCustomAsserts(t *testing.T) { + test := New(). + Patch("/hello"). + Expect(t). + Assert(IsSuccess). + Assert(IsSuccess) + + assert.Len(t, test.assert, 2) +} + +func TestApiTest_AssertFunc(t *testing.T) { + tests := []struct { + statusCode int + expectedErr error + }{ + {200, nil}, + {400, errors.New("not success. Status code=400")}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("status: %d", test.statusCode), func(t *testing.T) { + res := httptest.NewRecorder() + res.Code = test.statusCode + apitTest := New(). + Patch("/hello"). + Expect(t). + Assert(IsSuccess) + + err := apitTest.apiTest.assertFunc(res, nil) + + assert.Equal(t, test.expectedErr, err) + }) + } +} + func TestApiTest_Report(t *testing.T) { getUser := NewMock(). Get("http://localhost:8080"). diff --git a/assert.go b/assert.go index b045b3e..0fb8b33 100644 --- a/assert.go +++ b/assert.go @@ -2,8 +2,8 @@ package apitest import ( "errors" + "fmt" "net/http" - "strconv" ) // IsSuccess is a convenience function to assert on a range of happy path status codes @@ -11,7 +11,7 @@ var IsSuccess Assert = func(response *http.Response, request *http.Request) erro if response.StatusCode >= 200 && response.StatusCode < 400 { return nil } - return errors.New("not a client error. Status code=" + strconv.Itoa(response.StatusCode)) + return errors.New(fmt.Sprintf("not success. Status code=%d", response.StatusCode)) } // IsClientError is a convenience function to assert on a range of client error status codes @@ -19,7 +19,7 @@ var IsClientError Assert = func(response *http.Response, request *http.Request) if response.StatusCode >= 400 && response.StatusCode < 500 { return nil } - return errors.New("not a client error. Status code=" + strconv.Itoa(response.StatusCode)) + return errors.New(fmt.Sprintf("not a client error. Status code=%d", response.StatusCode)) } // IsServerError is a convenience function to assert on a range of server error status codes @@ -27,5 +27,5 @@ var IsServerError Assert = func(response *http.Response, request *http.Request) if response.StatusCode >= 500 { return nil } - return errors.New("not a server error. Status code=" + strconv.Itoa(response.StatusCode)) + return errors.New(fmt.Sprintf("not a server error. Status code=%d", response.StatusCode)) }