diff --git a/internal/sdk/doc_test.go b/internal/sdk/doc_test.go index 34678b3..7817d56 100644 --- a/internal/sdk/doc_test.go +++ b/internal/sdk/doc_test.go @@ -3,6 +3,7 @@ package sdk import ( "io" "net/http" + "strconv" ) type FakeHTTPClient struct { @@ -19,17 +20,25 @@ func (f *FakeHTTPClient) Do(request *http.Request) (*http.Response, error) { /*////////////////////////////////////////////////////////////////////////*/ type FakeMultiHTTPClient struct { - requests []*http.Request - bodies []string - responses []*http.Response - errors []error - call int + requests []*http.Request + headers []*http.Header + bodies []string + responses []*http.Response + errors []error + call int + headerKey string + rateLimitTime int } func (f *FakeMultiHTTPClient) Do(request *http.Request) (*http.Response, error) { defer f.increment() f.simulateServerReadingRequestBody(request) f.requests = append(f.requests, request) + response := f.responses[f.call] + if response.StatusCode == 429 { + response.Header = http.Header{} + response.Header.Set(f.headerKey, strconv.Itoa(f.rateLimitTime)) + } return f.responses[f.call], f.errors[f.call] } diff --git a/internal/sdk/retry_client.go b/internal/sdk/retry_client.go index 6c69bac..360b735 100644 --- a/internal/sdk/retry_client.go +++ b/internal/sdk/retry_client.go @@ -3,9 +3,8 @@ package sdk import ( "bytes" "io" - "math/rand" "net/http" - "sync" + "strconv" "time" ) @@ -14,20 +13,18 @@ type RetryClient struct { inner HTTPClient maxRetries int sleeper func(time.Duration) - lock *sync.Mutex - rand *rand.Rand + rateLimit int } -func NewRetryClient(inner HTTPClient, maxRetries int, rand *rand.Rand, sleeper func(time.Duration)) HTTPClient { +func NewRetryClient(inner HTTPClient, maxRetries int, sleeper func(time.Duration)) HTTPClient { if maxRetries == 0 { return inner } return &RetryClient{ inner: inner, maxRetries: maxRetries, - lock: new(sync.Mutex), - rand: rand, sleeper: sleeper, + rateLimit: -1, } } @@ -80,8 +77,19 @@ func (r *RetryClient) handleHttpStatusCode(response *http.Response, attempt *int return false } if response.StatusCode == http.StatusTooManyRequests { - r.sleeper(time.Second * time.Duration(r.random(backOffRateLimit))) - *attempt = 1 + if response.Header != nil { + if i, err := strconv.Atoi(response.Header.Get("Retry-After")); err == nil { + r.rateLimit = i + *attempt = 0 + return true + } + } + if *attempt == 0 { + r.rateLimit = 1 + } else { + r.rateLimit += 1 + } + *attempt = 0 } return true } @@ -102,24 +110,21 @@ func (r *RetryClient) readBody(response *http.Response) bool { } func (r *RetryClient) backOff(attempt int) bool { - if attempt == 0 { - return true - } if attempt > r.maxRetries { return false } - backOffCap := max(0, min(maxBackOffDuration, attempt)) - backOff := time.Second * time.Duration(r.random(backOffCap)) + backOffCap := 0 + if r.rateLimit != -1 { + backOffCap = r.rateLimit + } else { + backOffCap = max(0, min(maxBackOffDuration, attempt)) + } + backOff := time.Second * time.Duration(backOffCap) r.sleeper(backOff) return true } -func (r *RetryClient) random(cap int) int { - r.lock.Lock() - defer r.lock.Unlock() - return r.rand.Intn(cap) -} - +// TODO: delete in favor of built-in function after upgrading to Go 1.21 func max(x, y int) int { if x > y { return x @@ -127,6 +132,7 @@ func max(x, y int) int { return y } +// TODO: delete in favor of built-in function after upgrading to Go 1.21 func min(x, y int) int { if x < y { return x @@ -135,6 +141,5 @@ func min(x, y int) int { } const ( - backOffRateLimit = 5 maxBackOffDuration = 10 ) diff --git a/internal/sdk/retry_client_test.go b/internal/sdk/retry_client_test.go index bafc690..fa765c3 100644 --- a/internal/sdk/retry_client_test.go +++ b/internal/sdk/retry_client_test.go @@ -3,7 +3,6 @@ package sdk import ( "errors" "io" - "math/rand" "net/http" "strings" "testing" @@ -22,6 +21,7 @@ type RetryClientFixture struct { inner *FakeMultiHTTPClient response *http.Response err error + header http.Header naps []time.Duration } @@ -32,7 +32,7 @@ func (f *RetryClientFixture) TestRequestBodyCannotBeBuffered_ErrorReturnedImmedi } func (f *RetryClientFixture) sendErrorProneRequest() (*http.Response, error) { f.inner = &FakeMultiHTTPClient{} - client := NewRetryClient(f.inner, 10, rand.New(rand.NewSource(0)), f.sleep).(*RetryClient) + client := NewRetryClient(f.inner, 10, f.sleep).(*RetryClient) request, _ := http.NewRequest("POST", "/", &ErrorProneReadCloser{readError: errors.New("GOPHERS!")}) return client.Do(request) } @@ -76,7 +76,7 @@ func (f *RetryClientFixture) assertRequestWasSuccessful() { func (f *RetryClientFixture) assertBackOffStrategyWasObserved() { f.So(f.inner.call, should.Equal, 5) f.So(f.naps, should.Resemble, - []time.Duration{0 * time.Second, 0 * time.Second, 1 * time.Second, 2 * time.Second}) + []time.Duration{0 * time.Second, 1 * time.Second, 2 * time.Second, 3 * time.Second, 4 * time.Second}) } /**************************************************************************/ @@ -90,6 +90,28 @@ func (f *RetryClientFixture) TestRetryOnBadResponseUntilSuccess() { f.assertBackOffStrategyWasObserved() } +func (f *RetryClientFixture) TestPost404ErrorDoesNotRetry() { + f.inner = NewFailingHTTPClient(404, 429) + + f.response, f.err = f.sendPostWithRetry(1) + + if f.So(f.response, should.NotBeNil) { + f.So(f.response.StatusCode, should.Equal, 404) + } + f.So(f.err, should.BeNil) +} + +func (f *RetryClientFixture) TestGet404ErrorDoesNotRetry() { + f.inner = NewFailingHTTPClient(404, 429) + + f.response, f.err = f.sendGetWithRetry(1) + + if f.So(f.response, should.NotBeNil) { + f.So(f.response.StatusCode, should.Equal, 404) + } + f.So(f.err, should.BeNil) +} + /**************************************************************************/ func (f *RetryClientFixture) TestFailureReturnedIfRetryExceeded() { @@ -111,7 +133,7 @@ func (f *RetryClientFixture) assertInternalServerError() { func (f *RetryClientFixture) TestNoRetryRequestedReturnsInnerClientInstead() { inner := &FakeHTTPClient{} - client := NewRetryClient(inner, 0, rand.New(rand.NewSource(0)), f.sleep) + client := NewRetryClient(inner, 0, f.sleep) f.So(client, should.Equal, inner) } @@ -125,9 +147,8 @@ func (f *RetryClientFixture) TestBackOffNeverToExceedHardCodedMaximum() { f.So(f.err, should.BeNil) f.So(f.inner.call, should.Equal, retries) - f.So(f.naps[0], should.Equal, 0) - for i := 1; i < len(f.naps); i++ { - f.So(f.naps[i], should.BeBetweenOrEqual, 0, time.Second*time.Duration(min(i, maxBackOffDuration))) + for i := 0; i < len(f.naps); i++ { + f.So(f.naps[i], should.Equal, time.Second*time.Duration(min(i, maxBackOffDuration))) } } @@ -142,9 +163,9 @@ func (f *RetryClientFixture) TestBackOffRateLimitedGet() { f.So(f.err, should.BeNil) if f.So(f.inner.call, should.Equal, 11) { var napTotal time.Duration - for i := 0; i < 10; i++ { + for i := 0; i < len(f.naps); i++ { napTotal += f.naps[i] - f.So(f.naps[i], should.BeBetweenOrEqual, 0, backOffRateLimit*time.Second) + f.So(f.naps[i], should.Equal, time.Second*time.Duration(min(i, maxBackOffDuration))) } f.So(napTotal, should.BeGreaterThan, time.Second*5) } @@ -161,23 +182,88 @@ func (f *RetryClientFixture) TestBackOffRateLimitedPost() { f.So(f.err, should.BeNil) if f.So(f.inner.call, should.Equal, 11) { var napTotal time.Duration - for i := 0; i < 10; i++ { + for i := 0; i < len(f.naps); i++ { + napTotal += f.naps[i] + f.So(f.naps[i], should.Equal, time.Second*time.Duration(min(i, maxBackOffDuration))) + } + f.So(napTotal, should.BeGreaterThan, time.Second*5) + } +} + +func (f *RetryClientFixture) TestRateLimitHeaderSetSleep() { + maxRetries := 3 + f.inner = NewFailingHTTPClient(429, 429, 429, 429, http.StatusOK) + rateLimitTime := 7 + f.inner.headerKey = "Retry-After" + f.inner.rateLimitTime = rateLimitTime + f.inner.responses[4].Body = io.NopCloser(strings.NewReader("Alohomora")) + f.response, f.err = f.sendPostWithRetry(maxRetries) + + f.So(f.err, should.BeNil) + if f.So(f.inner.call, should.Equal, 5) { + var napTotal time.Duration + f.So(f.naps[0], should.Equal, time.Duration(0)) + for i := 1; i < len(f.naps); i++ { + napTotal += f.naps[i] + f.So(f.naps[i], should.Equal, time.Second*time.Duration(rateLimitTime)) + } + f.So(napTotal, should.BeGreaterThan, time.Second*4) + } +} + +func (f *RetryClientFixture) TestRateLimitHeaderSetSleep12Sec() { + maxRetries := 9 + f.inner = NewFailingHTTPClient(429, 429, 429, 429, 429, 429, 429, 429, 429, 429, 429, 429, 429, http.StatusOK) + rateLimitTime := 12 + f.inner.headerKey = "Retry-After" + f.inner.rateLimitTime = rateLimitTime + f.inner.responses[13].Body = io.NopCloser(strings.NewReader("Alohomora")) + f.response, f.err = f.sendPostWithRetry(maxRetries) + + f.So(f.err, should.BeNil) + if f.So(f.inner.call, should.Equal, 14) { + var napTotal time.Duration + f.So(f.naps[0], should.Equal, time.Duration(0)) + for i := 1; i < len(f.naps); i++ { napTotal += f.naps[i] - f.So(f.naps[i], should.BeBetweenOrEqual, 0, backOffRateLimit*time.Second) + f.So(f.naps[i], should.Equal, time.Second*time.Duration(rateLimitTime)) } f.So(napTotal, should.BeGreaterThan, time.Second*5) } } +func (f *RetryClientFixture) TestRateLimitNoHeaderSetSleep() { + maxRetries := 3 + f.inner = NewFailingHTTPClient(429, 429, 429, 429, http.StatusOK) + f.inner.headerKey = "Invalid-Header" + f.inner.rateLimitTime = 12 + f.inner.responses[4].Body = io.NopCloser(strings.NewReader("Alohomora")) + f.response, f.err = f.sendPostWithRetry(maxRetries) + + f.So(f.err, should.BeNil) + if f.So(f.inner.call, should.Equal, 5) { + f.So(f.naps, should.Resemble, + []time.Duration{0 * time.Second, 1 * time.Second, 2 * time.Second, 3 * time.Second, 4 * time.Second}) + } +} + /**************************************************************************/ func (f *RetryClientFixture) sendGetWithRetry(retries int) (*http.Response, error) { - client := NewRetryClient(f.inner, retries, rand.New(rand.NewSource(0)), f.sleep).(*RetryClient) + if len(f.inner.responses) <= retries { + f.T().Fatalf("The number of retries is greater than or equal to the number of status codes provided. Please ensure that the number of retries is less than the number of status codes provided.") + } + + client := NewRetryClient(f.inner, retries, f.sleep).(*RetryClient) request, _ := http.NewRequest("GET", "/?body=request", nil) return client.Do(request) } func (f *RetryClientFixture) sendPostWithRetry(retries int) (*http.Response, error) { - client := NewRetryClient(f.inner, retries, rand.New(rand.NewSource(0)), f.sleep).(*RetryClient) + if len(f.inner.responses) <= retries { + f.T().Fatalf("The number of retries is greater than or equal to the number of status codes provided. Please ensure that the number of retries is less than the number of status codes provided.") + } + + client := NewRetryClient(f.inner, retries, f.sleep).(*RetryClient) request, _ := http.NewRequest("POST", "/", strings.NewReader("request")) return client.Do(request) } diff --git a/wireup/builder.go b/wireup/builder.go index a01be03..1ac5a17 100644 --- a/wireup/builder.go +++ b/wireup/builder.go @@ -3,7 +3,6 @@ package wireup import ( "crypto/tls" "fmt" - "math/rand" "net" "net/http" "net/url" @@ -184,7 +183,7 @@ func (b *clientBuilder) buildHTTPClient() (wrapped internal.HTTPClient) { wrapped = b.buildClient() wrapped = internal.NewTracingClient(wrapped, b.trace) wrapped = internal.NewDebugOutputClient(wrapped, b.debug) - wrapped = internal.NewRetryClient(wrapped, b.retries, rand.New(rand.NewSource(time.Now().UnixNano())), time.Sleep) + wrapped = internal.NewRetryClient(wrapped, b.retries, time.Sleep) wrapped = internal.NewSigningClient(wrapped, b.credential) wrapped = internal.NewCustomHeadersClient(wrapped, b.headers) wrapped = internal.NewBaseURLClient(wrapped, b.baseURL)