diff --git a/exporter/awscloudwatchlogsexporter/exporter.go b/exporter/awscloudwatchlogsexporter/exporter.go index db7bd9230e43..3777dca218cc 100644 --- a/exporter/awscloudwatchlogsexporter/exporter.go +++ b/exporter/awscloudwatchlogsexporter/exporter.go @@ -146,7 +146,7 @@ func (e *exporter) getLogPusher(pusherKey cwlogs.PusherKey) cwlogs.Pusher { func (e *exporter) start(_ context.Context, host component.Host) error { if e.Config.MiddlewareID != nil { - awsmiddleware.TryConfigureSDKv1(e.logger, host.GetExtensions(), *e.Config.MiddlewareID, e.svcStructuredLog.Handlers()) + awsmiddleware.TryConfigure(e.logger, host, *e.Config.MiddlewareID, awsmiddleware.SDKv1(e.svcStructuredLog.Handlers())) } pusherKey := cwlogs.PusherKey{ LogGroupName: e.Config.LogGroupName, diff --git a/exporter/awsemfexporter/emf_exporter.go b/exporter/awsemfexporter/emf_exporter.go index 11128f267a38..ba342bb6b004 100644 --- a/exporter/awsemfexporter/emf_exporter.go +++ b/exporter/awsemfexporter/emf_exporter.go @@ -197,7 +197,7 @@ func (emf *emfExporter) listPushers() []cwlogs.Pusher { func (emf *emfExporter) start(_ context.Context, host component.Host) error { if emf.config.MiddlewareID != nil { - awsmiddleware.TryConfigureSDKv1(emf.config.logger, host.GetExtensions(), *emf.config.MiddlewareID, emf.svcStructuredLog.Handlers()) + awsmiddleware.TryConfigure(emf.config.logger, host, *emf.config.MiddlewareID, awsmiddleware.SDKv1(emf.svcStructuredLog.Handlers())) } return nil } diff --git a/exporter/awsxrayexporter/awsxray.go b/exporter/awsxrayexporter/awsxray.go index 32798ff399d5..3454b5e663a0 100644 --- a/exporter/awsxrayexporter/awsxray.go +++ b/exporter/awsxrayexporter/awsxray.go @@ -88,7 +88,7 @@ func newTracesExporter( exporterhelper.WithStart(func(_ context.Context, host component.Host) error { sender.Start() if cfg.MiddlewareID != nil { - awsmiddleware.TryConfigureSDKv1(logger, host.GetExtensions(), *cfg.MiddlewareID, xrayClient.Handlers()) + awsmiddleware.TryConfigure(logger, host, *cfg.MiddlewareID, awsmiddleware.SDKv1(xrayClient.Handlers())) } return nil }), diff --git a/extension/awsmiddleware/README.md b/extension/awsmiddleware/README.md index 644e43b06bf3..b8b6692cd905 100644 --- a/extension/awsmiddleware/README.md +++ b/extension/awsmiddleware/README.md @@ -1,8 +1,8 @@ # AWS Middleware An AWS middleware extension provides request and/or response handlers that can be configured on AWS SDK v1/v2 clients. -Other components can configure their AWS SDK clients using `awsmiddleware.GetConfigurer` and the `ConfigureSDKv1` and -`ConfigureSDKv2` functions available on the `Configurer`. +Other components can configure their AWS SDK clients using `awsmiddleware.GetConfigurer` and passing the `SDKv1` or `SDKv2` +options into the `Configure` function available on the `Configurer`. The `awsmiddleware.Extension` interface extends `component.Extension` by adding the following method: ``` diff --git a/extension/awsmiddleware/helper.go b/extension/awsmiddleware/helper.go index 44e92630b8cf..7b47fbd197e8 100644 --- a/extension/awsmiddleware/helper.go +++ b/extension/awsmiddleware/helper.go @@ -4,28 +4,15 @@ package awsmiddleware // import "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware" import ( - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go/aws/request" "go.opentelemetry.io/collector/component" "go.uber.org/zap" ) -// TryConfigureSDKv1 is a helper function that will try to get the extension and configure the AWS SDKv1 handlers with it. -func TryConfigureSDKv1(logger *zap.Logger, extensions map[component.ID]component.Component, middlewareID ID, handlers *request.Handlers) { - if c, err := GetConfigurer(extensions, middlewareID); err != nil { +// TryConfigure is a helper function that will try to get the extension and configure the provided AWS SDK with it. +func TryConfigure(logger *zap.Logger, host component.Host, middlewareID component.ID, sdkVersion SDKVersion) { + if c, err := GetConfigurer(host.GetExtensions(), middlewareID); err != nil { logger.Error("Unable to find AWS Middleware extension", zap.Error(err)) - } else if err = c.ConfigureSDKv1(handlers); err != nil { - logger.Warn("Unable to configure middleware on AWS client", zap.Error(err)) - } else { - logger.Debug("Configured middleware on AWS client", zap.String("middleware", middlewareID.String())) - } -} - -// TryConfigureSDKv2 is a helper function that will try to get the extension and configure the AWS SDKv2 config with it. -func TryConfigureSDKv2(logger *zap.Logger, extensions map[component.ID]component.Component, middlewareID ID, cfg *aws.Config) { - if c, err := GetConfigurer(extensions, middlewareID); err != nil { - logger.Error("Unable to find AWS Middleware extension", zap.Error(err)) - } else if err = c.ConfigureSDKv2(cfg); err != nil { + } else if err = c.Configure(sdkVersion); err != nil { logger.Warn("Unable to configure middleware on AWS client", zap.Error(err)) } else { logger.Debug("Configured middleware on AWS client", zap.String("middleware", middlewareID.String())) diff --git a/extension/awsmiddleware/helper_test.go b/extension/awsmiddleware/helper_test.go new file mode 100644 index 000000000000..bd5f5adc535a --- /dev/null +++ b/extension/awsmiddleware/helper_test.go @@ -0,0 +1,43 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package awsmiddleware + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/collector/component" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +func TestTryConfigure(t *testing.T) { + testCases := []SDKVersion{SDKv1(&request.Handlers{}), SDKv2(&aws.Config{})} + for _, testCase := range testCases { + id := component.NewID("test") + host := new(MockExtensionsHost) + host.On("GetExtensions").Return(nil).Once() + core, observed := observer.New(zap.DebugLevel) + logger := zap.New(core) + TryConfigure(logger, host, id, testCase) + require.Len(t, observed.FilterLevelExact(zap.ErrorLevel).TakeAll(), 1) + + extensions := map[component.ID]component.Component{} + host.On("GetExtensions").Return(extensions) + handler := new(MockHandler) + handler.On("ID").Return("test") + handler.On("Position").Return(HandlerPosition(-1)) + extension := new(MockMiddlewareExtension) + extension.On("Handlers").Return([]RequestHandler{handler}, nil).Once() + extensions[id] = extension + TryConfigure(logger, host, id, testCase) + require.Len(t, observed.FilterLevelExact(zap.WarnLevel).TakeAll(), 1) + + extension.On("Handlers").Return(nil, nil).Once() + TryConfigure(logger, host, id, testCase) + require.Len(t, observed.FilterLevelExact(zap.DebugLevel).TakeAll(), 1) + } +} diff --git a/extension/awsmiddleware/middleware.go b/extension/awsmiddleware/middleware.go index 1d0154358754..0d02f3769538 100644 --- a/extension/awsmiddleware/middleware.go +++ b/extension/awsmiddleware/middleware.go @@ -20,6 +20,7 @@ var ( errNotMiddleware = errors.New("extension is not an AWS middleware") errInvalidHandler = errors.New("invalid handler") errUnsupportedPosition = errors.New("unsupported position") + errUnsupportedVersion = errors.New("unsupported SDK version") ) // HandlerPosition is the relative position of a handler used during insertion. @@ -118,11 +119,23 @@ func newConfigurer(requestHandlers []RequestHandler, responseHandlers []Response return &Configurer{requestHandlers: requestHandlers, responseHandlers: responseHandlers} } -// ConfigureSDKv1 adds middleware to the AWS SDK v1. Request handlers are added to the +// Configure configures the handlers on the provided AWS SDK. +func (c Configurer) Configure(sdkVersion SDKVersion) error { + switch v := sdkVersion.(type) { + case sdkVersion1: + return c.configureSDKv1(v.handlers) + case sdkVersion2: + return c.configureSDKv2(v.cfg) + default: + return fmt.Errorf("%w: %T", errUnsupportedVersion, v) + } +} + +// configureSDKv1 adds middleware to the AWS SDK v1. Request handlers are added to the // Build handler list and response handlers are added to the ValidateResponse handler list. // Build will only be run once per request, but if there are errors, ValidateResponse will // be run again for each configured retry. -func (c Configurer) ConfigureSDKv1(handlers *request.Handlers) error { +func (c Configurer) configureSDKv1(handlers *request.Handlers) error { var errs error for _, handler := range c.requestHandlers { if err := appendHandler(&handlers.Build, namedRequestHandler(handler), handler.Position()); err != nil { @@ -137,9 +150,9 @@ func (c Configurer) ConfigureSDKv1(handlers *request.Handlers) error { return errs } -// ConfigureSDKv2 adds middleware to the AWS SDK v2. Request handlers are added to the +// configureSDKv2 adds middleware to the AWS SDK v2. Request handlers are added to the // Build step and response handlers are added to the Deserialize step. -func (c Configurer) ConfigureSDKv2(config *aws.Config) error { +func (c Configurer) configureSDKv2(config *aws.Config) error { var errs error for _, handler := range c.requestHandlers { relativePosition, err := toRelativePosition(handler.Position()) @@ -160,7 +173,7 @@ func (c Configurer) ConfigureSDKv2(config *aws.Config) error { return errs } -// addHandlerToList adds the handler to the list based on the position. +// appendHandler adds the handler to the list based on the position. func appendHandler(handlerList *request.HandlerList, handler request.NamedHandler, position HandlerPosition) error { relativePosition, err := toRelativePosition(position) if err != nil { diff --git a/extension/awsmiddleware/middleware_test.go b/extension/awsmiddleware/middleware_test.go index 0ac775c42252..18b0af148d24 100644 --- a/extension/awsmiddleware/middleware_test.go +++ b/extension/awsmiddleware/middleware_test.go @@ -5,7 +5,6 @@ package awsmiddleware import ( "context" - "errors" "net/http" "net/http/httptest" "testing" @@ -14,6 +13,7 @@ import ( awsv2 "github.com/aws/aws-sdk-go-v2/aws" s3v2 "github.com/aws/aws-sdk-go-v2/service/s3" awsv1 "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/awstesting" s3v1 "github.com/aws/aws-sdk-go/service/s3" "github.com/stretchr/testify/assert" @@ -114,19 +114,24 @@ func TestInvalidHandlers(t *testing.T) { middleware := new(MockMiddlewareExtension) middleware.On("Handlers").Return([]RequestHandler{handler}, []ResponseHandler{handler}) c := newConfigurer(middleware.Handlers()) - // v1 - client := awstesting.NewClient() - err := c.ConfigureSDKv1(&client.Handlers) - assert.Error(t, err) - assert.True(t, errors.Is(err, errInvalidHandler)) - assert.True(t, errors.Is(err, errUnsupportedPosition)) - // v2 - err = c.ConfigureSDKv2(&awsv2.Config{}) + testCases := []SDKVersion{SDKv1(&request.Handlers{}), SDKv2(&awsv2.Config{})} + for _, testCase := range testCases { + err := c.Configure(testCase) + assert.Error(t, err) + assert.ErrorIs(t, err, errInvalidHandler) + assert.ErrorIs(t, err, errUnsupportedPosition) + handler.AssertNotCalled(t, "HandleRequest", mock.Anything, mock.Anything) + handler.AssertNotCalled(t, "HandleResponse", mock.Anything, mock.Anything) + } +} + +func TestConfigureUnsupported(t *testing.T) { + type unsupportedVersion struct { + SDKVersion + } + err := newConfigurer(nil, nil).Configure(unsupportedVersion{}) assert.Error(t, err) - assert.True(t, errors.Is(err, errInvalidHandler)) - assert.True(t, errors.Is(err, errUnsupportedPosition)) - handler.AssertNotCalled(t, "HandleRequest", mock.Anything, mock.Anything) - handler.AssertNotCalled(t, "HandleResponse", mock.Anything, mock.Anything) + assert.ErrorIs(t, err, errUnsupportedVersion) } func TestAppendOrder(t *testing.T) { @@ -193,7 +198,7 @@ func TestAppendOrder(t *testing.T) { DisableSSL: awsv1.Bool(true), Endpoint: awsv1.String(server.URL), }) - assert.NoError(t, c.ConfigureSDKv1(&client.Handlers)) + assert.NoError(t, c.Configure(SDKv1(&client.Handlers))) s3v1Client := &s3v1.S3{Client: client} _, err := s3v1Client.ListBuckets(&s3v1.ListBucketsInput{}) require.NoError(t, err) @@ -201,7 +206,7 @@ func TestAppendOrder(t *testing.T) { recorder.order = nil // v2 cfg := awsv2.Config{Region: "us-east-1"} - assert.NoError(t, c.ConfigureSDKv2(&cfg)) + assert.NoError(t, c.Configure(SDKv2(&cfg))) s3v2Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) { options.BaseEndpoint = awsv2.String(server.URL) }) @@ -212,7 +217,7 @@ func TestAppendOrder(t *testing.T) { } } -func TestConfigureSDKv1(t *testing.T) { +func TestRoundTripSDKv1(t *testing.T) { middleware, recorder, server := setup(t) defer server.Close() client := awstesting.NewClient(&awsv1.Config{ @@ -223,7 +228,7 @@ func TestConfigureSDKv1(t *testing.T) { }) require.Equal(t, 3, client.Handlers.Build.Len()) require.Equal(t, 1, client.Handlers.ValidateResponse.Len()) - assert.NoError(t, newConfigurer(middleware.Handlers()).ConfigureSDKv1(&client.Handlers)) + assert.NoError(t, newConfigurer(middleware.Handlers()).Configure(SDKv1(&client.Handlers))) assert.Equal(t, 5, client.Handlers.Build.Len()) assert.Equal(t, 2, client.Handlers.ValidateResponse.Len()) s3Client := &s3v1.S3{Client: client} @@ -234,11 +239,11 @@ func TestConfigureSDKv1(t *testing.T) { assert.Equal(t, recorder.requestIDs, recorder.responseIDs) } -func TestConfigureSDKv2(t *testing.T) { +func TestRoundTripSDKv2(t *testing.T) { middleware, recorder, server := setup(t) defer server.Close() - cfg := awsv2.Config{Region: "us-east-1", RetryMaxAttempts: 0} - assert.NoError(t, newConfigurer(middleware.Handlers()).ConfigureSDKv2(&cfg)) + cfg := awsv2.Config{Region: "mock-region", RetryMaxAttempts: 0} + assert.NoError(t, newConfigurer(middleware.Handlers()).Configure(SDKv2(&cfg))) s3Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) { options.BaseEndpoint = awsv2.String(server.URL) }) diff --git a/extension/awsmiddleware/options.go b/extension/awsmiddleware/options.go new file mode 100644 index 000000000000..fccaf75eb26f --- /dev/null +++ b/extension/awsmiddleware/options.go @@ -0,0 +1,33 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package awsmiddleware + +import ( + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go/aws/request" +) + +type SDKVersion interface { + unused() +} + +type sdkVersion1 struct { + SDKVersion + handlers *request.Handlers +} + +// SDKv1 takes in AWS SDKv1 client request handlers. +func SDKv1(handlers *request.Handlers) SDKVersion { + return sdkVersion1{handlers: handlers} +} + +type sdkVersion2 struct { + SDKVersion + cfg *aws.Config +} + +// SDKv2 takes in an AWS SDKv2 config. +func SDKv2(cfg *aws.Config) SDKVersion { + return sdkVersion2{cfg: cfg} +}