diff --git a/pubsub/aws/aws.go b/pubsub/aws/aws.go index 74de9570a..858657ad0 100644 --- a/pubsub/aws/aws.go +++ b/pubsub/aws/aws.go @@ -15,7 +15,6 @@ import ( "github.com/aws/aws-sdk-go/service/sns/snsiface" "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" - "github.com/aws/aws-sdk-go/service/sts" "github.com/golang/protobuf/proto" "golang.org/x/net/context" ) @@ -43,12 +42,17 @@ func NewPublisher(cfg SNSConfig) (pubsub.Publisher, error) { return p, errors.New("SNS region is required") } + sess, err := session.NewSession() + if err != nil { + return p, err + } + var creds *credentials.Credentials if cfg.AccessKey != "" { creds = credentials.NewStaticCredentials(cfg.AccessKey, cfg.SecretKey, cfg.SessionToken) } else if cfg.RoleARN != "" { var err error - creds, err = requestRoleCredentials(creds, cfg.RoleARN, cfg.MFASerialNumber) + creds, err = requestRoleCredentials(sess, cfg.RoleARN, cfg.MFASerialNumber) if err != nil { return p, err } @@ -56,11 +60,6 @@ func NewPublisher(cfg SNSConfig) (pubsub.Publisher, error) { creds = credentials.NewEnvCredentials() } - sess, err := session.NewSession() - if err != nil { - return p, err - } - p.sns = sns.New(sess, &aws.Config{ Credentials: creds, Region: &cfg.Region, @@ -199,12 +198,17 @@ func NewSubscriber(cfg SQSConfig) (pubsub.Subscriber, error) { return s, errors.New("sqs queue name or url is required") } + sess, err := session.NewSession() + if err != nil { + return s, err + } + var creds *credentials.Credentials if cfg.AccessKey != "" { creds = credentials.NewStaticCredentials(cfg.AccessKey, cfg.SecretKey, cfg.SessionToken) } else if cfg.RoleARN != "" { var err error - creds, err = requestRoleCredentials(creds, cfg.RoleARN, cfg.MFASerialNumber) + creds, err = requestRoleCredentials(sess, cfg.RoleARN, cfg.MFASerialNumber) if err != nil { return s, err } @@ -212,10 +216,6 @@ func NewSubscriber(cfg SQSConfig) (pubsub.Subscriber, error) { creds = credentials.NewEnvCredentials() } - sess, err := session.NewSession() - if err != nil { - return s, err - } s.sqs = sqs.New(sess, &aws.Config{ Credentials: creds, Region: &cfg.Region, @@ -399,24 +399,15 @@ func (s *subscriber) Err() error { // requestRoleCredentials return the credentials from AssumeRoleProvider to assume the role // referenced by the roleARN. If MFASerialNumber is specified, prompt for MFA token from stdin. -func requestRoleCredentials(creds *credentials.Credentials, roleARN string, MFASerialNumber string) (*credentials.Credentials, error) { +func requestRoleCredentials(sess *session.Session, roleARN string, MFASerialNumber string) (*credentials.Credentials, error) { if roleARN == "" { return nil, errors.New("role ARN is required") } - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *aws.NewConfig().WithCredentials(creds), - }) - if err != nil { - return nil, err - } - assumeRole := &stscreds.AssumeRoleProvider{ - Client: sts.New(sess), - RoleARN: roleARN, - Duration: stscreds.DefaultDuration, - } - if MFASerialNumber != "" { - assumeRole.SerialNumber = &MFASerialNumber - assumeRole.TokenProvider = stscreds.StdinTokenProvider - } - return credentials.NewCredentials(assumeRole), nil + + return stscreds.NewCredentials(sess, roleARN, func(provider *stscreds.AssumeRoleProvider) { + if MFASerialNumber != "" { + provider.SerialNumber = &MFASerialNumber + provider.TokenProvider = stscreds.StdinTokenProvider + } + }), nil }