From 21178779dcfd8849e8f35decb13cde8f84e1b840 Mon Sep 17 00:00:00 2001 From: Abbey Nelson Date: Mon, 14 Aug 2023 14:45:54 -0600 Subject: [PATCH] Added functionality and tests for the 429 retry-header. Removed randomization from sleep times, now increments sleep time by 1 second for each attempt. Deleted unused code and fields, added notes for future upgrades --- internal/sdk/doc_test.go | 19 +++-- internal/sdk/retry_client.go | 47 +++++++------ internal/sdk/retry_client_test.go | 112 ++++++++++++++++++++++++++---- wireup/builder.go | 3 +- 4 files changed, 140 insertions(+), 41 deletions(-) diff --git a/internal/sdk/doc_test.go b/internal/sdk/doc_test.go index 34678b3..7817d56 100644 --- a/internal/sdk/doc_test.go +++ b/internal/sdk/doc_test.go @@ -3,6 +3,7 @@ package sdk import ( "io" "net/http" + "strconv" ) type FakeHTTPClient struct { @@ -19,17 +20,25 @@ func (f *FakeHTTPClient) Do(request *http.Request) (*http.Response, error) { /*////////////////////////////////////////////////////////////////////////*/ type FakeMultiHTTPClient struct { - requests []*http.Request - bodies []string - responses []*http.Response - errors []error - call int + requests []*http.Request + headers []*http.Header + bodies []string + responses []*http.Response + errors []error + call int + headerKey string + rateLimitTime int } func (f *FakeMultiHTTPClient) Do(request *http.Request) (*http.Response, error) { defer f.increment() f.simulateServerReadingRequestBody(request) f.requests = append(f.requests, request) + response := f.responses[f.call] + if response.StatusCode == 429 { + response.Header = http.Header{} + response.Header.Set(f.headerKey, strconv.Itoa(f.rateLimitTime)) + } return f.responses[f.call], f.errors[f.call] } diff --git a/internal/sdk/retry_client.go b/internal/sdk/retry_client.go index 6c69bac..360b735 100644 --- a/internal/sdk/retry_client.go +++ b/internal/sdk/retry_client.go @@ -3,9 +3,8 @@ package sdk import ( "bytes" "io" - "math/rand" "net/http" - "sync" + "strconv" "time" ) @@ -14,20 +13,18 @@ type RetryClient struct { inner HTTPClient maxRetries int sleeper func(time.Duration) - lock *sync.Mutex - rand *rand.Rand + rateLimit int } -func NewRetryClient(inner HTTPClient, maxRetries int, rand *rand.Rand, sleeper func(time.Duration)) HTTPClient { +func NewRetryClient(inner HTTPClient, maxRetries int, sleeper func(time.Duration)) HTTPClient { if maxRetries == 0 { return inner } return &RetryClient{ inner: inner, maxRetries: maxRetries, - lock: new(sync.Mutex), - rand: rand, sleeper: sleeper, + rateLimit: -1, } } @@ -80,8 +77,19 @@ func (r *RetryClient) handleHttpStatusCode(response *http.Response, attempt *int return false } if response.StatusCode == http.StatusTooManyRequests { - r.sleeper(time.Second * time.Duration(r.random(backOffRateLimit))) - *attempt = 1 + if response.Header != nil { + if i, err := strconv.Atoi(response.Header.Get("Retry-After")); err == nil { + r.rateLimit = i + *attempt = 0 + return true + } + } + if *attempt == 0 { + r.rateLimit = 1 + } else { + r.rateLimit += 1 + } + *attempt = 0 } return true } @@ -102,24 +110,21 @@ func (r *RetryClient) readBody(response *http.Response) bool { } func (r *RetryClient) backOff(attempt int) bool { - if attempt == 0 { - return true - } if attempt > r.maxRetries { return false } - backOffCap := max(0, min(maxBackOffDuration, attempt)) - backOff := time.Second * time.Duration(r.random(backOffCap)) + backOffCap := 0 + if r.rateLimit != -1 { + backOffCap = r.rateLimit + } else { + backOffCap = max(0, min(maxBackOffDuration, attempt)) + } + backOff := time.Second * time.Duration(backOffCap) r.sleeper(backOff) return true } -func (r *RetryClient) random(cap int) int { - r.lock.Lock() - defer r.lock.Unlock() - return r.rand.Intn(cap) -} - +// TODO: delete in favor of built-in function after upgrading to Go 1.21 func max(x, y int) int { if x > y { return x @@ -127,6 +132,7 @@ func max(x, y int) int { return y } +// TODO: delete in favor of built-in function after upgrading to Go 1.21 func min(x, y int) int { if x < y { return x @@ -135,6 +141,5 @@ func min(x, y int) int { } const ( - backOffRateLimit = 5 maxBackOffDuration = 10 ) diff --git a/internal/sdk/retry_client_test.go b/internal/sdk/retry_client_test.go index bafc690..fa765c3 100644 --- a/internal/sdk/retry_client_test.go +++ b/internal/sdk/retry_client_test.go @@ -3,7 +3,6 @@ package sdk import ( "errors" "io" - "math/rand" "net/http" "strings" "testing" @@ -22,6 +21,7 @@ type RetryClientFixture struct { inner *FakeMultiHTTPClient response *http.Response err error + header http.Header naps []time.Duration } @@ -32,7 +32,7 @@ func (f *RetryClientFixture) TestRequestBodyCannotBeBuffered_ErrorReturnedImmedi } func (f *RetryClientFixture) sendErrorProneRequest() (*http.Response, error) { f.inner = &FakeMultiHTTPClient{} - client := NewRetryClient(f.inner, 10, rand.New(rand.NewSource(0)), f.sleep).(*RetryClient) + client := NewRetryClient(f.inner, 10, f.sleep).(*RetryClient) request, _ := http.NewRequest("POST", "/", &ErrorProneReadCloser{readError: errors.New("GOPHERS!")}) return client.Do(request) } @@ -76,7 +76,7 @@ func (f *RetryClientFixture) assertRequestWasSuccessful() { func (f *RetryClientFixture) assertBackOffStrategyWasObserved() { f.So(f.inner.call, should.Equal, 5) f.So(f.naps, should.Resemble, - []time.Duration{0 * time.Second, 0 * time.Second, 1 * time.Second, 2 * time.Second}) + []time.Duration{0 * time.Second, 1 * time.Second, 2 * time.Second, 3 * time.Second, 4 * time.Second}) } /**************************************************************************/ @@ -90,6 +90,28 @@ func (f *RetryClientFixture) TestRetryOnBadResponseUntilSuccess() { f.assertBackOffStrategyWasObserved() } +func (f *RetryClientFixture) TestPost404ErrorDoesNotRetry() { + f.inner = NewFailingHTTPClient(404, 429) + + f.response, f.err = f.sendPostWithRetry(1) + + if f.So(f.response, should.NotBeNil) { + f.So(f.response.StatusCode, should.Equal, 404) + } + f.So(f.err, should.BeNil) +} + +func (f *RetryClientFixture) TestGet404ErrorDoesNotRetry() { + f.inner = NewFailingHTTPClient(404, 429) + + f.response, f.err = f.sendGetWithRetry(1) + + if f.So(f.response, should.NotBeNil) { + f.So(f.response.StatusCode, should.Equal, 404) + } + f.So(f.err, should.BeNil) +} + /**************************************************************************/ func (f *RetryClientFixture) TestFailureReturnedIfRetryExceeded() { @@ -111,7 +133,7 @@ func (f *RetryClientFixture) assertInternalServerError() { func (f *RetryClientFixture) TestNoRetryRequestedReturnsInnerClientInstead() { inner := &FakeHTTPClient{} - client := NewRetryClient(inner, 0, rand.New(rand.NewSource(0)), f.sleep) + client := NewRetryClient(inner, 0, f.sleep) f.So(client, should.Equal, inner) } @@ -125,9 +147,8 @@ func (f *RetryClientFixture) TestBackOffNeverToExceedHardCodedMaximum() { f.So(f.err, should.BeNil) f.So(f.inner.call, should.Equal, retries) - f.So(f.naps[0], should.Equal, 0) - for i := 1; i < len(f.naps); i++ { - f.So(f.naps[i], should.BeBetweenOrEqual, 0, time.Second*time.Duration(min(i, maxBackOffDuration))) + for i := 0; i < len(f.naps); i++ { + f.So(f.naps[i], should.Equal, time.Second*time.Duration(min(i, maxBackOffDuration))) } } @@ -142,9 +163,9 @@ func (f *RetryClientFixture) TestBackOffRateLimitedGet() { f.So(f.err, should.BeNil) if f.So(f.inner.call, should.Equal, 11) { var napTotal time.Duration - for i := 0; i < 10; i++ { + for i := 0; i < len(f.naps); i++ { napTotal += f.naps[i] - f.So(f.naps[i], should.BeBetweenOrEqual, 0, backOffRateLimit*time.Second) + f.So(f.naps[i], should.Equal, time.Second*time.Duration(min(i, maxBackOffDuration))) } f.So(napTotal, should.BeGreaterThan, time.Second*5) } @@ -161,23 +182,88 @@ func (f *RetryClientFixture) TestBackOffRateLimitedPost() { f.So(f.err, should.BeNil) if f.So(f.inner.call, should.Equal, 11) { var napTotal time.Duration - for i := 0; i < 10; i++ { + for i := 0; i < len(f.naps); i++ { + napTotal += f.naps[i] + f.So(f.naps[i], should.Equal, time.Second*time.Duration(min(i, maxBackOffDuration))) + } + f.So(napTotal, should.BeGreaterThan, time.Second*5) + } +} + +func (f *RetryClientFixture) TestRateLimitHeaderSetSleep() { + maxRetries := 3 + f.inner = NewFailingHTTPClient(429, 429, 429, 429, http.StatusOK) + rateLimitTime := 7 + f.inner.headerKey = "Retry-After" + f.inner.rateLimitTime = rateLimitTime + f.inner.responses[4].Body = io.NopCloser(strings.NewReader("Alohomora")) + f.response, f.err = f.sendPostWithRetry(maxRetries) + + f.So(f.err, should.BeNil) + if f.So(f.inner.call, should.Equal, 5) { + var napTotal time.Duration + f.So(f.naps[0], should.Equal, time.Duration(0)) + for i := 1; i < len(f.naps); i++ { + napTotal += f.naps[i] + f.So(f.naps[i], should.Equal, time.Second*time.Duration(rateLimitTime)) + } + f.So(napTotal, should.BeGreaterThan, time.Second*4) + } +} + +func (f *RetryClientFixture) TestRateLimitHeaderSetSleep12Sec() { + maxRetries := 9 + f.inner = NewFailingHTTPClient(429, 429, 429, 429, 429, 429, 429, 429, 429, 429, 429, 429, 429, http.StatusOK) + rateLimitTime := 12 + f.inner.headerKey = "Retry-After" + f.inner.rateLimitTime = rateLimitTime + f.inner.responses[13].Body = io.NopCloser(strings.NewReader("Alohomora")) + f.response, f.err = f.sendPostWithRetry(maxRetries) + + f.So(f.err, should.BeNil) + if f.So(f.inner.call, should.Equal, 14) { + var napTotal time.Duration + f.So(f.naps[0], should.Equal, time.Duration(0)) + for i := 1; i < len(f.naps); i++ { napTotal += f.naps[i] - f.So(f.naps[i], should.BeBetweenOrEqual, 0, backOffRateLimit*time.Second) + f.So(f.naps[i], should.Equal, time.Second*time.Duration(rateLimitTime)) } f.So(napTotal, should.BeGreaterThan, time.Second*5) } } +func (f *RetryClientFixture) TestRateLimitNoHeaderSetSleep() { + maxRetries := 3 + f.inner = NewFailingHTTPClient(429, 429, 429, 429, http.StatusOK) + f.inner.headerKey = "Invalid-Header" + f.inner.rateLimitTime = 12 + f.inner.responses[4].Body = io.NopCloser(strings.NewReader("Alohomora")) + f.response, f.err = f.sendPostWithRetry(maxRetries) + + f.So(f.err, should.BeNil) + if f.So(f.inner.call, should.Equal, 5) { + f.So(f.naps, should.Resemble, + []time.Duration{0 * time.Second, 1 * time.Second, 2 * time.Second, 3 * time.Second, 4 * time.Second}) + } +} + /**************************************************************************/ func (f *RetryClientFixture) sendGetWithRetry(retries int) (*http.Response, error) { - client := NewRetryClient(f.inner, retries, rand.New(rand.NewSource(0)), f.sleep).(*RetryClient) + if len(f.inner.responses) <= retries { + f.T().Fatalf("The number of retries is greater than or equal to the number of status codes provided. Please ensure that the number of retries is less than the number of status codes provided.") + } + + client := NewRetryClient(f.inner, retries, f.sleep).(*RetryClient) request, _ := http.NewRequest("GET", "/?body=request", nil) return client.Do(request) } func (f *RetryClientFixture) sendPostWithRetry(retries int) (*http.Response, error) { - client := NewRetryClient(f.inner, retries, rand.New(rand.NewSource(0)), f.sleep).(*RetryClient) + if len(f.inner.responses) <= retries { + f.T().Fatalf("The number of retries is greater than or equal to the number of status codes provided. Please ensure that the number of retries is less than the number of status codes provided.") + } + + client := NewRetryClient(f.inner, retries, f.sleep).(*RetryClient) request, _ := http.NewRequest("POST", "/", strings.NewReader("request")) return client.Do(request) } diff --git a/wireup/builder.go b/wireup/builder.go index a01be03..1ac5a17 100644 --- a/wireup/builder.go +++ b/wireup/builder.go @@ -3,7 +3,6 @@ package wireup import ( "crypto/tls" "fmt" - "math/rand" "net" "net/http" "net/url" @@ -184,7 +183,7 @@ func (b *clientBuilder) buildHTTPClient() (wrapped internal.HTTPClient) { wrapped = b.buildClient() wrapped = internal.NewTracingClient(wrapped, b.trace) wrapped = internal.NewDebugOutputClient(wrapped, b.debug) - wrapped = internal.NewRetryClient(wrapped, b.retries, rand.New(rand.NewSource(time.Now().UnixNano())), time.Sleep) + wrapped = internal.NewRetryClient(wrapped, b.retries, time.Sleep) wrapped = internal.NewSigningClient(wrapped, b.credential) wrapped = internal.NewCustomHeadersClient(wrapped, b.headers) wrapped = internal.NewBaseURLClient(wrapped, b.baseURL)