diff --git a/mocks.go b/mocks.go index 9498b2b..48e6faa 100644 --- a/mocks.go +++ b/mocks.go @@ -398,6 +398,9 @@ var pathMatcher Matcher = func(r *http.Request, spec *MockRequest) error { var hostMatcher Matcher = func(r *http.Request, spec *MockRequest) error { receivedHost := r.Host + if receivedHost == "" { + receivedHost = r.URL.Host + } mockHost := spec.url.Host if mockHost == "" { return nil diff --git a/mocks_test.go b/mocks_test.go index 208722b..15e9072 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "reflect" "strings" "testing" @@ -54,19 +55,38 @@ func TestMocks_NewEmptyUnmatchedMockError_ExpectedErrorsString(t *testing.T) { } func TestMocks_HostMatcher(t *testing.T) { - tests := []struct { - requestUrl string + tests := map[string]struct { + request *http.Request mockUrl string expectedError error }{ - {"http://test.com", "https://test.com", nil}, - {"https://test.com", "https://testa.com", errors.New("received host test.com did not match mock host testa.com")}, - {"https://test.com", "", nil}, + "matching": { + request: httptest.NewRequest(http.MethodGet, "http://test.com", nil), + mockUrl: "https://test.com", + expectedError: nil, + }, + "not matching": { + request: httptest.NewRequest(http.MethodGet, "https://test.com", nil), + mockUrl: "https://testa.com", + expectedError: errors.New("received host test.com did not match mock host testa.com"), + }, + "no expected host": { + request: httptest.NewRequest(http.MethodGet, "https://test.com", nil), + mockUrl: "", + expectedError: nil, + }, + "matching using URL host": { + request: &http.Request{URL: &url.URL{ + Host: "test.com", + Path: "/", + }}, + mockUrl: "https://test.com", + expectedError: nil, + }, } - for _, test := range tests { - t.Run(fmt.Sprintf("%s %s", test.requestUrl, test.mockUrl), func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, test.requestUrl, nil) - matchError := hostMatcher(req, NewMock().Get(test.mockUrl)) + for name, test := range tests { + t.Run(name, func(t *testing.T) { + matchError := hostMatcher(test.request, NewMock().Get(test.mockUrl)) assert.Equal(t, test.expectedError, matchError) }) }