diff --git a/support/http/httptest/client_expectation.go b/support/http/httptest/client_expectation.go index f056ae7ebf..e19bdc334f 100644 --- a/support/http/httptest/client_expectation.go +++ b/support/http/httptest/client_expectation.go @@ -1,6 +1,7 @@ package httptest import ( + "fmt" "net/http" "net/url" "strconv" @@ -85,6 +86,37 @@ func (ce *ClientExpectation) ReturnStringWithHeader( return ce.Return(httpmock.ResponderFromResponse(&cResp)) } +// ReturnMultipleResults registers multiple sequential responses for a given client expectation. +// Useful for testing retries +func (ce *ClientExpectation) ReturnMultipleResults(responseSets []ResponseData) *ClientExpectation { + var allResponses []httpmock.Responder + for _, response := range responseSets { + resp := http.Response{ + Status: strconv.Itoa(response.Status), + StatusCode: response.Status, + Body: httpmock.NewRespBodyFromString(response.Body), + Header: response.Header, + } + allResponses = append(allResponses, httpmock.ResponderFromResponse(&resp)) + } + responseIndex := 0 + ce.Client.MockTransport.RegisterResponder( + ce.Method, + ce.URL, + func(req *http.Request) (*http.Response, error) { + if responseIndex >= len(allResponses) { + panic(fmt.Errorf("no responses available")) + } + + resp := allResponses[responseIndex] + responseIndex++ + return resp(req) + }, + ) + + return ce +} + // ReturnJSONWithHeader causes this expectation to resolve to a json-based body with the provided // status code and response header. Panics when the provided body cannot be encoded to JSON. func (ce *ClientExpectation) ReturnJSONWithHeader( diff --git a/support/http/httptest/main.go b/support/http/httptest/main.go index 47a00b1991..18b986ba1b 100644 --- a/support/http/httptest/main.go +++ b/support/http/httptest/main.go @@ -67,3 +67,9 @@ func NewServer(t *testing.T, handler http.Handler) *Server { Expect: httpexpect.New(t, server.URL), } } + +type ResponseData struct { + Status int + Body string + Header http.Header +} diff --git a/utils/apiclient/client.go b/utils/apiclient/client.go new file mode 100644 index 0000000000..82501df3fc --- /dev/null +++ b/utils/apiclient/client.go @@ -0,0 +1,97 @@ +package apiclient + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" + + "github.com/stellar/go/support/log" +) + +const ( + defaultMaxRetries = 5 + defaultInitialBackoffTime = 1 * time.Second +) + +func isRetryableStatusCode(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable +} + +func (c *APIClient) GetURL(endpoint string, queryParams url.Values) string { + return fmt.Sprintf("%s/%s?%s", c.BaseURL, endpoint, queryParams.Encode()) +} + +func (c *APIClient) CallAPI(reqParams RequestParams) (interface{}, error) { + if reqParams.QueryParams == nil { + reqParams.QueryParams = url.Values{} + } + + if reqParams.Headers == nil { + reqParams.Headers = map[string]interface{}{} + } + + if c.MaxRetries == 0 { + c.MaxRetries = defaultMaxRetries + } + + if c.InitialBackoffTime == 0 { + c.InitialBackoffTime = defaultInitialBackoffTime + } + + if reqParams.Endpoint == "" { + return nil, fmt.Errorf("Please set endpoint to query") + } + + url := c.GetURL(reqParams.Endpoint, reqParams.QueryParams) + reqBody, err := CreateRequestBody(reqParams.RequestType, url) + if err != nil { + return nil, fmt.Errorf("http request creation failed") + } + + SetAuthHeaders(reqBody, c.AuthType, c.AuthHeaders) + SetHeaders(reqBody, reqParams.Headers) + client := c.HTTP + if client == nil { + client = &http.Client{} + } + + var result interface{} + retries := 0 + + for retries <= c.MaxRetries { + resp, err := client.Do(reqBody) + if err != nil { + return nil, fmt.Errorf("http request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal JSON: %w", err) + } + + return result, nil + } else if isRetryableStatusCode(resp.StatusCode) { + retries++ + backoffDuration := c.InitialBackoffTime * time.Duration(1<