Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[extension/awsmiddleware] Change Configure to take in either SDK version #133

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exporter/awscloudwatchlogsexporter/exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (e *exporter) getLogPusher(pusherKey cwlogs.PusherKey) cwlogs.Pusher {

func (e *exporter) start(_ context.Context, host component.Host) error {
if e.Config.MiddlewareID != nil {
awsmiddleware.TryConfigureSDKv1(e.logger, host.GetExtensions(), *e.Config.MiddlewareID, e.svcStructuredLog.Handlers())
awsmiddleware.TryConfigure(e.logger, host, *e.Config.MiddlewareID, awsmiddleware.SDKv1(e.svcStructuredLog.Handlers()))
}
pusherKey := cwlogs.PusherKey{
LogGroupName: e.Config.LogGroupName,
Expand Down
2 changes: 1 addition & 1 deletion exporter/awsemfexporter/emf_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (emf *emfExporter) listPushers() []cwlogs.Pusher {

func (emf *emfExporter) start(_ context.Context, host component.Host) error {
if emf.config.MiddlewareID != nil {
awsmiddleware.TryConfigureSDKv1(emf.config.logger, host.GetExtensions(), *emf.config.MiddlewareID, emf.svcStructuredLog.Handlers())
awsmiddleware.TryConfigure(emf.config.logger, host, *emf.config.MiddlewareID, awsmiddleware.SDKv1(emf.svcStructuredLog.Handlers()))
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion exporter/awsxrayexporter/awsxray.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func newTracesExporter(
exporterhelper.WithStart(func(_ context.Context, host component.Host) error {
sender.Start()
if cfg.MiddlewareID != nil {
awsmiddleware.TryConfigureSDKv1(logger, host.GetExtensions(), *cfg.MiddlewareID, xrayClient.Handlers())
awsmiddleware.TryConfigure(logger, host, *cfg.MiddlewareID, awsmiddleware.SDKv1(xrayClient.Handlers()))
}
return nil
}),
Expand Down
4 changes: 2 additions & 2 deletions extension/awsmiddleware/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# AWS Middleware

An AWS middleware extension provides request and/or response handlers that can be configured on AWS SDK v1/v2 clients.
Other components can configure their AWS SDK clients using `awsmiddleware.GetConfigurer` and the `ConfigureSDKv1` and
`ConfigureSDKv2` functions available on the `Configurer`.
Other components can configure their AWS SDK clients using `awsmiddleware.GetConfigurer` and passing the `SDKv1` or `SDKv2`
options into the `Configure` function available on the `Configurer`.

The `awsmiddleware.Extension` interface extends `component.Extension` by adding the following method:
```
Expand Down
21 changes: 4 additions & 17 deletions extension/awsmiddleware/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,15 @@
package awsmiddleware // import "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware"

import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/request"
"go.opentelemetry.io/collector/component"
"go.uber.org/zap"
)

// TryConfigureSDKv1 is a helper function that will try to get the extension and configure the AWS SDKv1 handlers with it.
func TryConfigureSDKv1(logger *zap.Logger, extensions map[component.ID]component.Component, middlewareID ID, handlers *request.Handlers) {
if c, err := GetConfigurer(extensions, middlewareID); err != nil {
// TryConfigure is a helper function that will try to get the extension and configure the provided AWS SDK with it.
func TryConfigure(logger *zap.Logger, host component.Host, middlewareID component.ID, sdkVersion SDKVersion) {
if c, err := GetConfigurer(host.GetExtensions(), middlewareID); err != nil {
logger.Error("Unable to find AWS Middleware extension", zap.Error(err))
} else if err = c.ConfigureSDKv1(handlers); err != nil {
logger.Warn("Unable to configure middleware on AWS client", zap.Error(err))
} else {
logger.Debug("Configured middleware on AWS client", zap.String("middleware", middlewareID.String()))
}
}

// TryConfigureSDKv2 is a helper function that will try to get the extension and configure the AWS SDKv2 config with it.
func TryConfigureSDKv2(logger *zap.Logger, extensions map[component.ID]component.Component, middlewareID ID, cfg *aws.Config) {
if c, err := GetConfigurer(extensions, middlewareID); err != nil {
logger.Error("Unable to find AWS Middleware extension", zap.Error(err))
} else if err = c.ConfigureSDKv2(cfg); err != nil {
} else if err = c.Configure(sdkVersion); err != nil {
logger.Warn("Unable to configure middleware on AWS client", zap.Error(err))
} else {
logger.Debug("Configured middleware on AWS client", zap.String("middleware", middlewareID.String()))
Expand Down
43 changes: 43 additions & 0 deletions extension/awsmiddleware/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package awsmiddleware

import (
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/collector/component"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
)

func TestTryConfigure(t *testing.T) {
testCases := []SDKVersion{SDKv1(&request.Handlers{}), SDKv2(&aws.Config{})}
for _, testCase := range testCases {
id := component.NewID("test")
host := new(MockExtensionsHost)
host.On("GetExtensions").Return(nil).Once()
core, observed := observer.New(zap.DebugLevel)
logger := zap.New(core)
TryConfigure(logger, host, id, testCase)
require.Len(t, observed.FilterLevelExact(zap.ErrorLevel).TakeAll(), 1)

extensions := map[component.ID]component.Component{}
host.On("GetExtensions").Return(extensions)
handler := new(MockHandler)
handler.On("ID").Return("test")
handler.On("Position").Return(HandlerPosition(-1))
extension := new(MockMiddlewareExtension)
extension.On("Handlers").Return([]RequestHandler{handler}, nil).Once()
extensions[id] = extension
TryConfigure(logger, host, id, testCase)
require.Len(t, observed.FilterLevelExact(zap.WarnLevel).TakeAll(), 1)

extension.On("Handlers").Return(nil, nil).Once()
TryConfigure(logger, host, id, testCase)
require.Len(t, observed.FilterLevelExact(zap.DebugLevel).TakeAll(), 1)
}
}
23 changes: 18 additions & 5 deletions extension/awsmiddleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (
errNotMiddleware = errors.New("extension is not an AWS middleware")
errInvalidHandler = errors.New("invalid handler")
errUnsupportedPosition = errors.New("unsupported position")
errUnsupportedVersion = errors.New("unsupported SDK version")
)

// HandlerPosition is the relative position of a handler used during insertion.
Expand Down Expand Up @@ -118,11 +119,23 @@ func newConfigurer(requestHandlers []RequestHandler, responseHandlers []Response
return &Configurer{requestHandlers: requestHandlers, responseHandlers: responseHandlers}
}

// ConfigureSDKv1 adds middleware to the AWS SDK v1. Request handlers are added to the
// Configure configures the handlers on the provided AWS SDK.
func (c Configurer) Configure(sdkVersion SDKVersion) error {
switch v := sdkVersion.(type) {
case sdkVersion1:
return c.configureSDKv1(v.handlers)
case sdkVersion2:
return c.configureSDKv2(v.cfg)
default:
return fmt.Errorf("%w: %T", errUnsupportedVersion, v)
}
}

// configureSDKv1 adds middleware to the AWS SDK v1. Request handlers are added to the
// Build handler list and response handlers are added to the ValidateResponse handler list.
// Build will only be run once per request, but if there are errors, ValidateResponse will
// be run again for each configured retry.
func (c Configurer) ConfigureSDKv1(handlers *request.Handlers) error {
func (c Configurer) configureSDKv1(handlers *request.Handlers) error {
var errs error
for _, handler := range c.requestHandlers {
if err := appendHandler(&handlers.Build, namedRequestHandler(handler), handler.Position()); err != nil {
Expand All @@ -137,9 +150,9 @@ func (c Configurer) ConfigureSDKv1(handlers *request.Handlers) error {
return errs
}

// ConfigureSDKv2 adds middleware to the AWS SDK v2. Request handlers are added to the
// configureSDKv2 adds middleware to the AWS SDK v2. Request handlers are added to the
// Build step and response handlers are added to the Deserialize step.
func (c Configurer) ConfigureSDKv2(config *aws.Config) error {
func (c Configurer) configureSDKv2(config *aws.Config) error {
var errs error
for _, handler := range c.requestHandlers {
relativePosition, err := toRelativePosition(handler.Position())
Expand All @@ -160,7 +173,7 @@ func (c Configurer) ConfigureSDKv2(config *aws.Config) error {
return errs
}

// addHandlerToList adds the handler to the list based on the position.
// appendHandler adds the handler to the list based on the position.
func appendHandler(handlerList *request.HandlerList, handler request.NamedHandler, position HandlerPosition) error {
relativePosition, err := toRelativePosition(position)
if err != nil {
Expand Down
45 changes: 25 additions & 20 deletions extension/awsmiddleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package awsmiddleware

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -14,6 +13,7 @@ import (
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
s3v2 "github.com/aws/aws-sdk-go-v2/service/s3"
awsv1 "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
s3v1 "github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -114,19 +114,24 @@ func TestInvalidHandlers(t *testing.T) {
middleware := new(MockMiddlewareExtension)
middleware.On("Handlers").Return([]RequestHandler{handler}, []ResponseHandler{handler})
c := newConfigurer(middleware.Handlers())
// v1
client := awstesting.NewClient()
err := c.ConfigureSDKv1(&client.Handlers)
assert.Error(t, err)
assert.True(t, errors.Is(err, errInvalidHandler))
assert.True(t, errors.Is(err, errUnsupportedPosition))
// v2
err = c.ConfigureSDKv2(&awsv2.Config{})
testCases := []SDKVersion{SDKv1(&request.Handlers{}), SDKv2(&awsv2.Config{})}
for _, testCase := range testCases {
err := c.Configure(testCase)
assert.Error(t, err)
assert.ErrorIs(t, err, errInvalidHandler)
assert.ErrorIs(t, err, errUnsupportedPosition)
handler.AssertNotCalled(t, "HandleRequest", mock.Anything, mock.Anything)
handler.AssertNotCalled(t, "HandleResponse", mock.Anything, mock.Anything)
}
}

func TestConfigureUnsupported(t *testing.T) {
type unsupportedVersion struct {
SDKVersion
}
err := newConfigurer(nil, nil).Configure(unsupportedVersion{})
assert.Error(t, err)
assert.True(t, errors.Is(err, errInvalidHandler))
assert.True(t, errors.Is(err, errUnsupportedPosition))
handler.AssertNotCalled(t, "HandleRequest", mock.Anything, mock.Anything)
handler.AssertNotCalled(t, "HandleResponse", mock.Anything, mock.Anything)
assert.ErrorIs(t, err, errUnsupportedVersion)
}

func TestAppendOrder(t *testing.T) {
Expand Down Expand Up @@ -193,15 +198,15 @@ func TestAppendOrder(t *testing.T) {
DisableSSL: awsv1.Bool(true),
Endpoint: awsv1.String(server.URL),
})
assert.NoError(t, c.ConfigureSDKv1(&client.Handlers))
assert.NoError(t, c.Configure(SDKv1(&client.Handlers)))
s3v1Client := &s3v1.S3{Client: client}
_, err := s3v1Client.ListBuckets(&s3v1.ListBucketsInput{})
require.NoError(t, err)
assert.Equal(t, testCase.wantOrder, recorder.order)
recorder.order = nil
// v2
cfg := awsv2.Config{Region: "us-east-1"}
assert.NoError(t, c.ConfigureSDKv2(&cfg))
assert.NoError(t, c.Configure(SDKv2(&cfg)))
s3v2Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) {
options.BaseEndpoint = awsv2.String(server.URL)
})
Expand All @@ -212,7 +217,7 @@ func TestAppendOrder(t *testing.T) {
}
}

func TestConfigureSDKv1(t *testing.T) {
func TestRoundTripSDKv1(t *testing.T) {
middleware, recorder, server := setup(t)
defer server.Close()
client := awstesting.NewClient(&awsv1.Config{
Expand All @@ -223,7 +228,7 @@ func TestConfigureSDKv1(t *testing.T) {
})
require.Equal(t, 3, client.Handlers.Build.Len())
require.Equal(t, 1, client.Handlers.ValidateResponse.Len())
assert.NoError(t, newConfigurer(middleware.Handlers()).ConfigureSDKv1(&client.Handlers))
assert.NoError(t, newConfigurer(middleware.Handlers()).Configure(SDKv1(&client.Handlers)))
assert.Equal(t, 5, client.Handlers.Build.Len())
assert.Equal(t, 2, client.Handlers.ValidateResponse.Len())
s3Client := &s3v1.S3{Client: client}
Expand All @@ -234,11 +239,11 @@ func TestConfigureSDKv1(t *testing.T) {
assert.Equal(t, recorder.requestIDs, recorder.responseIDs)
}

func TestConfigureSDKv2(t *testing.T) {
func TestRoundTripSDKv2(t *testing.T) {
middleware, recorder, server := setup(t)
defer server.Close()
cfg := awsv2.Config{Region: "us-east-1", RetryMaxAttempts: 0}
assert.NoError(t, newConfigurer(middleware.Handlers()).ConfigureSDKv2(&cfg))
cfg := awsv2.Config{Region: "mock-region", RetryMaxAttempts: 0}
assert.NoError(t, newConfigurer(middleware.Handlers()).Configure(SDKv2(&cfg)))
s3Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) {
options.BaseEndpoint = awsv2.String(server.URL)
})
Expand Down
33 changes: 33 additions & 0 deletions extension/awsmiddleware/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package awsmiddleware

import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/request"
)

type SDKVersion interface {
unused()
}

type sdkVersion1 struct {
SDKVersion
handlers *request.Handlers
}

// SDKv1 takes in AWS SDKv1 client request handlers.
func SDKv1(handlers *request.Handlers) SDKVersion {
return sdkVersion1{handlers: handlers}
}

type sdkVersion2 struct {
SDKVersion
cfg *aws.Config
}

// SDKv2 takes in an AWS SDKv2 config.
func SDKv2(cfg *aws.Config) SDKVersion {
return sdkVersion2{cfg: cfg}
}
Loading