diff --git a/httpclient/client.go b/httpclient/client.go index 5f03c5b..3d40cda 100644 --- a/httpclient/client.go +++ b/httpclient/client.go @@ -2,6 +2,7 @@ package httpclient import ( "bytes" + "context" "io" "io/ioutil" "net/http" @@ -138,11 +139,32 @@ func (c *Client) Do(request *http.Request) (*http.Response, error) { multiErr := &valkyrie.MultiError{} var response *http.Response +outter: for i := 0; i <= c.retryCount; i++ { if response != nil { response.Body.Close() } + // Wait before retrying. + if i > 0 { + backoffTime := c.retrier.NextInterval(i - 1) + ctx, cancel := context.WithTimeout(context.Background(), backoffTime) + + select { + case <-ctx.Done(): + cancel() + + case <-request.Context().Done(): + cancel() + + multiErr.Push(request.Context().Err().Error()) + c.reportError(request, request.Context().Err()) + + // If the request context has already been cancelled, don't retry + break outter + } + } + c.reportRequestStart(request) var err error response, err = c.client.Do(request) @@ -155,19 +177,18 @@ func (c *Client) Do(request *http.Request) (*http.Response, error) { if err != nil { multiErr.Push(err.Error()) c.reportError(request, err) - backoffTime := c.retrier.NextInterval(i) - time.Sleep(backoffTime) + continue } + c.reportRequestEnd(request, response) if response.StatusCode >= http.StatusInternalServerError { - backoffTime := c.retrier.NextInterval(i) - time.Sleep(backoffTime) continue } - multiErr = &valkyrie.MultiError{} // Clear errors if any iteration succeeds + // Clear errors if any iteration succeeds + multiErr = &valkyrie.MultiError{} break } diff --git a/httpclient/client_test.go b/httpclient/client_test.go index dbca35d..0e01dba 100644 --- a/httpclient/client_test.go +++ b/httpclient/client_test.go @@ -2,6 +2,7 @@ package httpclient import ( "bytes" + "context" "io/ioutil" "net/http" "net/http/httptest" @@ -417,6 +418,75 @@ func TestHTTPClientGetReturnsErrorOnFailure(t *testing.T) { assert.Nil(t, response) } +func TestHTTPClientDontRetryWhenContextIsCancelled(t *testing.T) { + noOfRetries := 3 + // Set a huge backoffInterval that we won't have to wait anyway + backoffInterval := 1 * time.Hour + maximumJitterInterval := 1 * time.Millisecond + + client := NewClient( + WithHTTPTimeout(10*time.Millisecond), + WithRetryCount(noOfRetries), + WithRetrier(heimdall.NewRetrier(heimdall.NewConstantBackoff(backoffInterval, maximumJitterInterval))), + ) + + tt := []struct { + Title string + CancelTimeout time.Duration + NotNilResponse bool + }{ + { + Title: "Cancel directly", + CancelTimeout: 0 * time.Millisecond, + NotNilResponse: false, + }, + { + Title: "Cancel afterwards", + CancelTimeout: 10 * time.Millisecond, + NotNilResponse: true, + }, + } + + for _, test := range tt { + test := test + t.Run(test.Title, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + dummyHandler := func(w http.ResponseWriter, r *http.Request) { + if test.CancelTimeout == 0 { + cancel() + } else { + go func() { + time.Sleep(test.CancelTimeout) + cancel() + }() + } + + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ "response": "something went wrong" }`)) + } + + server := httptest.NewServer(http.HandlerFunc(dummyHandler)) + defer server.Close() + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + response, err := client.Do(req.WithContext(ctx)) + require.Error(t, err, "should have failed to make request") + + if test.NotNilResponse { + require.NotNil(t, response) + } else { + require.Nil(t, response) + } + }) + } +} + func TestPluginMethodsCalled(t *testing.T) { client := NewClient(WithHTTPTimeout(10 * time.Millisecond)) mockPlugin := &MockPlugin{} diff --git a/httpclient/options_test.go b/httpclient/options_test.go index b92ce71..2e93673 100644 --- a/httpclient/options_test.go +++ b/httpclient/options_test.go @@ -128,6 +128,5 @@ func ExampleWithRetrier() { // Output: retry attempt 0 // retry attempt 1 // retry attempt 2 - // retry attempt 3 // error }