From fb34589af575d6a46e698094825e80b48d046b67 Mon Sep 17 00:00:00 2001 From: Joseph Phillips Date: Thu, 19 Jan 2023 13:34:59 +0100 Subject: [PATCH] Ensures that retried request dispatching does not reset the body with a non-nil empty buffer when the body is nil. This should prevent observed "net/http: cannot rewind body after connection loss" errors. --- client.go | 23 ++++++++++++++--------- client_test.go | 51 +++++++++++++++++++++++++++++++++++++------------- testing.go | 30 +++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 22 deletions(-) diff --git a/client.go b/client.go index e3fe3a7..ed31da2 100644 --- a/client.go +++ b/client.go @@ -7,7 +7,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "mime/multipart" "net/http" "net/url" @@ -62,7 +61,7 @@ func readAndClose(stream io.ReadCloser) ([]byte, error) { return nil, nil } defer stream.Close() - return ioutil.ReadAll(stream) + return io.ReadAll(stream) } // dispatchRequest sends a request to the server, and interprets the response. @@ -70,7 +69,7 @@ func readAndClose(stream io.ReadCloser) ([]byte, error) { // server-side errors however (i.e. responses with a non 2XX status code), the // returned error will be ServerError and the returned body will reflect the // server's response. If the server returns a 503 response with a 'Retry-after' -// header, the request will be transparenty retried. +// header, the request will be transparently retried. func (client Client) dispatchRequest(request *http.Request) ([]byte, error) { // First, store the request's body into a byte[] to be able to restore it // after each request. @@ -80,18 +79,21 @@ func (client Client) dispatchRequest(request *http.Request) ([]byte, error) { } for retry := 0; retry < NumberOfRetries; retry++ { // Restore body before issuing request. - newBody := ioutil.NopCloser(bytes.NewReader(bodyContent)) - request.Body = newBody + if request.Body != nil { + newBody := io.NopCloser(bytes.NewReader(bodyContent)) + request.Body = newBody + } + body, err := client.dispatchSingleRequest(request) // If this is a 503 response with a non-void "Retry-After" header: wait // as instructed and retry the request. if err != nil { serverError, ok := errors.Cause(err).(ServerError) if ok && serverError.StatusCode == http.StatusServiceUnavailable { - retry_time_int, errConv := strconv.Atoi(serverError.Header.Get(RetryAfterHeaderName)) + retryTimeInt, errConv := strconv.Atoi(serverError.Header.Get(RetryAfterHeaderName)) if errConv == nil { select { - case <-time.After(time.Duration(retry_time_int) * time.Second): + case <-time.After(time.Duration(retryTimeInt) * time.Second): } continue } @@ -100,8 +102,11 @@ func (client Client) dispatchRequest(request *http.Request) ([]byte, error) { return body, err } // Restore body before issuing request. - newBody := ioutil.NopCloser(bytes.NewReader(bodyContent)) - request.Body = newBody + if request.Body != nil { + newBody := io.NopCloser(bytes.NewReader(bodyContent)) + request.Body = newBody + } + return client.dispatchSingleRequest(request) } diff --git a/client_test.go b/client_test.go index dea6ab4..54d9320 100644 --- a/client_test.go +++ b/client_test.go @@ -5,8 +5,9 @@ package gomaasapi import ( "bytes" + "crypto/tls" "fmt" - "io/ioutil" + "io" "net/http" "net/url" "strings" @@ -20,15 +21,15 @@ type ClientSuite struct{} var _ = gc.Suite(&ClientSuite{}) -func (*ClientSuite) TestReadAndCloseReturnsEmptyStringForNil(c *gc.C) { +func (*ClientSuite) TestReadAndCloseReturnsNilForNilBuffer(c *gc.C) { data, err := readAndClose(nil) c.Assert(err, jc.ErrorIsNil) - c.Check(string(data), gc.Equals, "") + c.Check(data, gc.IsNil) } func (*ClientSuite) TestReadAndCloseReturnsContents(c *gc.C) { content := "Stream contents." - stream := ioutil.NopCloser(strings.NewReader(content)) + stream := io.NopCloser(strings.NewReader(content)) data, err := readAndClose(stream) c.Assert(err, jc.ErrorIsNil) @@ -36,7 +37,7 @@ func (*ClientSuite) TestReadAndCloseReturnsContents(c *gc.C) { c.Check(string(data), gc.Equals, content) } -func (suite *ClientSuite) TestClientdispatchRequestReturnsServerError(c *gc.C) { +func (suite *ClientSuite) TestClientDispatchRequestReturnsServerError(c *gc.C) { URI := "/some/url/?param1=test" expectedResult := "expected:result" server := newSingleServingServer(URI, expectedResult, http.StatusBadRequest, -1) @@ -56,14 +57,15 @@ func (suite *ClientSuite) TestClientdispatchRequestReturnsServerError(c *gc.C) { c.Check(string(result), gc.Equals, expectedResult) } -func (suite *ClientSuite) TestClientdispatchRequestRetries503(c *gc.C) { +func (suite *ClientSuite) TestClientDispatchRequestRetries503(c *gc.C) { URI := "/some/url/?param1=test" server := newFlakyServer(URI, 503, NumberOfRetries) defer server.Close() client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) content := "content" - request, err := http.NewRequest("GET", server.URL+URI, ioutil.NopCloser(strings.NewReader(content))) + request, err := http.NewRequest("GET", server.URL+URI, io.NopCloser(strings.NewReader(content))) + c.Assert(err, jc.ErrorIsNil) _, err = client.dispatchRequest(request) @@ -76,7 +78,28 @@ func (suite *ClientSuite) TestClientdispatchRequestRetries503(c *gc.C) { c.Check(*server.requests, jc.DeepEquals, expectedRequestsContent) } -func (suite *ClientSuite) TestClientdispatchRequestDoesntRetry200(c *gc.C) { +func (suite *ClientSuite) TestTLSClientDispatchRequestRetries503NilBody(c *gc.C) { + URI := "/some/path" + server := newFlakyTLSServer(URI, 503, NumberOfRetries) + defer server.Close() + client, err := NewAnonymousClient(server.URL, "2.0") + c.Assert(err, jc.ErrorIsNil) + + client.HTTPClient = &http.Client{Transport: http.DefaultTransport} + client.HTTPClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + request, err := http.NewRequest("GET", server.URL+URI, nil) + c.Assert(err, jc.ErrorIsNil) + + _, err = client.dispatchRequest(request) + c.Assert(err, jc.ErrorIsNil) + + c.Check(*server.nbRequests, gc.Equals, NumberOfRetries+1) +} + +func (suite *ClientSuite) TestClientDispatchRequestDoesntRetry200(c *gc.C) { URI := "/some/url/?param1=test" server := newFlakyServer(URI, 200, 10) defer server.Close() @@ -84,6 +107,7 @@ func (suite *ClientSuite) TestClientdispatchRequestDoesntRetry200(c *gc.C) { c.Assert(err, jc.ErrorIsNil) request, err := http.NewRequest("GET", server.URL+URI, nil) + c.Assert(err, jc.ErrorIsNil) _, err = client.dispatchRequest(request) @@ -91,7 +115,7 @@ func (suite *ClientSuite) TestClientdispatchRequestDoesntRetry200(c *gc.C) { c.Check(*server.nbRequests, gc.Equals, 1) } -func (suite *ClientSuite) TestClientdispatchRequestRetriesIsLimited(c *gc.C) { +func (suite *ClientSuite) TestClientDispatchRequestRetriesIsLimited(c *gc.C) { URI := "/some/url/?param1=test" // Make the server return 503 responses NumberOfRetries + 1 times. server := newFlakyServer(URI, 503, NumberOfRetries+1) @@ -99,6 +123,7 @@ func (suite *ClientSuite) TestClientdispatchRequestRetriesIsLimited(c *gc.C) { client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) request, err := http.NewRequest("GET", server.URL+URI, nil) + c.Assert(err, jc.ErrorIsNil) _, err = client.dispatchRequest(request) @@ -124,7 +149,7 @@ func (suite *ClientSuite) TestClientDispatchRequestReturnsNonServerError(c *gc.C c.Check(result, gc.IsNil) } -func (suite *ClientSuite) TestClientdispatchRequestSignsRequest(c *gc.C) { +func (suite *ClientSuite) TestClientDispatchRequestSignsRequest(c *gc.C) { URI := "/some/url/?param1=test" expectedResult := "expected:result" server := newSingleServingServer(URI, expectedResult, http.StatusOK, -1) @@ -141,7 +166,7 @@ func (suite *ClientSuite) TestClientdispatchRequestSignsRequest(c *gc.C) { c.Check((*server.requestHeader)["Authorization"][0], gc.Matches, "^OAuth .*") } -func (suite *ClientSuite) TestClientdispatchRequestUsesConfiguredHTTPClient(c *gc.C) { +func (suite *ClientSuite) TestClientDispatchRequestUsesConfiguredHTTPClient(c *gc.C) { URI := "/some/url/" server := newSingleServingServer(URI, "", 0, 2*time.Second) @@ -217,7 +242,7 @@ func (suite *ClientSuite) TestClientPostSendsRequestWithParams(c *gc.C) { // extractFileContent extracts from the request built using 'requestContent', // 'requestHeader' and 'requestURL', the file named 'filename'. -func extractFileContent(requestContent string, requestHeader *http.Header, requestURL string, filename string) ([]byte, error) { +func extractFileContent(requestContent string, requestHeader *http.Header, requestURL string, _ string) ([]byte, error) { // Recreate the request from server.requestContent to use the parsing // utility from the http package (http.Request.FormFile). request, err := http.NewRequest("POST", requestURL, bytes.NewBufferString(requestContent)) @@ -229,7 +254,7 @@ func extractFileContent(requestContent string, requestHeader *http.Header, reque if err != nil { return nil, err } - fileContent, err := ioutil.ReadAll(file) + fileContent, err := io.ReadAll(file) if err != nil { return nil, err } diff --git a/testing.go b/testing.go index 908a1a5..494216d 100644 --- a/testing.go +++ b/testing.go @@ -84,6 +84,36 @@ func newFlakyServer(uri string, code int, nbFlakyResponses int) *flakyServer { return &flakyServer{server, &nbRequests, &requests} } +func newFlakyTLSServer(uri string, code int, nbFlakyResponses int) *flakyServer { + nbRequests := 0 + requests := make([][]byte, nbFlakyResponses+1) + var server *httptest.Server + + handler := func(writer http.ResponseWriter, request *http.Request) { + nbRequests += 1 + body, err := readAndClose(request.Body) + if err != nil { + panic(err) + } + requests[nbRequests-1] = body + if request.URL.String() != uri { + errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String()) + http.Error(writer, errorMsg, http.StatusNotFound) + } else if nbRequests <= nbFlakyResponses { + if code == http.StatusServiceUnavailable { + writer.Header().Set("Retry-After", "0") + } + writer.WriteHeader(code) + fmt.Fprint(writer, "flaky") + } else { + writer.WriteHeader(http.StatusOK) + fmt.Fprint(writer, "ok") + } + } + server = httptest.NewTLSServer(http.HandlerFunc(handler)) + return &flakyServer{server, &nbRequests, &requests} +} + type simpleResponse struct { status int body string