diff --git a/client.go b/client.go index eab8d87..e3fe3a7 100644 --- a/client.go +++ b/client.go @@ -31,8 +31,9 @@ const ( // Client represents a way to communicating with a MAAS API instance. // It is stateless, so it can have concurrent requests in progress. type Client struct { - APIURL *url.URL - Signer OAuthSigner + APIURL *url.URL + Signer OAuthSigner + HTTPClient *http.Client } // ServerError is an http error (or at least, a non-2xx result) received from @@ -106,7 +107,11 @@ func (client Client) dispatchRequest(request *http.Request) ([]byte, error) { func (client Client) dispatchSingleRequest(request *http.Request) ([]byte, error) { client.Signer.OAuthSign(request) - httpClient := http.Client{} + httpClient := &http.Client{} + if client.HTTPClient != nil { + httpClient = client.HTTPClient + } + // See https://code.google.com/p/go/issues/detail?id=4677 // We need to force the connection to close each time so that we don't // hit the above Go bug. diff --git a/client_test.go b/client_test.go index 73b8f41..dea6ab4 100644 --- a/client_test.go +++ b/client_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/url" "strings" + "time" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" @@ -38,7 +39,7 @@ func (*ClientSuite) TestReadAndCloseReturnsContents(c *gc.C) { func (suite *ClientSuite) TestClientdispatchRequestReturnsServerError(c *gc.C) { URI := "/some/url/?param1=test" expectedResult := "expected:result" - server := newSingleServingServer(URI, expectedResult, http.StatusBadRequest) + server := newSingleServingServer(URI, expectedResult, http.StatusBadRequest, -1) defer server.Close() client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) @@ -126,7 +127,7 @@ func (suite *ClientSuite) TestClientDispatchRequestReturnsNonServerError(c *gc.C func (suite *ClientSuite) TestClientdispatchRequestSignsRequest(c *gc.C) { URI := "/some/url/?param1=test" expectedResult := "expected:result" - server := newSingleServingServer(URI, expectedResult, http.StatusOK) + server := newSingleServingServer(URI, expectedResult, http.StatusOK, -1) defer server.Close() client, err := NewAuthenticatedClient(server.URL, "the:api:key") c.Assert(err, jc.ErrorIsNil) @@ -140,13 +141,32 @@ func (suite *ClientSuite) TestClientdispatchRequestSignsRequest(c *gc.C) { c.Check((*server.requestHeader)["Authorization"][0], gc.Matches, "^OAuth .*") } +func (suite *ClientSuite) TestClientdispatchRequestUsesConfiguredHTTPClient(c *gc.C) { + URI := "/some/url/" + + server := newSingleServingServer(URI, "", 0, 2*time.Second) + defer server.Close() + + client, err := NewAnonymousClient(server.URL, "2.0") + c.Assert(err, jc.ErrorIsNil) + + client.HTTPClient = &http.Client{Timeout: time.Second} + + request, err := http.NewRequest("GET", server.URL+URI, nil) + c.Assert(err, jc.ErrorIsNil) + + _, err = client.dispatchRequest(request) + + c.Assert(err, gc.ErrorMatches, `Get "http://127.0.0.1:\d+/some/url/": context deadline exceeded \(Client\.Timeout exceeded while awaiting headers\)`) +} + func (suite *ClientSuite) TestClientGetFormatsGetParameters(c *gc.C) { URI, err := url.Parse("/some/url") c.Assert(err, jc.ErrorIsNil) expectedResult := "expected:result" params := url.Values{"test": {"123"}} fullURI := URI.String() + "?test=123" - server := newSingleServingServer(fullURI, expectedResult, http.StatusOK) + server := newSingleServingServer(fullURI, expectedResult, http.StatusOK, -1) defer server.Close() client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) @@ -162,7 +182,7 @@ func (suite *ClientSuite) TestClientGetFormatsOperationAsGetParameter(c *gc.C) { c.Assert(err, jc.ErrorIsNil) expectedResult := "expected:result" fullURI := URI.String() + "?op=list" - server := newSingleServingServer(fullURI, expectedResult, http.StatusOK) + server := newSingleServingServer(fullURI, expectedResult, http.StatusOK, -1) defer server.Close() client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) @@ -179,7 +199,7 @@ func (suite *ClientSuite) TestClientPostSendsRequestWithParams(c *gc.C) { expectedResult := "expected:result" fullURI := URI.String() + "?op=list" params := url.Values{"test": {"123"}} - server := newSingleServingServer(fullURI, expectedResult, http.StatusOK) + server := newSingleServingServer(fullURI, expectedResult, http.StatusOK, -1) defer server.Close() client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) @@ -221,7 +241,7 @@ func (suite *ClientSuite) TestClientPostSendsMultipartRequest(c *gc.C) { c.Assert(err, jc.ErrorIsNil) expectedResult := "expected:result" fullURI := URI.String() + "?op=add" - server := newSingleServingServer(fullURI, expectedResult, http.StatusOK) + server := newSingleServingServer(fullURI, expectedResult, http.StatusOK, -1) defer server.Close() client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) @@ -242,7 +262,7 @@ func (suite *ClientSuite) TestClientPutSendsRequest(c *gc.C) { c.Assert(err, jc.ErrorIsNil) expectedResult := "expected:result" params := url.Values{"test": {"123"}} - server := newSingleServingServer(URI.String(), expectedResult, http.StatusOK) + server := newSingleServingServer(URI.String(), expectedResult, http.StatusOK, -1) defer server.Close() client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) @@ -258,7 +278,7 @@ func (suite *ClientSuite) TestClientDeleteSendsRequest(c *gc.C) { URI, err := url.Parse("/some/url") c.Assert(err, jc.ErrorIsNil) expectedResult := "expected:result" - server := newSingleServingServer(URI.String(), expectedResult, http.StatusOK) + server := newSingleServingServer(URI.String(), expectedResult, http.StatusOK, -1) defer server.Close() client, err := NewAnonymousClient(server.URL, "1.0") c.Assert(err, jc.ErrorIsNil) diff --git a/controller.go b/controller.go index 71821ed..e13b17d 100644 --- a/controller.go +++ b/controller.go @@ -41,8 +41,9 @@ var ( // ControllerArgs is an argument struct for passing the required parameters // to the NewController method. type ControllerArgs struct { - BaseURL string - APIKey string + BaseURL string + APIKey string + HTTPClient *http.Client } // NewController creates an authenticated client to the MAAS API, and @@ -59,7 +60,7 @@ func NewController(args ControllerArgs) (Controller, error) { if !supportedVersion(apiVersion) { return nil, NewUnsupportedVersionError("version %s", apiVersion) } - return newControllerWithVersion(base, apiVersion, args.APIKey) + return newControllerWithVersion(base, apiVersion, args.APIKey, args.HTTPClient) } return newControllerUnknownVersion(args) } @@ -73,7 +74,7 @@ func supportedVersion(value string) bool { return false } -func newControllerWithVersion(baseURL, apiVersion, apiKey string) (Controller, error) { +func newControllerWithVersion(baseURL, apiVersion, apiKey string, httpClient *http.Client) (Controller, error) { major, minor, err := version.ParseMajorMinor(apiVersion) // We should not get an error here. See the test. if err != nil { @@ -89,6 +90,8 @@ func newControllerWithVersion(baseURL, apiVersion, apiKey string) (Controller, e // is an unexpected error and return now. return nil, NewUnexpectedError(err) } + + client.HTTPClient = httpClient controllerVersion := version.Number{ Major: major, Minor: minor, @@ -111,7 +114,7 @@ func newControllerUnknownVersion(args ControllerArgs) (Controller, error) { // some time in the future, we will try the most up to date version and then // work our way backwards. for _, apiVersion := range supportedAPIVersions { - controller, err := newControllerWithVersion(args.BaseURL, apiVersion, args.APIKey) + controller, err := newControllerWithVersion(args.BaseURL, apiVersion, args.APIKey, args.HTTPClient) switch { case err == nil: return controller, nil diff --git a/testing.go b/testing.go index 54d67aa..908a1a5 100644 --- a/testing.go +++ b/testing.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "time" ) type singleServingServer struct { @@ -18,7 +19,7 @@ type singleServingServer struct { // newSingleServingServer creates a single-serving test http server which will // return only one response as defined by the passed arguments. -func newSingleServingServer(uri string, response string, code int) *singleServingServer { +func newSingleServingServer(uri string, response string, code int, delay time.Duration) *singleServingServer { var requestContent string var requestHeader http.Header var requested bool @@ -36,6 +37,7 @@ func newSingleServingServer(uri string, response string, code int) *singleServin errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String()) http.Error(writer, errorMsg, http.StatusNotFound) } else { + time.Sleep(delay) writer.WriteHeader(code) fmt.Fprint(writer, response) }