diff --git a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl index 36e40cbd0a6d..ed3560a46ed0 100644 --- a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl +++ b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl @@ -43,6 +43,9 @@ -define(METADATA_TOKEN, "X-aws-ec2-metadata-token"). +-define(LINEAR_BACK_OFF_MILLIS, 1000). +-define(MAX_RETRIES, 30). + -type access_key() :: nonempty_string(). -type secret_access_key() :: nonempty_string(). -type expiration() :: calendar:datetime() | undefined. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 5ba4e61b5037..300eb20a8332 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -17,7 +17,8 @@ has_credentials/0, set_region/1, ensure_imdsv2_token_valid/0, - api_get_request/2]). + api_get_request/2, + api_get_request_with_retries/4]). %% gen-server exports -export([start_link/0, @@ -543,10 +544,19 @@ ensure_credentials_valid() -> %% @end api_get_request(Service, Path) -> rabbit_log:debug("Invoking AWS request {Service: ~p; Path: ~p}...", [Service, Path]), + api_get_request_with_retries(Service, Path, ?LINEAR_BACK_OFF_MILLIS, ?MAX_RETRIES). + +api_get_request_with_retries(Service, Path, SleepTime, Retries) -> ensure_credentials_valid(), case get(Service, Path) of {ok, {_Headers, Payload}} -> rabbit_log:debug("AWS request: ~s~nResponse: ~p", [Path, Payload]), - {ok, Payload}; + {ok, Payload}; {error, {credentials, _}} -> {error, credentials}; - {error, Message, _} -> {error, Message} + {error, Message, _} -> + case Retries of + 0 -> {error, Message}; + _ -> rabbit_log:warning("Encountered error when calling api ~p, retries remaining ~p", [Message, Retries]), + timer:sleep(SleepTime), + api_get_request_with_retries(Service, Path, SleepTime, Retries - 1) + end end. diff --git a/deps/rabbitmq_aws/test/src/rabbitmq_aws_tests.erl b/deps/rabbitmq_aws/test/src/rabbitmq_aws_tests.erl index b73b091e5ba3..9f1774b87e0f 100644 --- a/deps/rabbitmq_aws/test/src/rabbitmq_aws_tests.erl +++ b/deps/rabbitmq_aws/test/src/rabbitmq_aws_tests.erl @@ -468,17 +468,64 @@ api_get_request_test_() -> ?assertEqual({ok, [{"data","value"}]}, Result), meck:validate(httpc) end + } + ] + }. + +api_get_request_with_retries_test_() -> + { + foreach, + fun () -> + meck:new(httpc, []), + meck:new(rabbitmq_aws_config, []), + [httpc, rabbitmq_aws_config] + end, + fun meck:unload/1, + [ + {"AWS service API request succeeded", + fun() -> + State = #state{access_key = "ExpiredKey", + secret_access_key = "ExpiredAccessKey", + region = "us-east-1", + expiration = {{3016, 4, 1}, {12, 0, 0}}}, + meck:expect(httpc, request, 4, {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"data\": \"value\"}"}}), + {ok, Pid} = rabbitmq_aws:start_link(), + rabbitmq_aws:set_region("us-east-1"), + rabbitmq_aws:set_credentials(State), + Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 1, 1), + ok = gen_server:stop(Pid), + ?assertEqual({ok, [{"data","value"}]}, Result), + meck:validate(httpc) + end }, {"AWS service API request failed - credentials", fun() -> meck:expect(rabbitmq_aws_config, credentials, 0, {error, undefined}), {ok, Pid} = rabbitmq_aws:start_link(), rabbitmq_aws:set_region("us-east-1"), - Result = rabbitmq_aws:api_get_request("AWS", "API"), + Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 1, 2), ok = gen_server:stop(Pid), ?assertEqual({error, credentials}, Result) end }, + {"AWS service API request failed credentials - should retry", + fun() -> + State = #state{access_key = "ExpiredKey", + secret_access_key = "ExpiredAccessKey", + region = "us-east-1", + expiration = {{3016, 4, 1}, {12, 0, 0}}}, + meck:expect(httpc, request, 4, meck:seq([ + {error, undefined}, + {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"data\": \"value\"}"}}])), + {ok, Pid} = rabbitmq_aws:start_link(), + rabbitmq_aws:set_region("us-east-1"), + rabbitmq_aws:set_credentials(State), + Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 1, 1), + ok = gen_server:stop(Pid), + ?assertEqual({ok, [{"data","value"}]}, Result), + meck:validate(httpc) + end + }, {"AWS service API request failed - API error", fun() -> State = #state{access_key = "ExpiredKey", @@ -489,7 +536,7 @@ api_get_request_test_() -> {ok, Pid} = rabbitmq_aws:start_link(), rabbitmq_aws:set_region("us-east-1"), rabbitmq_aws:set_credentials(State), - Result = rabbitmq_aws:api_get_request("AWS", "API"), + Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 1, 1), ok = gen_server:stop(Pid), ?assertEqual({error, "invalid input"}, Result), meck:validate(httpc)