diff --git a/apitest.go b/apitest.go index f436b33..0531fca 100644 --- a/apitest.go +++ b/apitest.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/http/httputil" + "net/textproto" "strings" "testing" ) @@ -41,7 +42,7 @@ func New(name ...string) *APITest { } response := &Response{ apiTest: apiTest, - headers: map[string]string{}, + headers: map[string][]string{}, } apiTest.request = request apiTest.response = response @@ -198,14 +199,16 @@ func (r *Request) QueryCollection(q map[string][]string) *Request { // Header is a builder method to set the request headers func (r *Request) Header(key, value string) *Request { - r.headers[key] = append(r.headers[key], value) + normalizedKey := textproto.CanonicalMIMEHeaderKey(key) + r.headers[normalizedKey] = append(r.headers[normalizedKey], value) return r } // Headers is a builder method to set the request headers func (r *Request) Headers(headers map[string]string) *Request { for k, v := range headers { - r.headers[k] = append(r.headers[k], v) + normalizedKey := textproto.CanonicalMIMEHeaderKey(k) + r.headers[normalizedKey] = append(r.headers[normalizedKey], v) } return r } @@ -233,7 +236,7 @@ func (r *Request) Expect(t *testing.T) *Response { type Response struct { status int body string - headers map[string]string + headers map[string][]string cookies []*Cookie cookiesPresent []string cookiesNotPresent []string @@ -271,9 +274,19 @@ func (r *Response) CookieNotPresent(cookieName string) *Response { return r } -// Headers is the expected response headers +// Header is a builder method to set the request headers +func (r *Response) Header(key, value string) *Response { + normalizedKey := textproto.CanonicalMIMEHeaderKey(key) + r.headers[normalizedKey] = append(r.headers[normalizedKey], value) + return r +} + +// Headers is a builder method to set the request headers func (r *Response) Headers(headers map[string]string) *Response { - r.headers = headers + for k, v := range headers { + normalizedKey := textproto.CanonicalMIMEHeaderKey(k) + r.headers[normalizedKey] = append(r.headers[textproto.CanonicalMIMEHeaderKey(normalizedKey)], v) + } return r } @@ -469,10 +482,19 @@ func responseCookies(response *httptest.ResponseRecorder) []*http.Cookie { } func (a *APITest) assertHeaders(res *httptest.ResponseRecorder) { - if a.response.headers != nil { - for k, v := range a.response.headers { - header := res.Header().Get(k) - assertEqual(a.t, v, header, fmt.Sprintf("'%s' header should be equal", k)) + for expectedHeader, expectedValues := range a.response.headers { + for _, expectedValue := range expectedValues { + found := false + result := res.Result() + for _, resValue := range result.Header[expectedHeader] { + if expectedValue == resValue { + found = true + break + } + } + if !found { + a.t.Fatalf("could not match header=%s", expectedHeader) + } } } } diff --git a/apitest_test.go b/apitest_test.go index d6ff926..8a6474c 100644 --- a/apitest_test.go +++ b/apitest_test.go @@ -310,11 +310,13 @@ func TestApiTest_MatchesResponseHttpCookies_OnlySuppliedFields(t *testing.T) { End() } -func TestApiTest_MatchesResponseHeaders(t *testing.T) { +func TestApiTest_MatchesResponseHeaders_WithMixedKeyCase(t *testing.T) { handler := http.NewServeMux() handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("ABC", "12345") w.Header().Set("DEF", "67890") + w.Header().Set("Authorization", "12345") + w.Header().Add("Authorization", "98765") w.WriteHeader(http.StatusOK) }) @@ -324,9 +326,11 @@ func TestApiTest_MatchesResponseHeaders(t *testing.T) { Expect(t). Status(http.StatusOK). Headers(map[string]string{ - "ABC": "12345", - "DEF": "67890", + "Abc": "12345", + "Def": "67890", }). + Header("Authorization", "12345"). + Header("authorization", "98765"). End() }