Skip to content

Commit

Permalink
deps: upgrade aws-sdk-go from v1 to v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mismithhisler committed Dec 19, 2024
1 parent 30ab889 commit a1b3d73
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 89 deletions.
153 changes: 85 additions & 68 deletions client/fingerprint/env_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
package fingerprint

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"regexp"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
smithyHttp "github.com/aws/smithy-go/transport/http"

"github.com/hashicorp/go-cleanhttp"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/nomad/structs"
)
Expand Down Expand Up @@ -52,17 +53,13 @@ var ec2NetSpeedTable = map[*regexp.Regexp]int{
type EnvAWSFingerprint struct {
StaticFingerprinter

// endpoint for EC2 metadata as expected by AWS SDK
endpoint string

logger log.Logger
}

// NewEnvAWSFingerprint is used to create a fingerprint from AWS metadata
func NewEnvAWSFingerprint(logger log.Logger) Fingerprint {
f := &EnvAWSFingerprint{
logger: logger.Named("env_aws"),
endpoint: strings.TrimSuffix(os.Getenv("AWS_ENV_URL"), "/meta-data/"),
logger: logger.Named("env_aws"),
}
return f
}
Expand All @@ -77,12 +74,16 @@ func (f *EnvAWSFingerprint) Fingerprint(request *FingerprintRequest, response *F
timeout = 1 * time.Millisecond
}

ec2meta, err := ec2MetaClient(f.endpoint, timeout)
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()

imdsClient, err := imdsClient(ctx)
if err != nil {
return fmt.Errorf("failed to setup ec2Metadata client: %v", err)
return fmt.Errorf("failed to setup IMDS client: %v", err)
}

if !isAWS(ec2meta) {
if !isAWS(ctx, imdsClient) {
f.logger.Debug("error querying AWS IDMS URL, skipping")
return nil
}

Expand All @@ -104,24 +105,24 @@ func (f *EnvAWSFingerprint) Fingerprint(request *FingerprintRequest, response *F
}

for k, unique := range keys {
resp, err := ec2meta.GetMetadata(k)
v := strings.TrimSpace(resp)
if v == "" {
f.logger.Debug("read an empty value", "attribute", k)
continue
} else if awsErr, ok := err.(awserr.RequestFailure); ok {
f.logger.Debug("could not read attribute value", "attribute", k, "error", awsErr)
resp, err := imdsClient.GetMetadata(ctx, &imds.GetMetadataInput{
Path: k,
})
if err := f.handleImdsError(err, k); err != nil {
return err
}
if resp == nil {
continue
} else if awsErr, ok := err.(awserr.Error); ok {
// if it's a URL error, assume we're not in an AWS environment
// TODO: better way to detect AWS? Check xen virtualization?
if _, ok := awsErr.OrigErr().(*url.Error); ok {
return nil
}
}

// not sure what other errors it would return
bytes, err := io.ReadAll(resp.Content)
if err != nil {
return err
}
v := string(bytes)
if v == "" {
f.logger.Debug("read an empty value", "attribute", k)
}

// assume we want blank entries
key := "platform.aws." + strings.ReplaceAll(k, "/", ".")
Expand All @@ -144,32 +145,32 @@ func (f *EnvAWSFingerprint) Fingerprint(request *FingerprintRequest, response *F
Device: "eth0",
IP: val,
CIDR: val + "/32",
MBits: f.throughput(request, ec2meta, val),
MBits: f.throughput(request, imdsClient, val),
},
}
}

// copy over IPv6 network specific information
if val, ok := response.Attributes["unique.platform.aws.mac"]; ok && val != "" {
k := "network/interfaces/macs/" + val + "/ipv6s"
addrsStr, err := ec2meta.GetMetadata(k)
addrsStr = strings.TrimSpace(addrsStr)
if addrsStr == "" {
f.logger.Debug("read an empty value", "attribute", k)
} else if awsErr, ok := err.(awserr.RequestFailure); ok {
f.logger.Debug("could not read attribute value", "attribute", k, "error", awsErr)
} else if awsErr, ok := err.(awserr.Error); ok {
// if it's a URL error, assume we're not in an AWS environment
// TODO: better way to detect AWS? Check xen virtualization?
if _, ok := awsErr.OrigErr().(*url.Error); ok {
return nil
}

// not sure what other errors it would return
resp, err := imdsClient.GetMetadata(ctx, &imds.GetMetadataInput{
Path: k,
})
if err := f.handleImdsError(err, k); err != nil {
return err
} else {
addrs := strings.SplitN(addrsStr, "\n", 2)
response.AddAttribute("unique.platform.aws.public-ipv6", addrs[0])
}
if resp != nil {
addrBytes, err := io.ReadAll(resp.Content)
if err != nil {
return err
}
addrsStr := string(addrBytes)
if addrsStr == "" {
f.logger.Debug("read an empty value", "attribute", k)
} else {
addrs := strings.SplitN(addrsStr, "\n", 2)
response.AddAttribute("unique.platform.aws.public-ipv6", addrs[0])
}
}
}

Expand All @@ -184,21 +185,41 @@ func (f *EnvAWSFingerprint) Fingerprint(request *FingerprintRequest, response *F
return nil
}

func (f *EnvAWSFingerprint) instanceType(ec2meta *ec2metadata.EC2Metadata) (string, error) {
response, err := ec2meta.GetMetadata("instance-type")
// See https://aws.github.io/aws-sdk-go-v2/docs/handling-errors for
// recommended error handling with aws-sdk-go-v2.
// See also: https://github.com/aws/aws-sdk-go-v2/issues/1306
func (f *EnvAWSFingerprint) handleImdsError(err error, attr string) error {
var apiErr *smithyHttp.ResponseError
if errors.As(err, &apiErr) {
// In the event of a request error while fetching attributes, just log and return nil.
// This will happen if attributes do not exist for this instance (ex. ipv6, public-ipv4s).
f.logger.Debug("could not read attribute value", "attribute", attr, "error", err)
return nil
}
return err
}

func (f *EnvAWSFingerprint) instanceType(client *imds.Client) (string, error) {
output, err := client.GetMetadata(context.TODO(), &imds.GetMetadataInput{
Path: "instance-type",
})
if err != nil {
return "", err
}
content, err := io.ReadAll(output.Content)
if err != nil {
return "", err
}
return strings.TrimSpace(response), nil
return string(content), nil
}

func (f *EnvAWSFingerprint) throughput(request *FingerprintRequest, ec2meta *ec2metadata.EC2Metadata, ip string) int {
func (f *EnvAWSFingerprint) throughput(request *FingerprintRequest, client *imds.Client, ip string) int {
throughput := request.Config.NetworkSpeed
if throughput != 0 {
return throughput
}

throughput = f.linkSpeed(ec2meta)
throughput = f.linkSpeed(client)
if throughput != 0 {
return throughput
}
Expand All @@ -215,8 +236,8 @@ func (f *EnvAWSFingerprint) throughput(request *FingerprintRequest, ec2meta *ec2
}

// EnvAWSFingerprint uses lookup table to approximate network speeds
func (f *EnvAWSFingerprint) linkSpeed(ec2meta *ec2metadata.EC2Metadata) int {
instanceType, err := f.instanceType(ec2meta)
func (f *EnvAWSFingerprint) linkSpeed(client *imds.Client) int {
instanceType, err := f.instanceType(client)
if err != nil {
f.logger.Error("error reading instance-type", "error", err)
return 0
Expand All @@ -233,26 +254,22 @@ func (f *EnvAWSFingerprint) linkSpeed(ec2meta *ec2metadata.EC2Metadata) int {
return netSpeed
}

func ec2MetaClient(endpoint string, timeout time.Duration) (*ec2metadata.EC2Metadata, error) {
func imdsClient(ctx context.Context) (*imds.Client, error) {
client := &http.Client{
Timeout: timeout,
Transport: cleanhttp.DefaultTransport(),
}

c := aws.NewConfig().WithHTTPClient(client).WithMaxRetries(0)
if endpoint != "" {
c = c.WithEndpoint(endpoint)
}

sess, err := session.NewSession(c)
cfg, err := config.LoadDefaultConfig(ctx, config.WithHTTPClient(client), config.WithRetryMaxAttempts(0))
if err != nil {
return nil, err
}
return ec2metadata.New(sess, c), nil
return imds.NewFromConfig(cfg), nil
}

func isAWS(ec2meta *ec2metadata.EC2Metadata) bool {
v, err := ec2meta.GetMetadata("ami-id")
v = strings.TrimSpace(v)
return err == nil && v != ""
// isAWS validates the client can reach IMDS. Fetching an ami-id must
// complete error free to be recognized as running on AWS EC2
func isAWS(ctx context.Context, client *imds.Client) bool {
_, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
Path: "ami-id",
})
return err == nil
}
48 changes: 28 additions & 20 deletions e2e/remotetasks/remotetasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
package remotetasks

import (
"context"
"fmt"
"os"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/hashicorp/nomad/api"
"github.com/hashicorp/nomad/e2e/e2eutil"
"github.com/hashicorp/nomad/e2e/framework"
Expand Down Expand Up @@ -78,7 +79,9 @@ func (tc *RemoteTasksTest) AfterEach(f *framework.F) {
func (tc *RemoteTasksTest) TestECSJob(f *framework.F) {
t := f.T()

ecsClient := ecsOrSkip(t, tc.Nomad())
ctx := context.Background()

ecsClient := ecsOrSkip(ctx, t, tc.Nomad())

jobID := "ecsjob-" + uuid.Generate()[0:8]
tc.jobIDs = append(tc.jobIDs, jobID)
Expand All @@ -92,7 +95,7 @@ func (tc *RemoteTasksTest) TestECSJob(f *framework.F) {
arn := arnForAlloc(t, tc.Nomad().Allocations(), allocID)

// Use ARN to lookup status of ECS task in AWS
ensureECSRunning(t, ecsClient, arn)
ensureECSRunning(ctx, t, ecsClient, arn)

t.Logf("Task %s is running!", arn)

Expand All @@ -102,10 +105,10 @@ func (tc *RemoteTasksTest) TestECSJob(f *framework.F) {
// Ensure it is stopped in ECS
input := ecs.DescribeTasksInput{
Cluster: aws.String("nomad-rtd-e2e"),
Tasks: []*string{aws.String(arn)},
Tasks: []string{arn},
}
testutil.WaitForResult(func() (bool, error) {
resp, err := ecsClient.DescribeTasks(&input)
resp, err := ecsClient.DescribeTasks(ctx, &input)
if err != nil {
return false, err
}
Expand All @@ -121,7 +124,9 @@ func (tc *RemoteTasksTest) TestECSJob(f *framework.F) {
func (tc *RemoteTasksTest) TestECSDrain(f *framework.F) {
t := f.T()

ecsClient := ecsOrSkip(t, tc.Nomad())
ctx := context.Background()

ecsClient := ecsOrSkip(ctx, t, tc.Nomad())

jobID := "ecsjob-" + uuid.Generate()[0:8]
tc.jobIDs = append(tc.jobIDs, jobID)
Expand All @@ -132,7 +137,7 @@ func (tc *RemoteTasksTest) TestECSDrain(f *framework.F) {
e2eutil.WaitForAllocsRunning(t, tc.Nomad(), []string{origAlloc})

arn := arnForAlloc(t, tc.Nomad().Allocations(), origAlloc)
ensureECSRunning(t, ecsClient, arn)
ensureECSRunning(ctx, t, ecsClient, arn)

t.Logf("Task %s is running! Now to drain the node.", arn)

Expand Down Expand Up @@ -197,7 +202,9 @@ func (tc *RemoteTasksTest) TestECSDrain(f *framework.F) {
func (tc *RemoteTasksTest) TestECSDeployment(f *framework.F) {
t := f.T()

ecsClient := ecsOrSkip(t, tc.Nomad())
ctx := context.Background()

ecsClient := ecsOrSkip(ctx, t, tc.Nomad())

jobID := "ecsjob-" + uuid.Generate()[0:8]
tc.jobIDs = append(tc.jobIDs, jobID)
Expand All @@ -211,7 +218,7 @@ func (tc *RemoteTasksTest) TestECSDeployment(f *framework.F) {
origARN := arnForAlloc(t, tc.Nomad().Allocations(), origAllocID)

// Use ARN to lookup status of ECS task in AWS
ensureECSRunning(t, ecsClient, origARN)
ensureECSRunning(ctx, t, ecsClient, origARN)

t.Logf("Task %s is running! Updating...", origARN)

Expand Down Expand Up @@ -271,10 +278,10 @@ func (tc *RemoteTasksTest) TestECSDeployment(f *framework.F) {
// Ensure original ARN is stopped in ECS
input := ecs.DescribeTasksInput{
Cluster: aws.String("nomad-rtd-e2e"),
Tasks: []*string{aws.String(origARN)},
Tasks: []string{origARN},
}
testutil.WaitForResult(func() (bool, error) {
resp, err := ecsClient.DescribeTasks(&input)
resp, err := ecsClient.DescribeTasks(ctx, &input)
if err != nil {
return false, err
}
Expand All @@ -287,12 +294,13 @@ func (tc *RemoteTasksTest) TestECSDeployment(f *framework.F) {

// ecsOrSkip returns an AWS ECS client or skips the test if ECS is unreachable
// by the test runner or the ECS remote task driver isn't healthy.
func ecsOrSkip(t *testing.T, nomadClient *api.Client) *ecs.ECS {
awsSession := session.Must(session.NewSession())
func ecsOrSkip(ctx context.Context, t *testing.T, nomadClient *api.Client) *ecs.Client {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion("us-east-1"))
require.NoError(t, err)

ecsClient := ecs.New(awsSession, aws.NewConfig().WithRegion("us-east-1"))
ecsClient := ecs.NewFromConfig(cfg)

_, err := ecsClient.ListClusters(&ecs.ListClustersInput{})
_, err = ecsClient.ListClusters(ctx, &ecs.ListClustersInput{})
if err != nil {
t.Skipf("Skipping ECS Remote Task Driver Task. Error querying AWS ECS API: %v", err)
}
Expand Down Expand Up @@ -378,14 +386,14 @@ func arnForAlloc(t *testing.T, allocAPI *api.Allocations, allocID string) string
}

// ensureECSRunning asserts that the given ARN is a running ECS task.
func ensureECSRunning(t *testing.T, ecsClient *ecs.ECS, arn string) {
func ensureECSRunning(ctx context.Context, t *testing.T, ecsClient *ecs.Client, arn string) {
t.Logf("Ensuring ARN=%s is running", arn)
input := ecs.DescribeTasksInput{
Cluster: aws.String("nomad-rtd-e2e"),
Tasks: []*string{aws.String(arn)},
Tasks: []string{arn},
}
testutil.WaitForResult(func() (bool, error) {
resp, err := ecsClient.DescribeTasks(&input)
resp, err := ecsClient.DescribeTasks(ctx, &input)
if err != nil {
return false, err
}
Expand Down
Loading

0 comments on commit a1b3d73

Please sign in to comment.