Skip to content

Commit

Permalink
Add fallback shared config files for credential ordering. (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
jefchien authored Sep 6, 2023
1 parent 7f38429 commit ced73ba
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 3 deletions.
34 changes: 31 additions & 3 deletions internal/aws/awsutil/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
Expand Down Expand Up @@ -372,16 +373,43 @@ func getSTSRegionalEndpoint(r string) string {
}

func GetDefaultSession(logger *zap.Logger, cfg *AWSSessionSettings) (*session.Session, error) {
awsConfig := &aws.Config{
cfgFiles := getFallbackSharedConfigFiles(backwardsCompatibleUserHomeDir)
logger.Debug("Fallback shared config file(s)", zap.Strings("files", cfgFiles))
awsConfig := aws.Config{
Credentials: getRootCredentials(cfg),
}
result, serr := session.NewSession(awsConfig)
result, serr := session.NewSessionWithOptions(session.Options{
Config: awsConfig,
SharedConfigFiles: cfgFiles,
})
if serr != nil {
logger.Error("Error in creating session object waiting 15 seconds", zap.Error(serr))
time.Sleep(15 * time.Second)
result, serr = session.NewSession(awsConfig)
result, serr = session.NewSessionWithOptions(session.Options{
Config: awsConfig,
SharedConfigFiles: cfgFiles,
})
if serr != nil {
logger.Error("Retry failed for creating credential sessions", zap.Error(serr))
return result, serr
}
}
cred, err := result.Config.Credentials.Get()
if err != nil {
logger.Error("Failed to get credential from session", zap.Error(err))
} else {
logger.Debug("Using credential from session", zap.String("access-key", cred.AccessKeyID), zap.String("provider", cred.ProviderName))
}
if cred.ProviderName == ec2rolecreds.ProviderName {
var found []string
cfgFiles = getFallbackSharedConfigFiles(currentUserHomeDir)
for _, cfgFile := range cfgFiles {
if _, err = os.Stat(cfgFile); err == nil {
found = append(found, cfgFile)
}
}
if len(found) > 0 {
logger.Warn("Unused shared config file(s) found.", zap.Strings("files", found))
}
}
return result, serr
Expand Down
95 changes: 95 additions & 0 deletions internal/aws/awsutil/shared_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright The OpenTelemetry Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package awsutil // import "github.com/open-telemetry/opentelemetry-collector-contrib/internal/aws/awsutil"

import (
"os"
"os/user"
"path/filepath"
"strconv"
)

const (
envAwsSdkLoadConfig = "AWS_SDK_LOAD_CONFIG"
// nolint:gosec
envAwsSharedCredentialsFile = "AWS_SHARED_CREDENTIALS_FILE"
envAwsSharedConfigFile = "AWS_CONFIG_FILE"
)

// getFallbackSharedConfigFiles follows the same logic as the AWS SDK but takes a getUserHomeDir
// function.
func getFallbackSharedConfigFiles(userHomeDirProvider func() string) []string {
var sharedCredentialsFile, sharedConfigFile string
setFromEnvVal(&sharedCredentialsFile, envAwsSharedCredentialsFile)
setFromEnvVal(&sharedConfigFile, envAwsSharedConfigFile)
if sharedCredentialsFile == "" {
sharedCredentialsFile = defaultSharedCredentialsFile(userHomeDirProvider())
}
if sharedConfigFile == "" {
sharedConfigFile = defaultSharedConfig(userHomeDirProvider())
}
var cfgFiles []string
enableSharedConfig, _ := strconv.ParseBool(os.Getenv(envAwsSdkLoadConfig))
if enableSharedConfig {
cfgFiles = append(cfgFiles, sharedConfigFile)
}
return append(cfgFiles, sharedCredentialsFile)
}

func setFromEnvVal(dst *string, keys ...string) {
for _, k := range keys {
if v := os.Getenv(k); len(v) != 0 {
*dst = v
break
}
}
}

func defaultSharedCredentialsFile(dir string) string {
return filepath.Join(dir, ".aws", "credentials")
}

func defaultSharedConfig(dir string) string {
return filepath.Join(dir, ".aws", "config")
}

// backwardsCompatibleUserHomeDir provides the home directory based on
// environment variables.
//
// Based on v1.44.106 of the AWS SDK.
func backwardsCompatibleUserHomeDir() string {
home, _ := os.UserHomeDir()
return home
}

// currentUserHomeDir attempts to use the environment variables before falling
// back on the current user's home directory.
//
// Based on v1.44.332 of the AWS SDK.
func currentUserHomeDir() string {
var home string

home = backwardsCompatibleUserHomeDir()
if len(home) > 0 {
return home
}

currUser, _ := user.Current()
if currUser != nil {
home = currUser.HomeDir
}

return home
}
42 changes: 42 additions & 0 deletions internal/aws/awsutil/shared_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright The OpenTelemetry Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package awsutil

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetFallbackSharedConfigFiles(t *testing.T) {
noOpGetUserHomeDir := func() string { return "home" }
t.Setenv(envAwsSdkLoadConfig, "true")
t.Setenv(envAwsSharedCredentialsFile, "credentials")
t.Setenv(envAwsSharedConfigFile, "config")

got := getFallbackSharedConfigFiles(noOpGetUserHomeDir)
assert.Equal(t, []string{"config", "credentials"}, got)

t.Setenv(envAwsSdkLoadConfig, "false")
got = getFallbackSharedConfigFiles(noOpGetUserHomeDir)
assert.Equal(t, []string{"credentials"}, got)

t.Setenv(envAwsSdkLoadConfig, "true")
t.Setenv(envAwsSharedCredentialsFile, "")
t.Setenv(envAwsSharedConfigFile, "")

got = getFallbackSharedConfigFiles(noOpGetUserHomeDir)
assert.Equal(t, []string{defaultSharedConfig("home"), defaultSharedCredentialsFile("home")}, got)
}

0 comments on commit ced73ba

Please sign in to comment.