Skip to content

Commit

Permalink
Change Configure to take in either SDK version
Browse files Browse the repository at this point in the history
  • Loading branch information
jefchien committed Oct 23, 2023
1 parent 0e8671f commit 71546aa
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 47 deletions.
2 changes: 1 addition & 1 deletion exporter/awscloudwatchlogsexporter/exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (e *exporter) getLogPusher(pusherKey cwlogs.PusherKey) cwlogs.Pusher {

func (e *exporter) start(_ context.Context, host component.Host) error {
if e.Config.MiddlewareID != nil {
awsmiddleware.TryConfigureSDKv1(e.logger, host.GetExtensions(), *e.Config.MiddlewareID, e.svcStructuredLog.Handlers())
awsmiddleware.TryConfigure(e.logger, host, *e.Config.MiddlewareID, awsmiddleware.SDKv1(e.svcStructuredLog.Handlers()))
}
pusherKey := cwlogs.PusherKey{
LogGroupName: e.Config.LogGroupName,
Expand Down
2 changes: 1 addition & 1 deletion exporter/awsemfexporter/emf_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (emf *emfExporter) listPushers() []cwlogs.Pusher {

func (emf *emfExporter) start(_ context.Context, host component.Host) error {
if emf.config.MiddlewareID != nil {
awsmiddleware.TryConfigureSDKv1(emf.config.logger, host.GetExtensions(), *emf.config.MiddlewareID, emf.svcStructuredLog.Handlers())
awsmiddleware.TryConfigure(emf.config.logger, host, *emf.config.MiddlewareID, awsmiddleware.SDKv1(emf.svcStructuredLog.Handlers()))
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion exporter/awsxrayexporter/awsxray.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func newTracesExporter(
exporterhelper.WithStart(func(_ context.Context, host component.Host) error {
sender.Start()
if cfg.MiddlewareID != nil {
awsmiddleware.TryConfigureSDKv1(logger, host.GetExtensions(), *cfg.MiddlewareID, xrayClient.Handlers())
awsmiddleware.TryConfigure(logger, host, *cfg.MiddlewareID, awsmiddleware.SDKv1(xrayClient.Handlers()))
}
return nil
}),
Expand Down
4 changes: 2 additions & 2 deletions extension/awsmiddleware/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# AWS Middleware

An AWS middleware extension provides request and/or response handlers that can be configured on AWS SDK v1/v2 clients.
Other components can configure their AWS SDK clients using `awsmiddleware.GetConfigurer` and the `ConfigureSDKv1` and
`ConfigureSDKv2` functions available on the `Configurer`.
Other components can configure their AWS SDK clients using `awsmiddleware.GetConfigurer` and passing the `SDKv1` or `SDKv2`
options into the `Configure` function available on the `Configurer`.

The `awsmiddleware.Extension` interface extends `component.Extension` by adding the following method:
```
Expand Down
21 changes: 4 additions & 17 deletions extension/awsmiddleware/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,15 @@
package awsmiddleware // import "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware"

import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/request"
"go.opentelemetry.io/collector/component"
"go.uber.org/zap"
)

// TryConfigureSDKv1 is a helper function that will try to get the extension and configure the AWS SDKv1 handlers with it.
func TryConfigureSDKv1(logger *zap.Logger, extensions map[component.ID]component.Component, middlewareID ID, handlers *request.Handlers) {
if c, err := GetConfigurer(extensions, middlewareID); err != nil {
// TryConfigure is a helper function that will try to get the extension and configure the provided AWS SDK with it.
func TryConfigure(logger *zap.Logger, host component.Host, middlewareID component.ID, sdkVersion SDKVersion) {
if c, err := GetConfigurer(host.GetExtensions(), middlewareID); err != nil {
logger.Error("Unable to find AWS Middleware extension", zap.Error(err))
} else if err = c.ConfigureSDKv1(handlers); err != nil {
logger.Warn("Unable to configure middleware on AWS client", zap.Error(err))
} else {
logger.Debug("Configured middleware on AWS client", zap.String("middleware", middlewareID.String()))
}
}

// TryConfigureSDKv2 is a helper function that will try to get the extension and configure the AWS SDKv2 config with it.
func TryConfigureSDKv2(logger *zap.Logger, extensions map[component.ID]component.Component, middlewareID ID, cfg *aws.Config) {
if c, err := GetConfigurer(extensions, middlewareID); err != nil {
logger.Error("Unable to find AWS Middleware extension", zap.Error(err))
} else if err = c.ConfigureSDKv2(cfg); err != nil {
} else if err = c.Configure(sdkVersion); err != nil {
logger.Warn("Unable to configure middleware on AWS client", zap.Error(err))
} else {
logger.Debug("Configured middleware on AWS client", zap.String("middleware", middlewareID.String()))
Expand Down
43 changes: 43 additions & 0 deletions extension/awsmiddleware/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package awsmiddleware

import (
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/collector/component"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
)

func TestTryConfigure(t *testing.T) {
testCases := []SDKVersion{SDKv1(&request.Handlers{}), SDKv2(&aws.Config{})}
for _, testCase := range testCases {
id := component.NewID("test")
host := new(MockExtensionsHost)
host.On("GetExtensions").Return(nil).Once()
core, observed := observer.New(zap.DebugLevel)
logger := zap.New(core)
TryConfigure(logger, host, id, testCase)
require.Len(t, observed.FilterLevelExact(zap.ErrorLevel).TakeAll(), 1)

extensions := map[component.ID]component.Component{}
host.On("GetExtensions").Return(extensions)
handler := new(MockHandler)
handler.On("ID").Return("test")
handler.On("Position").Return(HandlerPosition(-1))
extension := new(MockMiddlewareExtension)
extension.On("Handlers").Return([]RequestHandler{handler}, nil).Once()
extensions[id] = extension
TryConfigure(logger, host, id, testCase)
require.Len(t, observed.FilterLevelExact(zap.WarnLevel).TakeAll(), 1)

extension.On("Handlers").Return(nil, nil).Once()
TryConfigure(logger, host, id, testCase)
require.Len(t, observed.FilterLevelExact(zap.DebugLevel).TakeAll(), 1)
}
}
23 changes: 18 additions & 5 deletions extension/awsmiddleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (
errNotMiddleware = errors.New("extension is not an AWS middleware")
errInvalidHandler = errors.New("invalid handler")
errUnsupportedPosition = errors.New("unsupported position")
errUnsupportedVersion = errors.New("unsupported SDK version")
)

// HandlerPosition is the relative position of a handler used during insertion.
Expand Down Expand Up @@ -118,11 +119,23 @@ func newConfigurer(requestHandlers []RequestHandler, responseHandlers []Response
return &Configurer{requestHandlers: requestHandlers, responseHandlers: responseHandlers}
}

// ConfigureSDKv1 adds middleware to the AWS SDK v1. Request handlers are added to the
// Configure configures the handlers on the provided AWS SDK.
func (c Configurer) Configure(sdkVersion SDKVersion) error {
switch v := sdkVersion.(type) {
case sdkVersion1:
return c.configureSDKv1(v.handlers)
case sdkVersion2:
return c.configureSDKv2(v.cfg)
default:
return fmt.Errorf("%w: %T", errUnsupportedVersion, v)
}
}

// configureSDKv1 adds middleware to the AWS SDK v1. Request handlers are added to the
// Build handler list and response handlers are added to the ValidateResponse handler list.
// Build will only be run once per request, but if there are errors, ValidateResponse will
// be run again for each configured retry.
func (c Configurer) ConfigureSDKv1(handlers *request.Handlers) error {
func (c Configurer) configureSDKv1(handlers *request.Handlers) error {
var errs error
for _, handler := range c.requestHandlers {
if err := appendHandler(&handlers.Build, namedRequestHandler(handler), handler.Position()); err != nil {
Expand All @@ -137,9 +150,9 @@ func (c Configurer) ConfigureSDKv1(handlers *request.Handlers) error {
return errs
}

// ConfigureSDKv2 adds middleware to the AWS SDK v2. Request handlers are added to the
// configureSDKv2 adds middleware to the AWS SDK v2. Request handlers are added to the
// Build step and response handlers are added to the Deserialize step.
func (c Configurer) ConfigureSDKv2(config *aws.Config) error {
func (c Configurer) configureSDKv2(config *aws.Config) error {
var errs error
for _, handler := range c.requestHandlers {
relativePosition, err := toRelativePosition(handler.Position())
Expand All @@ -160,7 +173,7 @@ func (c Configurer) ConfigureSDKv2(config *aws.Config) error {
return errs
}

// addHandlerToList adds the handler to the list based on the position.
// appendHandler adds the handler to the list based on the position.
func appendHandler(handlerList *request.HandlerList, handler request.NamedHandler, position HandlerPosition) error {
relativePosition, err := toRelativePosition(position)
if err != nil {
Expand Down
45 changes: 25 additions & 20 deletions extension/awsmiddleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package awsmiddleware

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -14,6 +13,7 @@ import (
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
s3v2 "github.com/aws/aws-sdk-go-v2/service/s3"
awsv1 "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
s3v1 "github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -114,19 +114,24 @@ func TestInvalidHandlers(t *testing.T) {
middleware := new(MockMiddlewareExtension)
middleware.On("Handlers").Return([]RequestHandler{handler}, []ResponseHandler{handler})
c := newConfigurer(middleware.Handlers())
// v1
client := awstesting.NewClient()
err := c.ConfigureSDKv1(&client.Handlers)
assert.Error(t, err)
assert.True(t, errors.Is(err, errInvalidHandler))
assert.True(t, errors.Is(err, errUnsupportedPosition))
// v2
err = c.ConfigureSDKv2(&awsv2.Config{})
testCases := []SDKVersion{SDKv1(&request.Handlers{}), SDKv2(&awsv2.Config{})}
for _, testCase := range testCases {
err := c.Configure(testCase)
assert.Error(t, err)
assert.ErrorIs(t, err, errInvalidHandler)
assert.ErrorIs(t, err, errUnsupportedPosition)
handler.AssertNotCalled(t, "HandleRequest", mock.Anything, mock.Anything)
handler.AssertNotCalled(t, "HandleResponse", mock.Anything, mock.Anything)
}
}

func TestConfigureUnsupported(t *testing.T) {
type unsupportedVersion struct {
SDKVersion
}
err := newConfigurer(nil, nil).Configure(unsupportedVersion{})
assert.Error(t, err)
assert.True(t, errors.Is(err, errInvalidHandler))
assert.True(t, errors.Is(err, errUnsupportedPosition))
handler.AssertNotCalled(t, "HandleRequest", mock.Anything, mock.Anything)
handler.AssertNotCalled(t, "HandleResponse", mock.Anything, mock.Anything)
assert.ErrorIs(t, err, errUnsupportedVersion)
}

func TestAppendOrder(t *testing.T) {
Expand Down Expand Up @@ -193,15 +198,15 @@ func TestAppendOrder(t *testing.T) {
DisableSSL: awsv1.Bool(true),
Endpoint: awsv1.String(server.URL),
})
assert.NoError(t, c.ConfigureSDKv1(&client.Handlers))
assert.NoError(t, c.Configure(SDKv1(&client.Handlers)))
s3v1Client := &s3v1.S3{Client: client}
_, err := s3v1Client.ListBuckets(&s3v1.ListBucketsInput{})
require.NoError(t, err)
assert.Equal(t, testCase.wantOrder, recorder.order)
recorder.order = nil
// v2
cfg := awsv2.Config{Region: "us-east-1"}
assert.NoError(t, c.ConfigureSDKv2(&cfg))
assert.NoError(t, c.Configure(SDKv2(&cfg)))
s3v2Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) {
options.BaseEndpoint = awsv2.String(server.URL)
})
Expand All @@ -212,7 +217,7 @@ func TestAppendOrder(t *testing.T) {
}
}

func TestConfigureSDKv1(t *testing.T) {
func TestRoundTripSDKv1(t *testing.T) {
middleware, recorder, server := setup(t)
defer server.Close()
client := awstesting.NewClient(&awsv1.Config{
Expand All @@ -223,7 +228,7 @@ func TestConfigureSDKv1(t *testing.T) {
})
require.Equal(t, 3, client.Handlers.Build.Len())
require.Equal(t, 1, client.Handlers.ValidateResponse.Len())
assert.NoError(t, newConfigurer(middleware.Handlers()).ConfigureSDKv1(&client.Handlers))
assert.NoError(t, newConfigurer(middleware.Handlers()).Configure(SDKv1(&client.Handlers)))
assert.Equal(t, 5, client.Handlers.Build.Len())
assert.Equal(t, 2, client.Handlers.ValidateResponse.Len())
s3Client := &s3v1.S3{Client: client}
Expand All @@ -234,11 +239,11 @@ func TestConfigureSDKv1(t *testing.T) {
assert.Equal(t, recorder.requestIDs, recorder.responseIDs)
}

func TestConfigureSDKv2(t *testing.T) {
func TestRoundTripSDKv2(t *testing.T) {
middleware, recorder, server := setup(t)
defer server.Close()
cfg := awsv2.Config{Region: "us-east-1", RetryMaxAttempts: 0}
assert.NoError(t, newConfigurer(middleware.Handlers()).ConfigureSDKv2(&cfg))
cfg := awsv2.Config{Region: "mock-region", RetryMaxAttempts: 0}
assert.NoError(t, newConfigurer(middleware.Handlers()).Configure(SDKv2(&cfg)))
s3Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) {
options.BaseEndpoint = awsv2.String(server.URL)
})
Expand Down
33 changes: 33 additions & 0 deletions extension/awsmiddleware/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package awsmiddleware

import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/request"
)

type SDKVersion interface {
unused()
}

type sdkVersion1 struct {
SDKVersion
handlers *request.Handlers
}

// SDKv1 takes in AWS SDKv1 client request handlers.
func SDKv1(handlers *request.Handlers) SDKVersion {
return sdkVersion1{handlers: handlers}
}

type sdkVersion2 struct {
SDKVersion
cfg *aws.Config
}

// SDKv2 takes in an AWS SDKv2 config.
func SDKv2(cfg *aws.Config) SDKVersion {
return sdkVersion2{cfg: cfg}
}

0 comments on commit 71546aa

Please sign in to comment.