diff --git a/httpclient/client.go b/httpclient/client.go index 5f03c5b..dd199ca 100644 --- a/httpclient/client.go +++ b/httpclient/client.go @@ -120,6 +120,20 @@ func (c *Client) Delete(url string, headers http.Header) (*http.Response, error) return c.Do(request) } +// Head makes a HTTP HEAD request with provided URL +func (c *Client) Head(url string, headers http.Header) (*http.Response, error) { + var response *http.Response + + request, err := http.NewRequest(http.MethodHead, url, nil) + if err != nil { + return response, errors.Wrap(err, "HEAD - request creation failed") + } + + request.Header = headers + + return c.Do(request) +} + // Do makes an HTTP request with the native `http.Do` interface func (c *Client) Do(request *http.Request) (*http.Response, error) { request.Close = true diff --git a/httpclient/client_test.go b/httpclient/client_test.go index dbca35d..e5a7eaa 100644 --- a/httpclient/client_test.go +++ b/httpclient/client_test.go @@ -131,6 +131,31 @@ func TestHTTPClientDeleteSuccess(t *testing.T) { assert.Equal(t, "{ \"response\": \"ok\" }", respBody(t, response)) } +func TestHTTPClientHeadSuccess(t *testing.T) { + client := NewClient(WithHTTPTimeout(10 * time.Millisecond)) + + dummyHandler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodHead, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + assert.Equal(t, "en", r.Header.Get("Accept-Language")) + + w.Header().Add("x-foo", "bar") + w.WriteHeader(http.StatusOK) + } + + server := httptest.NewServer(http.HandlerFunc(dummyHandler)) + defer server.Close() + + headers := http.Header{} + headers.Set("Content-Type", "application/json") + headers.Set("Accept-Language", "en") + + response, err := client.Head(server.URL, headers) + require.NoError(t, err, "should not have failed to make a HEAD request") + assert.Equal(t, "bar", response.Header.Get("x-foo")) + assert.Equal(t, http.StatusOK, response.StatusCode) +} + func TestHTTPClientPutSuccess(t *testing.T) { client := NewClient(WithHTTPTimeout(10 * time.Millisecond))