Skip to content

Commit

Permalink
add external id config for role assumption
Browse files Browse the repository at this point in the history
  • Loading branch information
binarymatt committed Dec 6, 2024
1 parent dd600c1 commit 10da12d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
2 changes: 2 additions & 0 deletions internal/aws/awsutil/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ type AWSSessionSettings struct {
ResourceARN string `mapstructure:"resource_arn"`
// IAM role to upload segments to a different account.
RoleARN string `mapstructure:"role_arn"`
// External ID to verify third party role assumption
ExternalID string `mapstructure:"external_id"`
}

func CreateDefaultSessionConfig() AWSSessionSettings {
Expand Down
32 changes: 19 additions & 13 deletions internal/aws/awsutil/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
)

type ConnAttr interface {
newAWSSession(logger *zap.Logger, roleArn string, region string) (*session.Session, error)
newAWSSession(logger *zap.Logger, roleArn string, externalID string, region string) (*session.Session, error)
getEC2Region(s *session.Session) (string, error)
}

Expand Down Expand Up @@ -145,7 +145,7 @@ func GetAWSConfigSession(logger *zap.Logger, cn ConnAttr, cfg *AWSSessionSetting
logger.Error(msg)
return nil, nil, awserr.New("NoAwsRegion", msg, nil)
}
s, err = cn.newAWSSession(logger, cfg.RoleARN, awsRegion)
s, err = cn.newAWSSession(logger, cfg.RoleARN, cfg.ExternalID, awsRegion)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -193,7 +193,7 @@ func ProxyServerTransport(logger *zap.Logger, config *AWSSessionSettings) (*http
return transport, nil
}

func (c *Conn) newAWSSession(logger *zap.Logger, roleArn string, region string) (*session.Session, error) {
func (c *Conn) newAWSSession(logger *zap.Logger, roleArn, externalID string, region string) (*session.Session, error) {
var s *session.Session
var err error
if roleArn == "" {
Expand All @@ -202,7 +202,7 @@ func (c *Conn) newAWSSession(logger *zap.Logger, roleArn string, region string)
return s, err
}
} else {
stsCreds, _ := getSTSCreds(logger, region, roleArn)
stsCreds, _ := getSTSCreds(logger, region, roleArn, externalID)

s, err = session.NewSession(&aws.Config{
Credentials: stsCreds,
Expand All @@ -218,13 +218,13 @@ func (c *Conn) newAWSSession(logger *zap.Logger, roleArn string, region string)
// getSTSCreds gets STS credentials from regional endpoint. ErrCodeRegionDisabledException is received if the
// STS regional endpoint is disabled. In this case STS credentials are fetched from STS primary regional endpoint
// in the respective AWS partition.
func getSTSCreds(logger *zap.Logger, region string, roleArn string) (*credentials.Credentials, error) {
func getSTSCreds(logger *zap.Logger, region string, roleArn, externalID string) (*credentials.Credentials, error) {
t, err := GetDefaultSession(logger)
if err != nil {
return nil, err
}

stsCred := getSTSCredsFromRegionEndpoint(logger, t, region, roleArn)
stsCred := getSTSCredsFromRegionEndpoint(logger, t, region, roleArn, externalID)
// Make explicit call to fetch credentials.
_, err = stsCred.Get()
if err != nil {
Expand All @@ -234,7 +234,7 @@ func getSTSCreds(logger *zap.Logger, region string, roleArn string) (*credential

if awsErr.Code() == sts.ErrCodeRegionDisabledException {
logger.Error("Region ", zap.String("region", region), zap.Error(awsErr))
stsCred = getSTSCredsFromPrimaryRegionEndpoint(logger, t, roleArn, region)
stsCred = getSTSCredsFromPrimaryRegionEndpoint(logger, t, roleArn, externalID, region)
}
}
}
Expand All @@ -245,7 +245,7 @@ func getSTSCreds(logger *zap.Logger, region string, roleArn string) (*credential
// AWS STS recommends that you provide both the Region and endpoint when you make calls to a Regional endpoint.
// Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code
func getSTSCredsFromRegionEndpoint(logger *zap.Logger, sess *session.Session, region string,
roleArn string,
roleArn, externalID string,
) *credentials.Credentials {
regionalEndpoint := getSTSRegionalEndpoint(region)
// if regionalEndpoint is "", the STS endpoint is Global endpoint for classic regions except ap-east-1 - (HKG)
Expand All @@ -254,23 +254,29 @@ func getSTSCredsFromRegionEndpoint(logger *zap.Logger, sess *session.Session, re
c := &aws.Config{Region: aws.String(region), Endpoint: &regionalEndpoint}
st := sts.New(sess, c)
logger.Info("STS Endpoint ", zap.String("endpoint", st.Endpoint))
return stscreds.NewCredentialsWithClient(st, roleArn)
options := []func(*stscreds.AssumeRoleProvider){}
if externalID != "" {
options = append(options, func(arp *stscreds.AssumeRoleProvider) {
arp.ExternalID = aws.String(externalID)
})
}
return stscreds.NewCredentialsWithClient(st, roleArn, options...)
}

// getSTSCredsFromPrimaryRegionEndpoint fetches STS credentials for provided roleARN from primary region endpoint in
// the respective partition.
func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, t *session.Session, roleArn string,
func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, t *session.Session, roleArn, externalID string,
region string,
) *credentials.Credentials {
logger.Info("Credentials for provided RoleARN being fetched from STS primary region endpoint.")
partitionID := getPartition(region)
switch partitionID {
case endpoints.AwsPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.UsEast1RegionID, roleArn)
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.UsEast1RegionID, roleArn, externalID)
case endpoints.AwsCnPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.CnNorth1RegionID, roleArn)
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.CnNorth1RegionID, roleArn, externalID)
case endpoints.AwsUsGovPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.UsGovWest1RegionID, roleArn)
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.UsGovWest1RegionID, roleArn, externalID)
}

return nil
Expand Down
16 changes: 9 additions & 7 deletions internal/aws/awsutil/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (c *mockConn) getEC2Region(_ *session.Session) (string, error) {
return ec2Region, nil
}

func (c *mockConn) newAWSSession(_ *zap.Logger, _ string, _ string) (*session.Session, error) {
func (c *mockConn) newAWSSession(_ *zap.Logger, _ string, _ string, _ string) (*session.Session, error) {
return c.sn, nil
}

Expand Down Expand Up @@ -104,15 +104,16 @@ func TestGetAWSConfigSessionWithEC2RegionErr(t *testing.T) {
func TestNewAWSSessionWithErr(t *testing.T) {
logger := zap.NewNop()
roleArn := "fake_arn"
externalID := ""
region := "fake_region"
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "fake")
conn := &Conn{}
se, err := conn.newAWSSession(logger, roleArn, region)
se, err := conn.newAWSSession(logger, roleArn, externalID, region)
assert.Error(t, err)
assert.Nil(t, se)
roleArn = ""
se, err = conn.newAWSSession(logger, roleArn, region)
se, err = conn.newAWSSession(logger, roleArn, externalID, region)
assert.Error(t, err)
assert.Nil(t, se)
t.Setenv("AWS_SDK_LOAD_CONFIG", "true")
Expand All @@ -132,10 +133,10 @@ func TestGetSTSCredsFromPrimaryRegionEndpoint(t *testing.T) {
regions := []string{"us-east-1", "us-gov-west-1", "cn-north-1"}

for _, region := range regions {
creds := getSTSCredsFromPrimaryRegionEndpoint(logger, session, "", region)
creds := getSTSCredsFromPrimaryRegionEndpoint(logger, session, "", "", region)
assert.NotNil(t, creds)
}
creds := getSTSCredsFromPrimaryRegionEndpoint(logger, session, "", "fake_region")
creds := getSTSCredsFromPrimaryRegionEndpoint(logger, session, "", "", "fake_region")
assert.Nil(t, creds)
}

Expand All @@ -150,9 +151,10 @@ func TestGetSTSCreds(t *testing.T) {
logger := zap.NewNop()
region := "fake_region"
roleArn := ""
_, err := getSTSCreds(logger, region, roleArn)
externalID := ""
_, err := getSTSCreds(logger, region, roleArn, externalID)
assert.NoError(t, err)
t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "fake")
_, err = getSTSCreds(logger, region, roleArn)
_, err = getSTSCreds(logger, region, roleArn, externalID)
assert.Error(t, err)
}

0 comments on commit 10da12d

Please sign in to comment.