Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Web Identity in STSProfileCredentialsProvider. #2831

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Release
*#
*.iml
tags
.vs
.vscode

# CI Artifacts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,47 @@ namespace Aws
*/
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration = std::chrono::minutes(60));

/**
* Use the provided profile name from the shared configuration file and a custom STS client.
*
* @param profileName The name of the profile in the shared configuration file.
* @param duration The duration, in minutes, of the role session, after which the credentials are expired.
* The value can range from 15 minutes up to the maximum session duration setting for the role. By default,
* the duration is set to 1 hour.
* Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That
* ensures the credentials do not expire between the time they're checked and the time they're returned to
* the user.
* If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only
* when they expire.
* @param stsClientFactory A factory function that creates an STSClient with specific credentials.
* Using the overload where the function returns a shared_ptr is preferred.
*
*/
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<Aws::STS::STSClient*(const AWSCredentials&)> &stsClientFactory);

/**
* Use the provided profile name from the shared configuration file and a custom STS client.
*
* @param profileName The name of the profile in the shared configuration file.
* @param duration The duration, in minutes, of the role session, after which the credentials are expired.
* The value can range from 15 minutes up to the maximum session duration setting for the role. By default,
* the duration is set to 1 hour.
* Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That
* ensures the credentials do not expire between the time they're checked and the time they're returned to
* the user.
* If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only
* when they expire.
* @param stsClientFactory A factory function that creates an STSClient with specific credentials.
*
*/
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<std::shared_ptr<Aws::STS::STSClient>(const AWSCredentials&)> &stsClientFactory);

/**
* Compatibility constructor to assist with overload resolution when passing nullptr for the client factory.
*
*/
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t);

/**
* Fetches the credentials set from STS following the rules defined in the shared configuration file.
*/
Expand All @@ -67,14 +106,16 @@ namespace Aws
* Returns the assumed role credentials or empty credentials on error.
*/
AWSCredentials GetCredentialsFromSTS(const AWSCredentials& credentials, const Aws::String& roleARN);
AWSCredentials GetCredentialsFromWebIdentity(const Config::Profile& profile);
private:
AWSCredentials GetCredentialsFromSTSInternal(const Aws::String& roleArn, Aws::STS::STSClient* client);
AWSCredentials GetCredentialsFromWebIdentityInternal(const Config::Profile& profile, Aws::STS::STSClient* client);

Aws::String m_profileName;
AWSCredentials m_credentials;
const std::chrono::minutes m_duration;
const std::chrono::milliseconds m_reloadFrequency;
std::function<Aws::STS::STSClient*(const AWSCredentials&)> m_stsClientFactory;
std::function<std::shared_ptr<Aws::STS::STSClient>(const AWSCredentials&)> m_stsClientFactory;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@

#include <aws/identity-management/auth/STSProfileCredentialsProvider.h>
#include <aws/sts/model/AssumeRoleRequest.h>
#include <aws/sts/model/AssumeRoleWithWebIdentityRequest.h>
#include <aws/sts/STSClient.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/Outcome.h>
#include <aws/core/utils/UUID.h>

#include <fstream>
#include <utility>

using namespace Aws;
using namespace Aws::Auth;

constexpr char CLASS_TAG[] = "STSProfileCredentialsProvider";

template <typename T>
struct NoOpDeleter
{
void operator()(T*) {}
};

STSProfileCredentialsProvider::STSProfileCredentialsProvider()
: STSProfileCredentialsProvider(GetConfigProfileName(), std::chrono::minutes(60)/*duration*/, nullptr/*stsClientFactory*/)
{
Expand All @@ -27,8 +35,24 @@ STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String&
{
}

STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t)
: m_profileName(profileName),
m_duration(duration),
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
m_stsClientFactory(nullptr)
{
}

STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<Aws::STS::STSClient*(const AWSCredentials&)> &stsClientFactory)
: m_profileName(profileName),
m_duration(duration),
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
m_stsClientFactory([=](const auto& credentials) {return std::shared_ptr<Aws::STS::STSClient>(stsClientFactory(credentials), NoOpDeleter<Aws::STS::STSClient>()); })
{
}

STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<std::shared_ptr<Aws::STS::STSClient> (const AWSCredentials&)>& stsClientFactory)
: m_profileName(profileName),
m_duration(duration),
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
m_stsClientFactory(stsClientFactory)
Expand Down Expand Up @@ -66,25 +90,27 @@ enum class ProfileState
Process,
SourceProfile,
SelfReferencing, // special case of SourceProfile.
RoleARNWebIdentity
};

/*
* A valid profile can be in one of the following states. Any other state is considered invalid.
+---------+-----------+-----------+--------------+
| | | | |
| Role | Source | Process | Static |
| ARN | Profile | | Credentials |
+------------------------------------------------+
| | | | |
| false | false | false | TRUE |
| | | | |
| false | false | TRUE | false |
| | | | |
| TRUE | TRUE | false | false |
| | | | |
| TRUE | TRUE | false | TRUE |
| | | | |
+---------+-----------+-----------+--------------+
+---------+-----------+-----------+--------------+------------+
| | | | | |
| Role | Source | Process | Static | Web |
| ARN | Profile | | Credentials | Identity |
+------------------------------------------------+------------+
| | | | | |
| false | false | false | TRUE | false |
| | | | | |
| false | false | TRUE | false | false |
| | | | | |
| TRUE | TRUE | false | false | false |
| | | | | |
| TRUE | TRUE | false | TRUE | false |
| | | | | |
| TRUE | false | false | false | TRUE |
+---------+-----------+-----------+--------------+------------+

*/
static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLevelProfile)
Expand All @@ -93,6 +119,7 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe
constexpr int PROCESS_CREDENTIALS = 2;
constexpr int SOURCE_PROFILE = 4;
constexpr int ROLE_ARN = 8;
constexpr int WEB_IDENTITY_TOKEN_FILE = 16;

int state = 0;

Expand All @@ -116,6 +143,11 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe
state += ROLE_ARN;
}

if (!profile.GetValue("web_identity_token_file").empty())
{
state += WEB_IDENTITY_TOKEN_FILE;
}

if (topLevelProfile)
{
switch(state)
Expand All @@ -133,6 +165,8 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe
}
// source-profile over-rule static credentials in top-level profiles (except when self-referencing)
return ProfileState::SourceProfile;
case 24: // role arn && web identity
return ProfileState::RoleARNWebIdentity;
default:
// All other cases are considered malformed configuration.
return ProfileState::Invalid;
Expand All @@ -154,6 +188,8 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe
return ProfileState::SelfReferencing;
}
return ProfileState::Static; // static credentials over-rule source-profile (except when self-referencing)
case 24: // role arn && web identity
return ProfileState::RoleARNWebIdentity;
default:
// All other cases are considered malformed configuration.
return ProfileState::Invalid;
Expand Down Expand Up @@ -280,10 +316,14 @@ void STSProfileCredentialsProvider::Reload()

while (sourceProfiles.size() > 1)
{
const auto profile = sourceProfiles.back()->second;
const auto& profile = sourceProfiles.back()->second;
sourceProfiles.pop_back();
AWSCredentials stsCreds;
if (profile.GetCredentialProcess().empty())
if (CheckProfile(profile, false /*topLevelProfile*/) == ProfileState::RoleARNWebIdentity)
{
stsCreds = GetCredentialsFromWebIdentity(profile);
}
else if (profile.GetCredentialProcess().empty())
{
assert(!profile.GetCredentials().IsEmpty());
stsCreds = profile.GetCredentials();
Expand All @@ -294,7 +334,7 @@ void STSProfileCredentialsProvider::Reload()
}

// get the role arn from the profile at the top of the stack (which hasn't been popped out yet)
const auto arn = sourceProfiles.back()->second.GetRoleArn();
const auto& arn = sourceProfiles.back()->second.GetRoleArn();
const auto& assumedCreds = GetCredentialsFromSTS(stsCreds, arn);
sourceProfiles.back()->second.SetCredentials(assumedCreds);
}
Expand Down Expand Up @@ -337,9 +377,68 @@ AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCre
{
using namespace Aws::STS::Model;
if (m_stsClientFactory) {
return GetCredentialsFromSTSInternal(roleArn, m_stsClientFactory(credentials));
auto client = m_stsClientFactory(credentials);
return GetCredentialsFromSTSInternal(roleArn, client.get());
}

Aws::STS::STSClient stsClient {credentials};
return GetCredentialsFromSTSInternal(roleArn, &stsClient);
}

AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromWebIdentityInternal(const Config::Profile& profile, Aws::STS::STSClient* client)
{
Aws::String roleSessionName = profile.GetValue("role_session_name");
if (roleSessionName.empty())
{
roleSessionName = Aws::Utils::UUID::PseudoRandomUUID();
}

Aws::String token;
{
auto& tokenPath = profile.GetValue("web_identity_token_file");
Aws::IFStream tokenFile(tokenPath);
if (tokenFile) {
token = Aws::String(
(std::istreambuf_iterator<char>(tokenFile)),
std::istreambuf_iterator<char>());
}
else {
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Can't open token file: " << tokenPath);
return {};
}
}

using namespace Aws::STS::Model;
AssumeRoleWithWebIdentityRequest assumeRoleRequest;
assumeRoleRequest
.WithRoleArn(profile.GetRoleArn())
.WithRoleSessionName(roleSessionName)
.WithWebIdentityToken(token)
.WithDurationSeconds(static_cast<int>(std::chrono::seconds(m_duration).count()));
auto outcome = client->AssumeRoleWithWebIdentity(assumeRoleRequest);
if (outcome.IsSuccess())
{
const auto& modelCredentials = outcome.GetResult().GetCredentials();
return {modelCredentials.GetAccessKeyId(),
modelCredentials.GetSecretAccessKey(),
modelCredentials.GetSessionToken(),
modelCredentials.GetExpiration()};
}
else
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Failed to assume role " << profile.GetRoleArn());
}
return {};
}

AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromWebIdentity(const Config::Profile& profile)
{
using namespace Aws::STS::Model;
if (m_stsClientFactory) {
auto client = m_stsClientFactory({});
return GetCredentialsFromWebIdentityInternal(profile, client.get());
}

Aws::STS::STSClient stsClient{AWSCredentials{}};
return GetCredentialsFromWebIdentityInternal(profile, &stsClient);
}
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutRoleARN)

STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
return (STSClient*)nullptr;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand Down Expand Up @@ -383,7 +383,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutSourceProfile)

STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
return (STSClient*)nullptr;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand All @@ -409,7 +409,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithNonExistentSourceProfile

STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
return (STSClient*)nullptr;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand Down Expand Up @@ -556,7 +556,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile

Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
std::shared_ptr<MockSTSClient> stsClient;

int stsCallCounter = 0;

Expand All @@ -572,9 +572,9 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile
EXPECT_STREQ(ACCESS_KEY_ID_2, creds.GetAWSAccessKeyId().c_str());
EXPECT_STREQ(SECRET_ACCESS_KEY_ID_2, creds.GetAWSSecretKey().c_str());
}
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient = Aws::MakeShared<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
return stsClient;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand Down Expand Up @@ -614,7 +614,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference

STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
return (STSClient*)nullptr;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand Down
Loading