From 5dcfb0f0c43ef5c474599fe9163c22bf12eed82e Mon Sep 17 00:00:00 2001 From: Theodore Tsirpanis Date: Fri, 26 Jan 2024 23:49:42 +0200 Subject: [PATCH 1/4] Add constructor overload to `STSProfileCredentialsProvider` where the client factory returns a shared pointer. --- .../auth/STSProfileCredentialsProvider.h | 41 ++++++++++++++++++- .../auth/STSProfileCredentialsProvider.cpp | 25 ++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h b/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h index 4b90bb01ec7..524554bf14a 100644 --- a/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h +++ b/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h @@ -52,8 +52,47 @@ namespace Aws */ STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration = std::chrono::minutes(60)); + /** + * Use the provided profile name from the shared configuration file and a custom STS client. + * + * @param profileName The name of the profile in the shared configuration file. + * @param duration The duration, in minutes, of the role session, after which the credentials are expired. + * The value can range from 15 minutes up to the maximum session duration setting for the role. By default, + * the duration is set to 1 hour. + * Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That + * ensures the credentials do not expire between the time they're checked and the time they're returned to + * the user. + * If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only + * when they expire. + * @param stsClientFactory A factory function that creates an STSClient with specific credentials. + * Using the overload where the function returns a shared_ptr is preferred. + * + */ STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function &stsClientFactory); + /** + * Use the provided profile name from the shared configuration file and a custom STS client. + * + * @param profileName The name of the profile in the shared configuration file. + * @param duration The duration, in minutes, of the role session, after which the credentials are expired. + * The value can range from 15 minutes up to the maximum session duration setting for the role. By default, + * the duration is set to 1 hour. + * Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That + * ensures the credentials do not expire between the time they're checked and the time they're returned to + * the user. + * If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only + * when they expire. + * @param stsClientFactory A factory function that creates an STSClient with specific credentials. + * + */ + STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function(const AWSCredentials&)> &stsClientFactory); + + /** + * Compatibility constructor to assist with overload resolution when passing nullptr for the client factory. + * + */ + STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t); + /** * Fetches the credentials set from STS following the rules defined in the shared configuration file. */ @@ -74,7 +113,7 @@ namespace Aws AWSCredentials m_credentials; const std::chrono::minutes m_duration; const std::chrono::milliseconds m_reloadFrequency; - std::function m_stsClientFactory; + std::function(const AWSCredentials&)> m_stsClientFactory; }; } } diff --git a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp index fd82b678fba..a362eccd541 100644 --- a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp +++ b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp @@ -17,6 +17,12 @@ using namespace Aws::Auth; constexpr char CLASS_TAG[] = "STSProfileCredentialsProvider"; +template +struct NoOpDeleter +{ + void operator()(T*) {} +}; + STSProfileCredentialsProvider::STSProfileCredentialsProvider() : STSProfileCredentialsProvider(GetConfigProfileName(), std::chrono::minutes(60)/*duration*/, nullptr/*stsClientFactory*/) { @@ -27,8 +33,24 @@ STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& { } +STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t) + : m_profileName(profileName), + m_duration(duration), + m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), + m_stsClientFactory(nullptr) +{ +} + STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function &stsClientFactory) : m_profileName(profileName), + m_duration(duration), + m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), + m_stsClientFactory([=](const auto& credentials) {return std::shared_ptr(stsClientFactory(credentials), NoOpDeleter()); }) +{ +} + +STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function (const AWSCredentials&)>& stsClientFactory) + : m_profileName(profileName), m_duration(duration), m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), m_stsClientFactory(stsClientFactory) @@ -337,7 +359,8 @@ AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCre { using namespace Aws::STS::Model; if (m_stsClientFactory) { - return GetCredentialsFromSTSInternal(roleArn, m_stsClientFactory(credentials)); + auto client = m_stsClientFactory(credentials); + return GetCredentialsFromSTSInternal(roleArn, client.get()); } Aws::STS::STSClient stsClient {credentials}; From 174e23d338385dc5115fb811f384a701905ab3d4 Mon Sep 17 00:00:00 2001 From: Theodore Tsirpanis Date: Fri, 26 Jan 2024 23:23:45 +0200 Subject: [PATCH 2/4] Update tests. --- .../auth/STSProfileCredentialsProviderTest.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp b/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp index 197535a6a2e..cf234107db0 100644 --- a/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp +++ b/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp @@ -313,7 +313,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutRoleARN) STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -383,7 +383,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutSourceProfile) STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -409,7 +409,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithNonExistentSourceProfile STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -556,7 +556,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile Model::AssumeRoleResult mockResult; mockResult.SetCredentials(stsCredentials); - Aws::UniquePtr stsClient; + std::shared_ptr stsClient; int stsCallCounter = 0; @@ -572,9 +572,9 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile EXPECT_STREQ(ACCESS_KEY_ID_2, creds.GetAWSAccessKeyId().c_str()); EXPECT_STREQ(SECRET_ACCESS_KEY_ID_2, creds.GetAWSSecretKey().c_str()); } - stsClient = Aws::MakeUnique(CLASS_TAG, creds); + stsClient = Aws::MakeShared(CLASS_TAG, creds); stsClient->MockAssumeRole(mockResult); - return stsClient.get(); + return stsClient; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -614,7 +614,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); From 592b3dfb227d885f5fd06ec4815fb58434cd4962 Mon Sep 17 00:00:00 2001 From: Theodore Tsirpanis Date: Fri, 26 Jan 2024 23:49:13 +0200 Subject: [PATCH 3/4] Ignore the `.vs` folder. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index e968c3ad323..20bb0ab57fb 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ Release *# *.iml tags +.vs .vscode # CI Artifacts From f4ef98abaa7928709f9f0a29b466fc2fcfaba560 Mon Sep 17 00:00:00 2001 From: Theodore Tsirpanis Date: Sat, 27 Jan 2024 01:45:26 +0200 Subject: [PATCH 4/4] Support Web Identity in `STSProfileCredentialsProvider`. --- .../auth/STSProfileCredentialsProvider.h | 2 + .../auth/STSProfileCredentialsProvider.cpp | 112 +++++++++++++++--- 2 files changed, 96 insertions(+), 18 deletions(-) diff --git a/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h b/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h index 524554bf14a..5790efddb33 100644 --- a/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h +++ b/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h @@ -106,8 +106,10 @@ namespace Aws * Returns the assumed role credentials or empty credentials on error. */ AWSCredentials GetCredentialsFromSTS(const AWSCredentials& credentials, const Aws::String& roleARN); + AWSCredentials GetCredentialsFromWebIdentity(const Config::Profile& profile); private: AWSCredentials GetCredentialsFromSTSInternal(const Aws::String& roleArn, Aws::STS::STSClient* client); + AWSCredentials GetCredentialsFromWebIdentityInternal(const Config::Profile& profile, Aws::STS::STSClient* client); Aws::String m_profileName; AWSCredentials m_credentials; diff --git a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp index a362eccd541..1908e522bb3 100644 --- a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp +++ b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp @@ -5,11 +5,13 @@ #include #include +#include #include #include #include #include +#include #include using namespace Aws; @@ -88,25 +90,27 @@ enum class ProfileState Process, SourceProfile, SelfReferencing, // special case of SourceProfile. + RoleARNWebIdentity }; /* * A valid profile can be in one of the following states. Any other state is considered invalid. - +---------+-----------+-----------+--------------+ -| | | | | -| Role | Source | Process | Static | -| ARN | Profile | | Credentials | -+------------------------------------------------+ -| | | | | -| false | false | false | TRUE | -| | | | | -| false | false | TRUE | false | -| | | | | -| TRUE | TRUE | false | false | -| | | | | -| TRUE | TRUE | false | TRUE | -| | | | | -+---------+-----------+-----------+--------------+ ++---------+-----------+-----------+--------------+------------+ +| | | | | | +| Role | Source | Process | Static | Web | +| ARN | Profile | | Credentials | Identity | ++------------------------------------------------+------------+ +| | | | | | +| false | false | false | TRUE | false | +| | | | | | +| false | false | TRUE | false | false | +| | | | | | +| TRUE | TRUE | false | false | false | +| | | | | | +| TRUE | TRUE | false | TRUE | false | +| | | | | | +| TRUE | false | false | false | TRUE | ++---------+-----------+-----------+--------------+------------+ */ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLevelProfile) @@ -115,6 +119,7 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe constexpr int PROCESS_CREDENTIALS = 2; constexpr int SOURCE_PROFILE = 4; constexpr int ROLE_ARN = 8; + constexpr int WEB_IDENTITY_TOKEN_FILE = 16; int state = 0; @@ -138,6 +143,11 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe state += ROLE_ARN; } + if (!profile.GetValue("web_identity_token_file").empty()) + { + state += WEB_IDENTITY_TOKEN_FILE; + } + if (topLevelProfile) { switch(state) @@ -155,6 +165,8 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe } // source-profile over-rule static credentials in top-level profiles (except when self-referencing) return ProfileState::SourceProfile; + case 24: // role arn && web identity + return ProfileState::RoleARNWebIdentity; default: // All other cases are considered malformed configuration. return ProfileState::Invalid; @@ -176,6 +188,8 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe return ProfileState::SelfReferencing; } return ProfileState::Static; // static credentials over-rule source-profile (except when self-referencing) + case 24: // role arn && web identity + return ProfileState::RoleARNWebIdentity; default: // All other cases are considered malformed configuration. return ProfileState::Invalid; @@ -302,10 +316,14 @@ void STSProfileCredentialsProvider::Reload() while (sourceProfiles.size() > 1) { - const auto profile = sourceProfiles.back()->second; + const auto& profile = sourceProfiles.back()->second; sourceProfiles.pop_back(); AWSCredentials stsCreds; - if (profile.GetCredentialProcess().empty()) + if (CheckProfile(profile, false /*topLevelProfile*/) == ProfileState::RoleARNWebIdentity) + { + stsCreds = GetCredentialsFromWebIdentity(profile); + } + else if (profile.GetCredentialProcess().empty()) { assert(!profile.GetCredentials().IsEmpty()); stsCreds = profile.GetCredentials(); @@ -316,7 +334,7 @@ void STSProfileCredentialsProvider::Reload() } // get the role arn from the profile at the top of the stack (which hasn't been popped out yet) - const auto arn = sourceProfiles.back()->second.GetRoleArn(); + const auto& arn = sourceProfiles.back()->second.GetRoleArn(); const auto& assumedCreds = GetCredentialsFromSTS(stsCreds, arn); sourceProfiles.back()->second.SetCredentials(assumedCreds); } @@ -366,3 +384,61 @@ AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCre Aws::STS::STSClient stsClient {credentials}; return GetCredentialsFromSTSInternal(roleArn, &stsClient); } + +AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromWebIdentityInternal(const Config::Profile& profile, Aws::STS::STSClient* client) +{ + Aws::String roleSessionName = profile.GetValue("role_session_name"); + if (roleSessionName.empty()) + { + roleSessionName = Aws::Utils::UUID::PseudoRandomUUID(); + } + + Aws::String token; + { + auto& tokenPath = profile.GetValue("web_identity_token_file"); + Aws::IFStream tokenFile(tokenPath); + if (tokenFile) { + token = Aws::String( + (std::istreambuf_iterator(tokenFile)), + std::istreambuf_iterator()); + } + else { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Can't open token file: " << tokenPath); + return {}; + } + } + + using namespace Aws::STS::Model; + AssumeRoleWithWebIdentityRequest assumeRoleRequest; + assumeRoleRequest + .WithRoleArn(profile.GetRoleArn()) + .WithRoleSessionName(roleSessionName) + .WithWebIdentityToken(token) + .WithDurationSeconds(static_cast(std::chrono::seconds(m_duration).count())); + auto outcome = client->AssumeRoleWithWebIdentity(assumeRoleRequest); + if (outcome.IsSuccess()) + { + const auto& modelCredentials = outcome.GetResult().GetCredentials(); + return {modelCredentials.GetAccessKeyId(), + modelCredentials.GetSecretAccessKey(), + modelCredentials.GetSessionToken(), + modelCredentials.GetExpiration()}; + } + else + { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Failed to assume role " << profile.GetRoleArn()); + } + return {}; +} + +AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromWebIdentity(const Config::Profile& profile) +{ + using namespace Aws::STS::Model; + if (m_stsClientFactory) { + auto client = m_stsClientFactory({}); + return GetCredentialsFromWebIdentityInternal(profile, client.get()); + } + + Aws::STS::STSClient stsClient{AWSCredentials{}}; + return GetCredentialsFromWebIdentityInternal(profile, &stsClient); +}