Skip to content

Commit

Permalink
Ensures that retried request dispatching does not reset the body with a
Browse files Browse the repository at this point in the history
non-nil empty buffer when the body is nil.

This should prevent observed "net/http: cannot rewind body after
connection loss" errors.
  • Loading branch information
Joseph Phillips authored and Joseph Phillips committed Jan 19, 2023
1 parent 7268ed0 commit fb34589
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 22 deletions.
23 changes: 14 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"bytes"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/url"
Expand Down Expand Up @@ -62,15 +61,15 @@ func readAndClose(stream io.ReadCloser) ([]byte, error) {
return nil, nil
}
defer stream.Close()
return ioutil.ReadAll(stream)
return io.ReadAll(stream)
}

// dispatchRequest sends a request to the server, and interprets the response.
// Client-side errors will return an empty response and a non-nil error. For
// server-side errors however (i.e. responses with a non 2XX status code), the
// returned error will be ServerError and the returned body will reflect the
// server's response. If the server returns a 503 response with a 'Retry-after'
// header, the request will be transparenty retried.
// header, the request will be transparently retried.
func (client Client) dispatchRequest(request *http.Request) ([]byte, error) {
// First, store the request's body into a byte[] to be able to restore it
// after each request.
Expand All @@ -80,18 +79,21 @@ func (client Client) dispatchRequest(request *http.Request) ([]byte, error) {
}
for retry := 0; retry < NumberOfRetries; retry++ {
// Restore body before issuing request.
newBody := ioutil.NopCloser(bytes.NewReader(bodyContent))
request.Body = newBody
if request.Body != nil {
newBody := io.NopCloser(bytes.NewReader(bodyContent))
request.Body = newBody
}

body, err := client.dispatchSingleRequest(request)
// If this is a 503 response with a non-void "Retry-After" header: wait
// as instructed and retry the request.
if err != nil {
serverError, ok := errors.Cause(err).(ServerError)
if ok && serverError.StatusCode == http.StatusServiceUnavailable {
retry_time_int, errConv := strconv.Atoi(serverError.Header.Get(RetryAfterHeaderName))
retryTimeInt, errConv := strconv.Atoi(serverError.Header.Get(RetryAfterHeaderName))
if errConv == nil {
select {
case <-time.After(time.Duration(retry_time_int) * time.Second):
case <-time.After(time.Duration(retryTimeInt) * time.Second):
}
continue
}
Expand All @@ -100,8 +102,11 @@ func (client Client) dispatchRequest(request *http.Request) ([]byte, error) {
return body, err
}
// Restore body before issuing request.
newBody := ioutil.NopCloser(bytes.NewReader(bodyContent))
request.Body = newBody
if request.Body != nil {
newBody := io.NopCloser(bytes.NewReader(bodyContent))
request.Body = newBody
}

return client.dispatchSingleRequest(request)
}

Expand Down
51 changes: 38 additions & 13 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ package gomaasapi

import (
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/url"
"strings"
Expand All @@ -20,23 +21,23 @@ type ClientSuite struct{}

var _ = gc.Suite(&ClientSuite{})

func (*ClientSuite) TestReadAndCloseReturnsEmptyStringForNil(c *gc.C) {
func (*ClientSuite) TestReadAndCloseReturnsNilForNilBuffer(c *gc.C) {
data, err := readAndClose(nil)
c.Assert(err, jc.ErrorIsNil)
c.Check(string(data), gc.Equals, "")
c.Check(data, gc.IsNil)
}

func (*ClientSuite) TestReadAndCloseReturnsContents(c *gc.C) {
content := "Stream contents."
stream := ioutil.NopCloser(strings.NewReader(content))
stream := io.NopCloser(strings.NewReader(content))

data, err := readAndClose(stream)
c.Assert(err, jc.ErrorIsNil)

c.Check(string(data), gc.Equals, content)
}

func (suite *ClientSuite) TestClientdispatchRequestReturnsServerError(c *gc.C) {
func (suite *ClientSuite) TestClientDispatchRequestReturnsServerError(c *gc.C) {
URI := "/some/url/?param1=test"
expectedResult := "expected:result"
server := newSingleServingServer(URI, expectedResult, http.StatusBadRequest, -1)
Expand All @@ -56,14 +57,15 @@ func (suite *ClientSuite) TestClientdispatchRequestReturnsServerError(c *gc.C) {
c.Check(string(result), gc.Equals, expectedResult)
}

func (suite *ClientSuite) TestClientdispatchRequestRetries503(c *gc.C) {
func (suite *ClientSuite) TestClientDispatchRequestRetries503(c *gc.C) {
URI := "/some/url/?param1=test"
server := newFlakyServer(URI, 503, NumberOfRetries)
defer server.Close()
client, err := NewAnonymousClient(server.URL, "1.0")
c.Assert(err, jc.ErrorIsNil)
content := "content"
request, err := http.NewRequest("GET", server.URL+URI, ioutil.NopCloser(strings.NewReader(content)))
request, err := http.NewRequest("GET", server.URL+URI, io.NopCloser(strings.NewReader(content)))
c.Assert(err, jc.ErrorIsNil)

_, err = client.dispatchRequest(request)

Expand All @@ -76,29 +78,52 @@ func (suite *ClientSuite) TestClientdispatchRequestRetries503(c *gc.C) {
c.Check(*server.requests, jc.DeepEquals, expectedRequestsContent)
}

func (suite *ClientSuite) TestClientdispatchRequestDoesntRetry200(c *gc.C) {
func (suite *ClientSuite) TestTLSClientDispatchRequestRetries503NilBody(c *gc.C) {
URI := "/some/path"
server := newFlakyTLSServer(URI, 503, NumberOfRetries)
defer server.Close()
client, err := NewAnonymousClient(server.URL, "2.0")
c.Assert(err, jc.ErrorIsNil)

client.HTTPClient = &http.Client{Transport: http.DefaultTransport}
client.HTTPClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}

request, err := http.NewRequest("GET", server.URL+URI, nil)
c.Assert(err, jc.ErrorIsNil)

_, err = client.dispatchRequest(request)
c.Assert(err, jc.ErrorIsNil)

c.Check(*server.nbRequests, gc.Equals, NumberOfRetries+1)
}

func (suite *ClientSuite) TestClientDispatchRequestDoesntRetry200(c *gc.C) {
URI := "/some/url/?param1=test"
server := newFlakyServer(URI, 200, 10)
defer server.Close()
client, err := NewAnonymousClient(server.URL, "1.0")
c.Assert(err, jc.ErrorIsNil)

request, err := http.NewRequest("GET", server.URL+URI, nil)
c.Assert(err, jc.ErrorIsNil)

_, err = client.dispatchRequest(request)

c.Assert(err, jc.ErrorIsNil)
c.Check(*server.nbRequests, gc.Equals, 1)
}

func (suite *ClientSuite) TestClientdispatchRequestRetriesIsLimited(c *gc.C) {
func (suite *ClientSuite) TestClientDispatchRequestRetriesIsLimited(c *gc.C) {
URI := "/some/url/?param1=test"
// Make the server return 503 responses NumberOfRetries + 1 times.
server := newFlakyServer(URI, 503, NumberOfRetries+1)
defer server.Close()
client, err := NewAnonymousClient(server.URL, "1.0")
c.Assert(err, jc.ErrorIsNil)
request, err := http.NewRequest("GET", server.URL+URI, nil)
c.Assert(err, jc.ErrorIsNil)

_, err = client.dispatchRequest(request)

Expand All @@ -124,7 +149,7 @@ func (suite *ClientSuite) TestClientDispatchRequestReturnsNonServerError(c *gc.C
c.Check(result, gc.IsNil)
}

func (suite *ClientSuite) TestClientdispatchRequestSignsRequest(c *gc.C) {
func (suite *ClientSuite) TestClientDispatchRequestSignsRequest(c *gc.C) {
URI := "/some/url/?param1=test"
expectedResult := "expected:result"
server := newSingleServingServer(URI, expectedResult, http.StatusOK, -1)
Expand All @@ -141,7 +166,7 @@ func (suite *ClientSuite) TestClientdispatchRequestSignsRequest(c *gc.C) {
c.Check((*server.requestHeader)["Authorization"][0], gc.Matches, "^OAuth .*")
}

func (suite *ClientSuite) TestClientdispatchRequestUsesConfiguredHTTPClient(c *gc.C) {
func (suite *ClientSuite) TestClientDispatchRequestUsesConfiguredHTTPClient(c *gc.C) {
URI := "/some/url/"

server := newSingleServingServer(URI, "", 0, 2*time.Second)
Expand Down Expand Up @@ -217,7 +242,7 @@ func (suite *ClientSuite) TestClientPostSendsRequestWithParams(c *gc.C) {

// extractFileContent extracts from the request built using 'requestContent',
// 'requestHeader' and 'requestURL', the file named 'filename'.
func extractFileContent(requestContent string, requestHeader *http.Header, requestURL string, filename string) ([]byte, error) {
func extractFileContent(requestContent string, requestHeader *http.Header, requestURL string, _ string) ([]byte, error) {
// Recreate the request from server.requestContent to use the parsing
// utility from the http package (http.Request.FormFile).
request, err := http.NewRequest("POST", requestURL, bytes.NewBufferString(requestContent))
Expand All @@ -229,7 +254,7 @@ func extractFileContent(requestContent string, requestHeader *http.Header, reque
if err != nil {
return nil, err
}
fileContent, err := ioutil.ReadAll(file)
fileContent, err := io.ReadAll(file)
if err != nil {
return nil, err
}
Expand Down
30 changes: 30 additions & 0 deletions testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,36 @@ func newFlakyServer(uri string, code int, nbFlakyResponses int) *flakyServer {
return &flakyServer{server, &nbRequests, &requests}
}

func newFlakyTLSServer(uri string, code int, nbFlakyResponses int) *flakyServer {
nbRequests := 0
requests := make([][]byte, nbFlakyResponses+1)
var server *httptest.Server

handler := func(writer http.ResponseWriter, request *http.Request) {
nbRequests += 1
body, err := readAndClose(request.Body)
if err != nil {
panic(err)
}
requests[nbRequests-1] = body
if request.URL.String() != uri {
errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String())
http.Error(writer, errorMsg, http.StatusNotFound)
} else if nbRequests <= nbFlakyResponses {
if code == http.StatusServiceUnavailable {
writer.Header().Set("Retry-After", "0")
}
writer.WriteHeader(code)
fmt.Fprint(writer, "flaky")
} else {
writer.WriteHeader(http.StatusOK)
fmt.Fprint(writer, "ok")
}
}
server = httptest.NewTLSServer(http.HandlerFunc(handler))
return &flakyServer{server, &nbRequests, &requests}
}

type simpleResponse struct {
status int
body string
Expand Down

0 comments on commit fb34589

Please sign in to comment.