Skip to content

Commit

Permalink
Add ability to set a per-attempt timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
justinrixx committed Oct 3, 2023
1 parent eda55af commit 1b9ff12
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 7 deletions.
36 changes: 36 additions & 0 deletions cancelreader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package rehttp

import (
"context"
"io"
"net/http"
)

type cancelReader struct {
io.ReadCloser

cancel context.CancelFunc
}

func (r cancelReader) Close() error {
r.cancel()
return r.ReadCloser.Close()
}

// injectCancelReader propagates the ability for the caller to cancel the request context
// once done with the response. If the transport cancels before the body stream is read,
// a race begins where the caller may be unable to read the response bytes before the stream
// is closed and an error is returned. This helper function wraps a response body in a
// io.ReadCloser that cancels the context once the body is closed, preventing a context leak.
// Solution based on https://github.com/go-kit/kit/issues/773.
func injectCancelReader(res *http.Response, cancel context.CancelFunc) *http.Response {
if res == nil {
return nil
}

res.Body = cancelReader{
ReadCloser: res.Body,
cancel: cancel,
}
return res
}
19 changes: 19 additions & 0 deletions perattempttimeout_post17.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//go:build go1.7
// +build go1.7

package rehttp

import (
"context"
"net/http"
"time"
)

func getRequestContext(req *http.Request) context.Context {
return req.Context()
}

func getPerAttemptTimeoutInfo(ctx context.Context, req *http.Request, timeout time.Duration) (*http.Request, context.CancelFunc) {
tctx, cancel := context.WithTimeout(ctx, timeout)
return req.WithContext(tctx), cancel
}
19 changes: 19 additions & 0 deletions perattempttimeout_pre17.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//go:build !go1.7
// +build !go1.7

package rehttp

import (
"context"
"net/http"
"time"
)

func getRequestContext(req *http.Request) context.Context {
return nil // req.Context() doesn't exist before 1.7
}

func getPerAttemptTimeoutInfo(ctx context.Context, req *http.Request, timeout time.Duration) (*http.Request, context.CancelFunc) {
// req.WithContext() doesn't exist before 1.7, so noop
return req, func() {}
}
41 changes: 34 additions & 7 deletions rehttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
// (https://golang.org/pkg/net/http/#Transport.CancelRequest).
//
// On Go1.7+, it uses the context returned by http.Request.Context
// to check for cancelled requests.
// to check for cancelled requests. Before Go1.7, PerAttemptTimeout
// has no effect.
//
// It should work on Go1.5, but only if there is no timeout set on the
// *http.Client. Go's stdlib will return an error on the first request
Expand All @@ -50,6 +51,7 @@ package rehttp

import (
"bytes"
"context"
"errors"
"io"
"io/ioutil"
Expand Down Expand Up @@ -280,6 +282,22 @@ type Transport struct {
// is non-nil.
PreventRetryWithBody bool

// PerAttemptTimeout can be optionally set to add per-attempt timeouts.
// These may be used in place of or in conjunction with overall timeouts.
// For example, a per-attempt timeout of 5s would mean an attempt will
// be canceled after 5s, then the delay fn will be consulted before
// potentially making another attempt, which will again be capped at 5s.
// This means that the overall duration may be up to
// (PerAttemptTimeout + delay) * n, where n is the maximum attempts.
// If using an overall timeout (whether on the http client or the request
// context), the request will stop at whichever timeout is reached first.
// Your RetryFn can determine if a request hit the per-attempt timeout by
// checking if attempt.Error == context.DeadlineExceeded (or use errors.Is
// on go 1.13+).
// time.Duration(0) signals that no per-attempt timeout should be used.
// Note that before go 1.7 this option has no effect.
PerAttemptTimeout time.Duration

// retry is a function that determines if the request should be retried.
// Unless a retry is prevented based on PreventRetryWithBody, all requests
// go through that function, even those that are typically considered
Expand All @@ -297,6 +315,9 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
var attempt int
preventRetry := req.Body != nil && req.Body != http.NoBody && t.PreventRetryWithBody

// used as a baseline to set fresh timeouts per-attempt if needed
ctx := getRequestContext(req)

// get the done cancellation channel for the context, will be nil
// for < go1.7.
done := contextForRequest(req)
Expand All @@ -317,19 +338,24 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
}

for {
res, err := t.RoundTripper.RoundTrip(req)
var cancel context.CancelFunc = func() {} // empty unless a timeout is set
reqWithTimeout := req
if t.PerAttemptTimeout != 0 {
reqWithTimeout, cancel = getPerAttemptTimeoutInfo(ctx, req, t.PerAttemptTimeout)
}
res, err := t.RoundTripper.RoundTrip(reqWithTimeout)
if preventRetry {
return res, err
return injectCancelReader(res, cancel), err
}

retry, delay := t.retry(Attempt{
Request: req,
Request: reqWithTimeout,
Response: res,
Index: attempt,
Error: err,
})
if !retry {
return res, err
return injectCancelReader(res, cancel), err
}

if br != nil {
Expand All @@ -338,15 +364,16 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// to reset on the request is the body, if any.
if _, serr := br.Seek(0, 0); serr != nil {
// failed to retry, return the results
return res, err
return injectCancelReader(res, cancel), err
}
req.Body = ioutil.NopCloser(br)
reqWithTimeout.Body = ioutil.NopCloser(br)
}
// close the disposed response's body, if any
if res != nil {
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
}
cancel() // we're done with this response and won't be returning it, so it's safe to cancel immediately

select {
case <-time.After(delay):
Expand Down
167 changes: 167 additions & 0 deletions rehttp_server_post17_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
//go:build go1.7
// +build go1.7

package rehttp

import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestTransport_RoundTripTimeouts(t *testing.T) {
// to keep track of any open server requests and ensure the correct number of requests were made
ch := make(chan bool, 4)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ch <- true
time.Sleep(time.Millisecond * 100)
w.WriteHeader(http.StatusTooManyRequests)
}))
defer ts.Close()

tr := NewTransport(
http.DefaultTransport,
RetryAll(RetryMaxRetries(3), RetryAny(
RetryStatuses(http.StatusTooManyRequests), // retry 429s
func(attempt Attempt) bool { // retry context deadline exceeded errors
return attempt.Error != nil && attempt.Error == context.DeadlineExceeded // errors.Is requires go 1.13+
})),
ConstDelay(0),
)
tr.PerAttemptTimeout = time.Millisecond * 10 // short timeout

client := http.Client{
Transport: tr,
}

req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Errorf("error creating request: %s", err)
}

_, err = client.Do(req)
// make sure server has finished and got 4 attempts
<-ch
<-ch
<-ch
<-ch
if err == nil {
t.Error("expected timeout error doing request but got nil")
}

// now increase the timeout restriction
ch = make(chan bool, 4)
tr.PerAttemptTimeout = time.Second
res, err := client.Do(req)

// should have attempted 4 times without going over the timeout
<-ch
<-ch
<-ch
<-ch

if err != nil {
t.Errorf("got unexpected error doing request: %s", err)
}
if res == nil || res.StatusCode != http.StatusTooManyRequests {
t.Errorf("status code does not match expected: got %d, want %d", res.StatusCode, http.StatusTooManyRequests)
}

// now remove the timeout restriction
ch = make(chan bool, 4)
tr.PerAttemptTimeout = time.Duration(0)
res, err = client.Do(req)
// should have attempted 4 times without going over the timeout
<-ch
<-ch
<-ch
<-ch

if err != nil {
t.Errorf("got unexpected error doing request: %s", err)
}
if res == nil || res.StatusCode != http.StatusTooManyRequests {
t.Errorf("status code does not match expected: got %d, want %d", res.StatusCode, http.StatusTooManyRequests)
}
}

func TestTransport_RoundTripOverallTimeout(t *testing.T) {
ch := make(chan bool, 2)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ch <- true
time.Sleep(time.Second * 2)
w.WriteHeader(http.StatusTooManyRequests)
}))
defer ts.Close()

tr := NewTransport(
http.DefaultTransport,
RetryAll(RetryMaxRetries(3), RetryAny(
RetryStatuses(http.StatusTooManyRequests), // retry 429s
func(attempt Attempt) bool { // retry context deadline exceeded errors
return attempt.Error != nil && attempt.Error == context.DeadlineExceeded // errors.Is requires go 1.13+
})),
ConstDelay(0),
)
tr.PerAttemptTimeout = time.Second

client := http.Client{
Transport: tr,
}

req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Errorf("error creating request: %s", err)
}

ctx, cancelFunc := context.WithTimeout(context.Background(), time.Millisecond*1500)
_, err = client.Do(req.WithContext(ctx))
cancelFunc()
// should only make 2 attempts
<-ch
<-ch
if err == nil {
t.Error("expected timeout error doing request but got nil")
}
}

// TestCancelReader is meant to test that the cancel reader is correctly
// preventing the race-case of being unable to read the body due to a
// preemptively-canceled context.
func TestCancelReader(t *testing.T) {
rt := NewTransport(http.DefaultTransport, RetryMaxRetries(1), ConstDelay(0))
rt.PerAttemptTimeout = time.Millisecond * 100
client := http.Client{
Transport: rt,
}

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 10)
w.WriteHeader(http.StatusOK)
// need a decent number of bytes to make the race case more likely to fail
// https://github.com/go-kit/kit/issues/773
w.Write(make([]byte, 102400))
}))
defer ts.Close()

ctx := context.Background()

req, _ := http.NewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req.WithContext(ctx))
if err != nil {
t.Fatalf("unexpected error creating request: %s", err)
}
defer res.Body.Close()
b, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("error reading response body: %s", err)
}
if len(b) != 102400 {
t.Errorf("response byte length does not match expected. got %d, want %d", len(b), 102400)
}
}

0 comments on commit 1b9ff12

Please sign in to comment.