diff --git a/main.go b/main.go index 496942c..0d7c05f 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "net/http" "os" "os/signal" + "strconv" "strings" "syscall" "time" @@ -61,12 +62,32 @@ func setupCollectors(logger log.Logger, configFile string) ([]prometheus.Collect level.Info(logger).Log("msg", "Configuring route53 with region", "region", config.Route53Config.Region) level.Info(logger).Log("msg", "Will VPC metrics be gathered?", "vpc-enabled", config.VpcConfig.Enabled) // Create a single session here, because we need the accountid, before we create the other configs + roleARN := os.Getenv("ROLE_ARN") + sessionName := os.Getenv("SESSION_NAME") awsConfig := aws.NewConfig().WithRegion("us-east-1") sess := session.Must(session.NewSession(awsConfig)) + durationSeconds := os.Getenv("TOKEN_DURATION") + convertedDurationSeconds, err := strconv.ParseInt(durationSeconds, 10, 64) + if err != nil { + return nil, err + } awsAccountId, err := getAwsAccountNumber(logger, sess) if err != nil { return collectors, err } + if pkg.LookUpEnvVar("ROLE_ARN") && pkg.LookUpEnvVar("SESSION_NAME") { + client := sts.New(sess) + aws_credentials := sts.Credentials{} + err = pkg.AssumeRole(client, roleARN, sessionName, convertedDurationSeconds, logger) + if err != nil { + return nil, err + } + start_index := strings.Index(roleARN, "::") + 2 + end_index := strings.LastIndex(roleARN, ":") + awsAccountId = roleARN[start_index:end_index] + go pkg.RefreshToken(client, &aws_credentials, roleARN, sessionName, convertedDurationSeconds, logger) + } + var vpcSessions []*session.Session if config.VpcConfig.Enabled { for _, region := range config.VpcConfig.Regions { diff --git a/openshift/aws-resource-exporter.yaml b/openshift/aws-resource-exporter.yaml index c620390..bd008c3 100644 --- a/openshift/aws-resource-exporter.yaml +++ b/openshift/aws-resource-exporter.yaml @@ -52,6 +52,16 @@ objects: value: ${AWS_REGION} - name: AWS_RESOURCE_EXPORTER_CONFIG_FILE value: /etc/aws-resource-exporter/aws-resource-exporter-config.yaml + - name: ROLE_ARN + valueFrom: + secretKeyRef: + name: ${SECRET_NAME} + key: role_arn + optional: true + - name: TOKEN_DURATION + value: ${TOKEN_DURATION} + - name: SESSION_NAME + value: ${SESSION_NAME} volumeMounts: - name: exporter-configuration mountPath: /etc/aws-resource-exporter/ @@ -134,3 +144,7 @@ parameters: route53: regions: "" timeout: 60s +- name: SESSION_NAME + value: TestSessionName +- name: TOKEN_DURATION + value: 60 diff --git a/pkg/awsclient/awsclient.go b/pkg/awsclient/awsclient.go index cbe2b4a..04cd40c 100644 --- a/pkg/awsclient/awsclient.go +++ b/pkg/awsclient/awsclient.go @@ -11,6 +11,8 @@ import ( "github.com/aws/aws-sdk-go/service/route53/route53iface" "github.com/aws/aws-sdk-go/service/servicequotas" "github.com/aws/aws-sdk-go/service/servicequotas/servicequotasiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -41,6 +43,10 @@ type Client interface { GetHostedZoneLimitWithContext(ctx context.Context, input *route53.GetHostedZoneLimitInput, opts ...request.Option) (*route53.GetHostedZoneLimitOutput, error) } +type Sts interface { + AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) +} + type awsClient struct { ec2Client ec2iface.EC2API rdsClient rds.RDS @@ -48,6 +54,10 @@ type awsClient struct { route53Client route53iface.Route53API } +type awsSts struct { + sts stsiface.STSAPI +} + func (c *awsClient) DescribeTransitGatewaysWithContext(ctx aws.Context, input *ec2.DescribeTransitGatewaysInput, opts ...request.Option) (*ec2.DescribeTransitGatewaysOutput, error) { return c.ec2Client.DescribeTransitGatewaysWithContext(ctx, input, opts...) } @@ -130,6 +140,10 @@ func (c *awsClient) GetHostedZoneLimitWithContext(ctx context.Context, input *ro return c.route53Client.GetHostedZoneLimitWithContext(ctx, input, opts...) } +func (s *awsSts) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { + return s.sts.AssumeRole(input) +} + func NewClientFromSession(sess *session.Session) Client { return &awsClient{ ec2Client: ec2.New(sess), diff --git a/pkg/awsclient/mock/zz_generated.mock_client.go b/pkg/awsclient/mock/zz_generated.mock_client.go index b53fbf6..d5eba2f 100644 --- a/pkg/awsclient/mock/zz_generated.mock_client.go +++ b/pkg/awsclient/mock/zz_generated.mock_client.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: /Users/suzananesic/go/src/github.com/app-sre/aws-resource-exporter/pkg/awsclient/awsclient.go +// Source: ./awsclient.go // Package mock is a generated GoMock package. package mock @@ -15,6 +15,7 @@ import ( route53 "github.com/aws/aws-sdk-go/service/route53" servicequotas "github.com/aws/aws-sdk-go/service/servicequotas" gomock "github.com/golang/mock/gomock" + sts "github.com/aws/aws-sdk-go/service/sts" ) // MockClient is a mock of Client interface. @@ -221,3 +222,18 @@ func (mr *MockClientMockRecorder) ListHostedZonesWithContext(ctx, input interfac varargs := append([]interface{}{ctx, input}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListHostedZonesWithContext", reflect.TypeOf((*MockClient)(nil).ListHostedZonesWithContext), varargs...) } + +// AssumeRole mocks base method. +func (m *MockClient) AssumeRole(arg0 *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AssumeRole", arg0) + ret0, _ := ret[0].(*sts.AssumeRoleOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AssumeRole indicates an expected call of AssumeRole. +func (mr *MockClientMockRecorder) AssumeRole(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRole", reflect.TypeOf((*MockClient)(nil).AssumeRole), arg0) +} \ No newline at end of file diff --git a/pkg/sts.go b/pkg/sts.go new file mode 100644 index 0000000..b30dc6a --- /dev/null +++ b/pkg/sts.go @@ -0,0 +1,53 @@ +package pkg + +import ( + "os" + "time" + + "github.com/app-sre/aws-resource-exporter/pkg/awsclient" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/go-kit/kit/log" +) + +func AssumeRole(client awsclient.Sts, roleARN, sessionName string, durationInSeconds int64, logger log.Logger) error { + roleInput := sts.AssumeRoleInput{ + RoleArn: &roleARN, + RoleSessionName: &sessionName, + DurationSeconds: &durationInSeconds, + } + result, err := client.AssumeRole(&roleInput) + if err != nil { + return err + } + + if err := os.Setenv("AWS_ACCESS_KEY_ID", *result.Credentials.AccessKeyId); err != nil { + return err + } + + if err := os.Setenv("AWS_SECRET_ACCESS_KEY", *result.Credentials.SecretAccessKey); err != nil { + return err + } + + if err := os.Setenv("AWS_SESSION_TOKEN", *result.Credentials.SessionToken); err != nil { + return err + } + return nil +} + +func LookUpEnvVar(key string) bool { + _, ok := os.LookupEnv(key) + return ok +} + +func RefreshToken(client awsclient.Sts, credentials *sts.Credentials, roleARN, sessionName string, durationInSeconds int64, logger log.Logger) error { + for { + expiration := credentials.Expiration + refreshWindow := time.Minute * 5 + if expiration != nil && expiration.Sub(time.Now()) < refreshWindow { + err := AssumeRole(client, roleARN, sessionName, durationInSeconds, logger) + if err != nil { + return err + } + } + } +} diff --git a/pkg/sts_test.go b/pkg/sts_test.go new file mode 100644 index 0000000..1d1a747 --- /dev/null +++ b/pkg/sts_test.go @@ -0,0 +1,38 @@ +package pkg + +import ( + "os" + "testing" + + "github.com/app-sre/aws-resource-exporter/pkg/awsclient/mock" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestTakeRoleSTS(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSvc := mock.NewMockClient(ctrl) + + assumeRoleInput := &sts.AssumeRoleInput{ + RoleArn: aws.String("arn:aws:iam::123456789012:role/example"), + RoleSessionName: aws.String("session-name"), + DurationSeconds: aws.Int64(60), + } + assumeRoleOutput := &sts.AssumeRoleOutput{ + Credentials: &sts.Credentials{ + AccessKeyId: aws.String("access-key-id"), + SecretAccessKey: aws.String("secret-access-key"), + SessionToken: aws.String("session-token"), + }, + } + mockSvc.EXPECT().AssumeRole(assumeRoleInput).Return(assumeRoleOutput, nil) + err := AssumeRole(mockSvc, "arn:aws:iam::123456789012:role/example", "session-name", 60, nil) + assert.NoError(t, err) + assert.Equal(t, "access-key-id", os.Getenv("AWS_ACCESS_KEY_ID")) + assert.Equal(t, "secret-access-key", os.Getenv("AWS_SECRET_ACCESS_KEY")) + assert.Equal(t, "session-token", os.Getenv("AWS_SESSION_TOKEN")) +}