diff --git a/pester.go b/pester.go index ce326ed..4ef9312 100644 --- a/pester.go +++ b/pester.go @@ -198,6 +198,16 @@ func (c *Client) copyBody(src io.ReadCloser) ([]byte, error) { return b, nil } +// resetBody resets the Body and GetBody fields of an http.Request to new Readers over +// the originalBody. This is used to refresh http.Requests that may have had their +// bodies closed already. +func resetBody(request *http.Request, originalBody []byte) { + request.Body = io.NopCloser(bytes.NewBuffer(originalBody)) + request.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewBuffer(originalBody)), nil + } +} + // pester provides all the logic of retries, concurrency, backoff, and logging func (c *Client) pester(p params) (*http.Response, error) { resultCh := make(chan result) @@ -242,7 +252,6 @@ func (c *Client) pester(p params) (*http.Response, error) { // if we have a request body, we need to save it for later var ( - request *http.Request originalBody []byte err error ) @@ -252,23 +261,52 @@ func (c *Client) pester(p params) (*http.Response, error) { } else if p.body != nil { originalBody, err = c.copyBody(p.body) } + if err != nil { + return nil, err + } + // check to make sure that we aren't trying to use an unsupported method switch p.method { - case methodDo: - request = p.req - case methodGet, methodHead: - request, err = http.NewRequest(p.verb, p.url, nil) - case methodPostForm, methodPost: - request, err = http.NewRequest(http.MethodPost, p.url, ioutil.NopCloser(bytes.NewBuffer(originalBody))) + case methodDo, methodGet, methodHead, methodPostForm, methodPost: default: - err = ErrUnexpectedMethod - } - if err != nil { - return nil, err + return nil, ErrUnexpectedMethod } - if len(p.bodyType) > 0 { - request.Header.Set(headerKeyContentType, p.bodyType) + // provideRequest returns an HTTP request to be use when retrying. + // if concurrency is 1, it will return the same request that was supplied to the Do() method + // for Do() calls, otherwise it will generate a Clone() of the request each time it is called. + // For non-Do() calls, it creates a new request each time it is called. This re-creation behaviour + // is because requests are not supposed to be used again until the RoundTripper is finished + // with them, which cannot be guaranteed with concurrent callers + // https://pkg.go.dev/net/http#RoundTripper + provideRequest := func() (request *http.Request, err error) { + switch p.method { + case methodDo: + if concurrency > 1 { + request = p.req.Clone(p.req.Context()) + } else { + request = p.req + } + if request.Body != nil { + // reset the body since Clone() doesn't do that for us + // and we drained it earlier when performing the Copy + // ex: https://go.dev/play/p/jlc6A-fjaOi + resetBody(request, originalBody) + } + case methodGet, methodHead: + request, err = http.NewRequest(p.verb, p.url, nil) + case methodPostForm, methodPost: + request, err = http.NewRequest(http.MethodPost, p.url, bytes.NewBuffer(originalBody)) + } + if err != nil { + return + } + + if len(p.bodyType) > 0 { + request.Header.Set(headerKeyContentType, p.bodyType) + } + + return } AttemptLimit := c.MaxRetries @@ -279,9 +317,15 @@ func (c *Client) pester(p params) (*http.Response, error) { for n := 0; n < concurrency; n++ { c.wg.Add(1) totalSentRequests.Add(1) - go func(n int, req *http.Request) { + go func(n int) { defer c.wg.Done() defer totalSentRequests.Done() + req, err := provideRequest() + // couldn't get a request to use, so don't proceed + if err != nil { + multiplexCh <- result{err: err, req: n} + return + } for i := 1; i <= AttemptLimit; i++ { c.wg.Add(1) @@ -340,15 +384,19 @@ func (c *Client) pester(p params) (*http.Response, error) { case <-time.After(c.Backoff(i) + 1*time.Microsecond): // allow context cancellation to cancel during backoff case <-req.Context().Done(): + multiplexCh <- result{resp: resp, err: req.Context().Err()} return } - } - }(n, request) - // rehydrate the body (it is drained each read) - if request.Body != nil { - request.Body = ioutil.NopCloser(bytes.NewBuffer(originalBody)) - } + // we are about to retry, if we had a Body, we will need to restore it + // to a non-closed one in order to work reliably. If you do not do this, + // there are a number of curious edge cases depending on the type of the + // underlying reader: https://go.dev/play/p/gZLVUe2EXSE + if req.Body != nil { + resetBody(req, originalBody) + } + } + }(n) } // spin off the go routine so it can continually listen in on late results and close the response bodies diff --git a/pester_test.go b/pester_test.go index 21cf7f1..654dfc0 100644 --- a/pester_test.go +++ b/pester_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io/ioutil" "log" "net" "net/http" @@ -743,6 +744,171 @@ func TestRetriesNotAttemptedIfContextIsCancelled(t *testing.T) { } } +type roundTripperFunc func(r *http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +func TestRetriesContextCancelledDuringWait(t *testing.T) { + t.Parallel() + // in order for this test to work we need to be able to reliably put the client in a + // waiting state. To achieve this, we create a client that will fail fast + // via a custom RoundTripper that always fails and pair it with a custom BackoffStrategy + // that waits for a long time. This results in a client that should spend + // almost all of its time waiting. + + ctx, cancel := context.WithCancel(context.Background()) + + c := NewExtendedClient(&http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("always fail") + }), + Timeout: 5 * time.Second, + }) + c.MaxRetries = 2 + c.Backoff = func(retry int) time.Duration { + return 5 * time.Second + } + // req details don't really matter, round-tripper will fail it anyway + req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost", nil) + if err != nil { + t.Fatalf("unable to create request %v", err) + } + + // we want to perform the call in a goroutine so we can explicitly check for indefinite + // blocking behaviour. Since you cannot use t.Fatal/t.Error/etc. in a goroutine, we + // create a channel to communicate back to our main goroutine what happened + errReturn := make(chan error) + go func() { + // perform call in goroutine to check for indefinite blocks + _, err := c.Do(req) + errReturn <- err + }() + + // wait a hundred ms to let the client fail and get into a waiting state + <-time.After(100 * time.Millisecond) + // cancel our context + cancel() + + // if all has gone well, we should have aborted our wait period and the + // err channel should contain a Context-cancellation error + + select { + case recdErr := <-errReturn: + if recdErr == nil { + t.Fatal("nil error returned from Do(req) routine") + } + // check that it is the right error message + if context.Canceled != recdErr { + t.Fatalf("unexpected error returned: %v", recdErr) + } + case <-time.After(time.Second): + // give it a second, then treat this as failing to return + t.Fatal("failed to receive error return") + } +} + +func TestRetriesWithBodies_Do(t *testing.T) { + t.Parallel() + + const testContent = "TestRetriesWithBodies_Do" + // using a channel to route these errors back into this goroutine + // it is important that this channel have enough capacity to hold all + // of the errors that will be generated by the test so that we do not + // deadlock. Therefore, MaxAttempts must be the same size as the channel capacity + // and each execution must only put at most one error on the channel. + serverReqErrCh := make(chan error, 4) + port, closeFn, err := middlewareServer( + contentVerificationMiddleware(serverReqErrCh, testContent), + always500RequestMiddleware(), + ) + if err != nil { + t.Fatal("unable to start timeout server", err) + } + defer closeFn() + + <-time.After(2 * time.Second) + + iseUrl := fmt.Sprintf("http://localhost:%d", port) + + req, err := http.NewRequest("POST", iseUrl, strings.NewReader(testContent)) + if err != nil { + t.Fatalf("unable to create request %v", err) + } + + c := New() + c.MaxRetries = cap(serverReqErrCh) + c.KeepLog = true + c.Backoff = func(retry int) time.Duration { + // backoff isn't important for this test + return 0 + } + + resp, err := c.Do(req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if resp == nil { + t.Error("response was unexpectedly nil") + } else if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("unexpected response StatusCode: %v", resp.StatusCode) + } + // we're done making requests, so close the return channel and drain it + close(serverReqErrCh) + for v := range serverReqErrCh { + if v != nil { + t.Errorf("unexpected error occurred when server processed request: %v", v) + } + } +} + +func TestRetriesWithBodies_POST(t *testing.T) { + t.Parallel() + + const testContent = "TestRetriesWithBodies_POST" + // using a channel to route these errors back into this goroutine + // it is important that this channel have enough capacity to hold all + // of the errors that will be generated by the test so that we do not + // deadlock. Therefore, MaxAttempts must be the same size as the channel capacity + // and each execution must only put at most one error on the channel. + serverReqErrCh := make(chan error, 4) + port, closeFn, err := middlewareServer( + contentVerificationMiddleware(serverReqErrCh, testContent), + always500RequestMiddleware(), + ) + if err != nil { + t.Fatal("unable to start timeout server", err) + } + defer closeFn() + + c := New() + c.MaxRetries = cap(serverReqErrCh) + c.KeepLog = true + c.Backoff = func(retry int) time.Duration { + // backoff isn't important for this test + return 0 + } + + iseUrl := fmt.Sprintf("http://localhost:%d", port) + resp, err := c.Post(iseUrl, "text/plain", strings.NewReader(testContent)) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if resp == nil { + t.Error("response was unexpectedly nil") + } else if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("unexpected response StatusCode: %v", resp.StatusCode) + } + // we're done making requests, so close the return channel and drain it + close(serverReqErrCh) + for v := range serverReqErrCh { + if v != nil { + t.Errorf("unexpected error occurred when server processed request: %v", v) + } + } +} + func withinEpsilon(got, want int64, epslion float64) bool { if want <= int64(epslion*float64(got)) || want >= int64(epslion*float64(got)) { return false @@ -880,3 +1046,61 @@ func serverWith400() (int, error) { return port, nil } + +func contentVerificationMiddleware(errorCh chan<- error, expectedContent string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + content, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + if err != nil { + errorCh <- err + } else if string(content) != expectedContent { + errorCh <- fmt.Errorf( + "unexpected body content: expected \"%v\", got \"%v\"", + expectedContent, + string(content), + ) + } + }) +} + +func always500RequestMiddleware() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 Internal Server Error")) + }) +} + +// middlewareServer stands up a server that accepts varags of middleware that conforms to the +// http.Handler interface +func middlewareServer(requestMiddleware ...http.Handler) (int, func(), error) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + for _, v := range requestMiddleware { + v.ServeHTTP(w, r) + } + }) + l, err := net.Listen("tcp", ":0") + if err != nil { + return -1, nil, fmt.Errorf("unable to secure listener %v", err) + } + server := &http.Server{ + Handler: mux, + } + go func() { + if err := server.Serve(l); err != nil && err != http.ErrServerClosed { + log.Fatalf("middleware-server error %v", err) + } + }() + + var port int + _, sport, err := net.SplitHostPort(l.Addr().String()) + if err == nil { + port, err = strconv.Atoi(sport) + } + + if err != nil { + return -1, nil, fmt.Errorf("unable to determine port %v", err) + } + + return port, func() { server.Close() }, nil +}