diff --git a/README.md b/README.md index 285ba65..cc06362 100644 --- a/README.md +++ b/README.md @@ -333,7 +333,7 @@ A plugin is an interface whose methods get called during key events in a request - `OnRequestStart` is called just before the request is made - `OnRequestEnd` is called once the request has successfully executed -- `OnError` is called is the request failed +- `OnError` is called when the request failed Each method is called with the request object as an argument, with `OnRequestEnd`, and `OnError` additionally being called with the response and error instances respectively. For a simple example on how to write plugins, look at the [request logger plugin](/plugins/request_logger.go). diff --git a/httpclient/client.go b/httpclient/client.go index 6880776..f1c5e94 100644 --- a/httpclient/client.go +++ b/httpclient/client.go @@ -2,6 +2,7 @@ package httpclient import ( "bytes" + "context" "io" "io/ioutil" "net/http" @@ -136,15 +137,17 @@ func (c *Client) Do(request *http.Request) (*http.Response, error) { } multiErr := &valkyrie.MultiError{} + var err error + var shouldRetry bool var response *http.Response - for i := 0; i <= c.retryCount; i++ { + for i := 0; ; i++ { if response != nil { response.Body.Close() } c.reportRequestStart(request) - var err error + response, err = c.client.Do(request) if bodyReader != nil { // Reset the body reader after the request since at this point it's already read @@ -152,28 +155,61 @@ func (c *Client) Do(request *http.Request) (*http.Response, error) { _, _ = bodyReader.Seek(0, 0) } + shouldRetry, err = c.checkRetry(request.Context(), response, err) + if err != nil { multiErr.Push(err.Error()) c.reportError(request, err) - backoffTime := c.retrier.NextInterval(i) - time.Sleep(backoffTime) - continue + } else { + c.reportRequestEnd(request, response) + } + + if !shouldRetry { + break } - c.reportRequestEnd(request, response) - if response.StatusCode >= http.StatusInternalServerError { - backoffTime := c.retrier.NextInterval(i) - time.Sleep(backoffTime) - continue + if c.retryCount-i <= 0 { + break } - multiErr = &valkyrie.MultiError{} // Clear errors if any iteration succeeds - break + // Cancel the retry sleep if the request context is cancelled or deadline exceeded + timer := time.NewTimer(c.retrier.NextInterval(i)) + select { + case <-request.Context().Done(): + timer.Stop() + break + case <-timer.C: + } + } + + if !shouldRetry && err == nil { + // Clear errors if any iteration succeeds + multiErr = &valkyrie.MultiError{} } return response, multiErr.HasError() } +func (c *Client) checkRetry(ctx context.Context, resp *http.Response, err error) (bool, error) { + // do not retry on context.Canceled or context.DeadlineExceeded + if ctx.Err() != nil { + return false, ctx.Err() + } + + if err != nil { + return true, err + } + + // 429 Too Many Requests is recoverable. Sometimes the server puts + // a Retry-After response header to indicate when the server is + // available to start processing request from client. + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= http.StatusInternalServerError { + return true, nil + } + + return false, nil +} + func (c *Client) reportRequestStart(request *http.Request) { for _, plugin := range c.plugins { plugin.OnRequestStart(request) diff --git a/httpclient/client_test.go b/httpclient/client_test.go index 83e94ba..0df713d 100644 --- a/httpclient/client_test.go +++ b/httpclient/client_test.go @@ -2,6 +2,7 @@ package httpclient import ( "bytes" + "context" "io/ioutil" "net/http" "net/http/httptest" @@ -15,6 +16,37 @@ import ( "github.com/stretchr/testify/require" ) +func TestHTTPRequestWithContextDoSuccess(t *testing.T) { + noOfRetries := 3 + backoffInterval := 1 * time.Millisecond + maximumJitterInterval := 10 * time.Millisecond + + client := NewClient( + WithRetryCount(noOfRetries), + WithRetrier(heimdall.NewRetrier(heimdall.NewConstantBackoff(backoffInterval, maximumJitterInterval))), + ) + + dummyHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ "response": "ok" }`)) + time.Sleep(100 * time.Millisecond) + } + + server := httptest.NewServer(http.HandlerFunc(dummyHandler)) + defer server.Close() + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + // define some user case context + subCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + req = req.WithContext(subCtx) + _, err = client.Do(req) + require.Contains(t, err.Error(), "context deadline exceeded") +} + func TestHTTPClientDoSuccess(t *testing.T) { client := NewClient(WithHTTPTimeout(10 * time.Millisecond)) @@ -201,6 +233,33 @@ func TestHTTPClientPatchSuccess(t *testing.T) { assert.Equal(t, "{ \"response\": \"ok\" }", respBody(t, response)) } +func TestHTTPClientGetRetriesOnTimeout(t *testing.T) { + count := 0 + noOfRetries := 3 + noOfCalls := noOfRetries + 1 + backoffInterval := 1 * time.Millisecond + maximumJitterInterval := 10 * time.Millisecond + + client := NewClient( + WithHTTPTimeout(3*time.Millisecond), + WithRetryCount(noOfRetries), + WithRetrier(heimdall.NewRetrier(heimdall.NewConstantBackoff(backoffInterval, maximumJitterInterval))), + ) + + dummyHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + count++ + time.Sleep(100 * time.Millisecond) + } + + server := httptest.NewServer(http.HandlerFunc(dummyHandler)) + defer server.Close() + + _, err := client.Get(server.URL, http.Header{}) + require.Contains(t, err.Error(), "context deadline exceeded") + assert.Equal(t, noOfCalls, count) +} + func TestHTTPClientGetRetriesOnFailure(t *testing.T) { count := 0 noOfRetries := 3 @@ -406,7 +465,6 @@ func TestHTTPClientGetReturnsNoErrorOn5xxFailure(t *testing.T) { response, err := client.Get(server.URL, http.Header{}) require.NoError(t, err) require.Equal(t, http.StatusInternalServerError, response.StatusCode) - } func TestHTTPClientGetReturnsErrorOnFailure(t *testing.T) { @@ -484,7 +542,8 @@ func TestCustomHTTPClientHeaderSuccess(t *testing.T) { client := NewClient( WithHTTPTimeout(10*time.Millisecond), WithHTTPClient(&myHTTPClient{ - client: http.Client{Timeout: 25 * time.Millisecond}}), + client: http.Client{Timeout: 25 * time.Millisecond}, + }), ) dummyHandler := func(w http.ResponseWriter, r *http.Request) { diff --git a/httpclient/options_test.go b/httpclient/options_test.go index a9eb5ba..d1bef15 100644 --- a/httpclient/options_test.go +++ b/httpclient/options_test.go @@ -128,6 +128,5 @@ func ExampleWithRetrier() { // Output: retry attempt 0 // retry attempt 1 // retry attempt 2 - // retry attempt 3 // error }