Skip to content

Commit

Permalink
Merge pull request #26 from steinfletcher/feature/assert
Browse files Browse the repository at this point in the history
Support multiple assert functions
  • Loading branch information
steinfletcher authored Mar 5, 2019
2 parents e7806a3 + bc085aa commit d2178ca
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 10 deletions.
21 changes: 15 additions & 6 deletions apitest.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ type Response struct {
jsonPathExpression string
jsonPathAssert func(interface{})
apiTest *APITest
assert Assert
assert []Assert
}

// Assert is a user defined custom assertion function
Expand Down Expand Up @@ -378,7 +378,7 @@ func (r *Response) Status(s int) *Response {
// Assert allows the consumer to provide a user defined function containing their own
// custom assertions
func (r *Response) Assert(fn func(*http.Response, *http.Request) error) *Response {
r.assert = fn
r.assert = append(r.assert, fn)
return r.apiTest.response
}

Expand Down Expand Up @@ -530,13 +530,22 @@ func (a *APITest) run() {
a.assertResponse(res)
a.assertHeaders(res)
a.assertCookies(res)
err := a.assertFunc(res, req)
if err != nil {
a.t.Fatal(err.Error())
}
}

if a.response.assert != nil {
err := a.response.assert(res.Result(), req)
if err != nil {
a.t.Fatal(err.Error())
func (a *APITest) assertFunc(res *httptest.ResponseRecorder, req *http.Request) error {
if len(a.response.assert) > 0 {
for _, assertFn := range a.response.assert {
err := assertFn(res.Result(), req)
if err != nil {
return err
}
}
}
return nil
}

func (a *APITest) runTest() (*httptest.ResponseRecorder, *http.Request) {
Expand Down
35 changes: 35 additions & 0 deletions apitest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package apitest

import (
"bytes"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"io/ioutil"
Expand Down Expand Up @@ -410,6 +411,40 @@ func TestApiTest_CustomAssert(t *testing.T) {
End()
}

func TestApiTest_SupportsMultipleCustomAsserts(t *testing.T) {
test := New().
Patch("/hello").
Expect(t).
Assert(IsSuccess).
Assert(IsSuccess)

assert.Len(t, test.assert, 2)
}

func TestApiTest_AssertFunc(t *testing.T) {
tests := []struct {
statusCode int
expectedErr error
}{
{200, nil},
{400, errors.New("not success. Status code=400")},
}
for _, test := range tests {
t.Run(fmt.Sprintf("status: %d", test.statusCode), func(t *testing.T) {
res := httptest.NewRecorder()
res.Code = test.statusCode
apitTest := New().
Patch("/hello").
Expect(t).
Assert(IsSuccess)

err := apitTest.apiTest.assertFunc(res, nil)

assert.Equal(t, test.expectedErr, err)
})
}
}

func TestApiTest_Report(t *testing.T) {
getUser := NewMock().
Get("http://localhost:8080").
Expand Down
8 changes: 4 additions & 4 deletions assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,30 @@ package apitest

import (
"errors"
"fmt"
"net/http"
"strconv"
)

// IsSuccess is a convenience function to assert on a range of happy path status codes
var IsSuccess Assert = func(response *http.Response, request *http.Request) error {
if response.StatusCode >= 200 && response.StatusCode < 400 {
return nil
}
return errors.New("not a client error. Status code=" + strconv.Itoa(response.StatusCode))
return errors.New(fmt.Sprintf("not success. Status code=%d", response.StatusCode))
}

// IsClientError is a convenience function to assert on a range of client error status codes
var IsClientError Assert = func(response *http.Response, request *http.Request) error {
if response.StatusCode >= 400 && response.StatusCode < 500 {
return nil
}
return errors.New("not a client error. Status code=" + strconv.Itoa(response.StatusCode))
return errors.New(fmt.Sprintf("not a client error. Status code=%d", response.StatusCode))
}

// IsServerError is a convenience function to assert on a range of server error status codes
var IsServerError Assert = func(response *http.Response, request *http.Request) error {
if response.StatusCode >= 500 {
return nil
}
return errors.New("not a server error. Status code=" + strconv.Itoa(response.StatusCode))
return errors.New(fmt.Sprintf("not a server error. Status code=%d", response.StatusCode))
}

0 comments on commit d2178ca

Please sign in to comment.