From e5c60ba404f2c2fcea2534d4dad3b192acc4ae1e Mon Sep 17 00:00:00 2001 From: Jeffrey Chien Date: Fri, 20 Oct 2023 08:56:47 -0400 Subject: [PATCH] Add ID to context to trace requests. Fix go.mod name. --- extension/awsmiddleware/README.md | 12 +-- extension/awsmiddleware/config.go | 30 +++--- extension/awsmiddleware/config_test.go | 24 +---- extension/awsmiddleware/doc.go | 2 +- extension/awsmiddleware/go.mod | 12 ++- extension/awsmiddleware/go.sum | 24 +++-- extension/awsmiddleware/middleware.go | 36 ++++--- extension/awsmiddleware/middleware_test.go | 103 +++++++++++++-------- extension/awsmiddleware/mock.go | 78 ++++++++++++++++ extension/awsmiddleware/wrapper.go | 35 ++++++- 10 files changed, 247 insertions(+), 109 deletions(-) create mode 100644 extension/awsmiddleware/mock.go diff --git a/extension/awsmiddleware/README.md b/extension/awsmiddleware/README.md index 67885f503174..3336eb98563b 100644 --- a/extension/awsmiddleware/README.md +++ b/extension/awsmiddleware/README.md @@ -3,27 +3,27 @@ An AWS middleware extension provides request and/or response handlers that can be configured on AWS SDK v1/v2 clients. Other components can configure their AWS SDK clients using the `awsmiddleware.ConfigureSDKv1` and `awsmiddleware.ConfigureSDKv2` functions. -The `awsmiddleware.Extension` interface extends `component.Extension` by adding the following methods: +The `awsmiddleware.Extension` interface extends `component.Extension` by adding the following method: ``` -RequestHandlers() []RequestHandler -ResponseHandlers() []ResponseHandler +Handlers() ([]RequestHandler, []ResponseHandler) ``` The `awsmiddleware.RequestHandler` interface contains the following methods: ``` ID() string Position() HandlerPosition -HandleRequest(r *http.Request) +HandleRequest(id string, r *http.Request) ``` The `awsmiddleware.ResponseHandler` interface contains the following methods: ``` ID() string Position() HandlerPosition -HandleResponse(r *http.Response) +HandleResponse(id string, r *http.Response) ``` - `ID` uniquely identifies a handler. Middleware will fail if there is clashing - `Position` determines whether the handler is appended to the front or back of the existing list. Insertion is done in the order of the handlers provided. -- `HandleRequest/Response` provides a hook to handle the request/response before and after they've been sent. \ No newline at end of file +- `HandleRequest/Response` provides a hook to handle the request/response before and after they've been sent along +with an attached ID. \ No newline at end of file diff --git a/extension/awsmiddleware/config.go b/extension/awsmiddleware/config.go index f221a347dd35..ac70006f523c 100644 --- a/extension/awsmiddleware/config.go +++ b/extension/awsmiddleware/config.go @@ -1,7 +1,7 @@ // Copyright The OpenTelemetry Authors // SPDX-License-Identifier: Apache-2.0 -package awsmiddleware // import "github.com/open-telemetry/opentelemetry-collector-contrib/extension/awsmiddleware" +package awsmiddleware // import "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware" import ( "fmt" @@ -9,19 +9,25 @@ import ( "go.opentelemetry.io/collector/component" ) -// Config defines the configuration for an AWS Middleware extension. -type Config struct { - // MiddlewareID is the ID of the Middleware extension. - MiddlewareID component.ID `mapstructure:"middleware"` -} +type ID = component.ID -// GetMiddleware retrieves the extension implementing Middleware based on the MiddlewareID. -func (c Config) GetMiddleware(extensions map[component.ID]component.Component) (Middleware, error) { - if ext, found := extensions[c.MiddlewareID]; found { - if mw, ok := ext.(Middleware); ok { - return mw, nil +// getMiddleware retrieves the extension implementing Middleware based on the middlewareID. +func getMiddleware(extensions map[component.ID]component.Component, middlewareID ID) (Middleware, error) { + if extension, found := extensions[middlewareID]; found { + if middleware, ok := extension.(Middleware); ok { + return middleware, nil } return nil, errNotMiddleware } - return nil, fmt.Errorf("failed to resolve AWS client handler %q: %w", c.MiddlewareID, errNotFound) + return nil, fmt.Errorf("failed to resolve AWS middleware %q: %w", middlewareID, errNotFound) +} + +// GetConfigurer retrieves the extension implementing Middleware based on the middlewareID and +// wraps it in a Configurer. +func GetConfigurer(extensions map[component.ID]component.Component, middlewareID ID) (*Configurer, error) { + middleware, err := getMiddleware(extensions, middlewareID) + if err != nil { + return nil, err + } + return NewConfigurer(middleware), nil } diff --git a/extension/awsmiddleware/config_test.go b/extension/awsmiddleware/config_test.go index cb317acd6b46..cda9c1ba3294 100644 --- a/extension/awsmiddleware/config_test.go +++ b/extension/awsmiddleware/config_test.go @@ -13,26 +13,8 @@ import ( "go.opentelemetry.io/collector/extension/extensiontest" ) -type testMiddlewareExtension struct { - component.StartFunc - component.ShutdownFunc - requestHandlers []RequestHandler - responseHandlers []ResponseHandler -} - -var _ Extension = (*testMiddlewareExtension)(nil) - -func (t *testMiddlewareExtension) RequestHandlers() []RequestHandler { - return t.requestHandlers -} - -func (t *testMiddlewareExtension) ResponseHandlers() []ResponseHandler { - return t.responseHandlers -} - -func TestGetMiddleware(t *testing.T) { +func TestGetConfigurer(t *testing.T) { id := component.NewID("test") - cfg := &Config{MiddlewareID: id} nopExtension, err := extensiontest.NewNopBuilder().Create(context.Background(), extensiontest.NewNopCreateSettings()) require.Error(t, err) testCases := map[string]struct { @@ -48,12 +30,12 @@ func TestGetMiddleware(t *testing.T) { wantErr: errNotMiddleware, }, "WithMiddlewareExtension": { - extensions: map[component.ID]component.Component{id: &testMiddlewareExtension{}}, + extensions: map[component.ID]component.Component{id: new(MockMiddlewareExtension)}, }, } for name, testCase := range testCases { t.Run(name, func(t *testing.T) { - got, err := cfg.GetMiddleware(testCase.extensions) + got, err := GetConfigurer(testCase.extensions, id) if testCase.wantErr != nil { assert.Error(t, err) assert.ErrorIs(t, err, testCase.wantErr) diff --git a/extension/awsmiddleware/doc.go b/extension/awsmiddleware/doc.go index dfc6a74dc27e..1426ff222520 100644 --- a/extension/awsmiddleware/doc.go +++ b/extension/awsmiddleware/doc.go @@ -3,4 +3,4 @@ // Package awsmiddleware defines an extension interface providing request and response handlers that can be // configured on AWS SDK clients. -package awsmiddleware // import "github.com/open-telemetry/opentelemetry-collector-contrib/extension/awsmiddleware" +package awsmiddleware // import "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware" diff --git a/extension/awsmiddleware/go.mod b/extension/awsmiddleware/go.mod index 2106c0e24556..b2d79308f3ee 100644 --- a/extension/awsmiddleware/go.mod +++ b/extension/awsmiddleware/go.mod @@ -1,12 +1,13 @@ -module github.com/open-telemetry/opentelemetry-collector-contrib/extension/awsmiddleware +module github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware go 1.20 require ( github.com/aws/aws-sdk-go v1.45.24 - github.com/aws/aws-sdk-go-v2 v1.21.1 + github.com/aws/aws-sdk-go-v2 v1.21.2 github.com/aws/aws-sdk-go-v2/service/s3 v1.40.0 github.com/aws/smithy-go v1.15.0 + github.com/google/uuid v1.3.0 github.com/stretchr/testify v1.8.4 go.opentelemetry.io/collector/component v0.87.0 go.opentelemetry.io/collector/extension v0.87.0 @@ -14,12 +15,12 @@ require ( require ( github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.13 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.42 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.36 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.43 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.37 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.1.4 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.14 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.36 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.36 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.37 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.15.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -32,6 +33,7 @@ require ( github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect go.opentelemetry.io/collector/config/configtelemetry v0.87.0 // indirect go.opentelemetry.io/collector/confmap v0.87.0 // indirect go.opentelemetry.io/collector/featuregate v1.0.0-rcv0016 // indirect diff --git a/extension/awsmiddleware/go.sum b/extension/awsmiddleware/go.sum index 8daaa48f656f..cc0cd3821a5e 100644 --- a/extension/awsmiddleware/go.sum +++ b/extension/awsmiddleware/go.sum @@ -1,16 +1,16 @@ github.com/aws/aws-sdk-go v1.45.24 h1:TZx/CizkmCQn8Rtsb11iLYutEQVGK5PK9wAhwouELBo= github.com/aws/aws-sdk-go v1.45.24/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go-v2 v1.21.0/go.mod h1:/RfNgGmRxI+iFOB1OeJUyxiU+9s88k3pfHvDagGEp0M= -github.com/aws/aws-sdk-go-v2 v1.21.1 h1:wjHYshtPpYOZm+/mu3NhVgRRc0baM6LJZOmxPZ5Cwzs= -github.com/aws/aws-sdk-go-v2 v1.21.1/go.mod h1:ErQhvNuEMhJjweavOYhxVkn2RUx7kQXVATHrjKtxIpM= +github.com/aws/aws-sdk-go-v2 v1.21.2 h1:+LXZ0sgo8quN9UOKXXzAWRT3FWd4NxeXWOZom9pE7GA= +github.com/aws/aws-sdk-go-v2 v1.21.2/go.mod h1:ErQhvNuEMhJjweavOYhxVkn2RUx7kQXVATHrjKtxIpM= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.13 h1:OPLEkmhXf6xFPiz0bLeDArZIDx1NNS4oJyG4nv3Gct0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.13/go.mod h1:gpAbvyDGQFozTEmlTFO8XcQKHzubdq0LzRyJpG6MiXM= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41/go.mod h1:CrObHAuPneJBlfEJ5T3szXOUkLEThaGfvnhTf33buas= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.42 h1:817VqVe6wvwE46xXy6YF5RywvjOX6U2zRQQ6IbQFK0s= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.42/go.mod h1:oDfgXoBBmj+kXnqxDDnIDnC56QBosglKp8ftRCTxR+0= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.43 h1:nFBQlGtkbPzp/NjZLuFxRqmT91rLJkgvsEQs68h962Y= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.43/go.mod h1:auo+PiyLl0n1l8A0e8RIeR8tOzYPfZZH/JNlrJ8igTQ= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35/go.mod h1:SJC1nEVVva1g3pHAIdCp7QsRIkMmLAgoDquQ9Rr8kYw= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.36 h1:7ZApaXzWbo8slc+W5TynuUlB4z66g44h7uqa3/d/BsY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.36/go.mod h1:rwr4WnmFi3RJO0M4dxbJtgi9BPLMpVBMX1nUte5ha9U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.37 h1:JRVhO25+r3ar2mKGP7E0LDl8K9/G36gjlqca5iQbaqc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.37/go.mod h1:Qe+2KtKml+FEsQF/DHmDV+xjtche/hwoF75EG4UlHW8= github.com/aws/aws-sdk-go-v2/internal/v4a v1.1.4 h1:6lJvvkQ9HmbHZ4h/IEwclwv2mrTW8Uq1SOB/kXy0mfw= github.com/aws/aws-sdk-go-v2/internal/v4a v1.1.4/go.mod h1:1PrKYwxTM+zjpw9Y41KFtoJCQrJ34Z47Y4VgVbfndjo= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.14 h1:m0QTSI6pZYJTk5WSKx3fm5cNW/DCicVzULBgU/6IyD0= @@ -18,8 +18,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.14/go.mod h1: github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.36 h1:eev2yZX7esGRjqRbnVk1UxMLw4CyVZDpZXRCcy75oQk= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.36/go.mod h1:lGnOkH9NJATw0XEPcAknFBj3zzNTEGRHtSw+CwC1YTg= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35/go.mod h1:QGF2Rs33W5MaN9gYdEQOBBFPLwTZkEhRwI33f7KIG0o= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.36 h1:YXlm7LxwNlauqb2OrinWlcvtsflTzP8GaMvYfQBhoT4= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.36/go.mod h1:ou9ffqJ9hKOVZmjlC6kQ6oROAyG1M4yBKzR+9BKbDwk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.37 h1:WWZA/I2K4ptBS1kg0kV1JbBtG/umed0vwHRrmcr9z7k= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.37/go.mod h1:vBmDnwWXWxNPFRMmG2m/3MKOe+xEcMDo1tanpaWCcck= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.15.4 h1:v0jkRigbSD6uOdwcaUQmgEwG1BkPfAPDqaeNt/29ghg= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.15.4/go.mod h1:LhTyt8J04LL+9cIt7pYJ5lbS/U98ZmXovLOR/4LUsk8= github.com/aws/aws-sdk-go-v2/service/s3 v1.40.0 h1:wl5dxN1NONhTDQD9uaEvNsDRX29cBmGED/nl0jkWlt4= @@ -40,6 +40,8 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -65,6 +67,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -153,5 +160,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/extension/awsmiddleware/middleware.go b/extension/awsmiddleware/middleware.go index 82e42e1d4b28..9ede32d48afb 100644 --- a/extension/awsmiddleware/middleware.go +++ b/extension/awsmiddleware/middleware.go @@ -1,7 +1,7 @@ // Copyright The OpenTelemetry Authors // SPDX-License-Identifier: Apache-2.0 -package awsmiddleware // import "github.com/open-telemetry/opentelemetry-collector-contrib/extension/awsmiddleware" +package awsmiddleware // import "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware" import ( "encoding" @@ -85,20 +85,19 @@ type handlerConfig interface { // RequestHandler allows for custom processing of requests. type RequestHandler interface { handlerConfig - HandleRequest(r *http.Request) + HandleRequest(id string, r *http.Request) } // ResponseHandler allows for custom processing of responses. type ResponseHandler interface { handlerConfig - HandleResponse(r *http.Response) + HandleResponse(id string, r *http.Response) } // Middleware defines the request and response handlers to be configured // on AWS Clients. type Middleware interface { - RequestHandlers() []RequestHandler - ResponseHandlers() []ResponseHandler + Handlers() ([]RequestHandler, []ResponseHandler) } // Extension is an extension that implements Middleware. @@ -107,17 +106,29 @@ type Extension interface { Middleware } +// Configurer wraps a Middleware and provides convenience functions +// for applying it to the AWS SDKs. +type Configurer struct { + Middleware +} + +// NewConfigurer wraps the Middleware. +func NewConfigurer(mw Middleware) *Configurer { + return &Configurer{Middleware: mw} +} + // ConfigureSDKv1 adds middleware to the AWS SDK v1. Request handlers are added to the // Build handler list and response handlers are added to the Unmarshal handler list. -func ConfigureSDKv1(mw Middleware, handlers *request.Handlers) error { +func (c *Configurer) ConfigureSDKv1(handlers *request.Handlers) error { var errs error - for _, handler := range mw.RequestHandlers() { + requestHandlers, responseHandlers := c.Middleware.Handlers() + for _, handler := range requestHandlers { if err := appendHandler(&handlers.Build, namedRequestHandler(handler), handler.Position()); err != nil { errs = errors.Join(errs, fmt.Errorf("%w (%q): %w", errInvalidHandler, handler.ID(), err)) } } - for _, handler := range mw.ResponseHandlers() { - if err := appendHandler(&handlers.Unmarshal, namedResponseHandler(handler), handler.Position()); err != nil { + for _, handler := range responseHandlers { + if err := appendHandler(&handlers.ValidateResponse, namedResponseHandler(handler), handler.Position()); err != nil { errs = errors.Join(errs, fmt.Errorf("%w (%q): %w", errInvalidHandler, handler.ID(), err)) } } @@ -126,9 +137,10 @@ func ConfigureSDKv1(mw Middleware, handlers *request.Handlers) error { // ConfigureSDKv2 adds middleware to the AWS SDK v2. Request handlers are added to the // Build step and response handlers are added to the Deserialize step. -func ConfigureSDKv2(mw Middleware, config *aws.Config) error { +func (c *Configurer) ConfigureSDKv2(config *aws.Config) error { var errs error - for _, handler := range mw.RequestHandlers() { + requestHandlers, responseHandlers := c.Middleware.Handlers() + for _, handler := range requestHandlers { relativePosition, err := toRelativePosition(handler.Position()) if err != nil { errs = errors.Join(errs, fmt.Errorf("%w (%q): %w", errInvalidHandler, handler.ID(), err)) @@ -136,7 +148,7 @@ func ConfigureSDKv2(mw Middleware, config *aws.Config) error { } config.APIOptions = append(config.APIOptions, withBuildOption(&requestMiddleware{RequestHandler: handler}, relativePosition)) } - for _, handler := range mw.ResponseHandlers() { + for _, handler := range responseHandlers { relativePosition, err := toRelativePosition(handler.Position()) if err != nil { errs = errors.Join(errs, fmt.Errorf("%w (%q): %w", errInvalidHandler, handler.ID(), err)) diff --git a/extension/awsmiddleware/middleware_test.go b/extension/awsmiddleware/middleware_test.go index 9a21a306600b..8915bd7c46a3 100644 --- a/extension/awsmiddleware/middleware_test.go +++ b/extension/awsmiddleware/middleware_test.go @@ -17,6 +17,7 @@ import ( "github.com/aws/aws-sdk-go/awstesting" s3v1 "github.com/aws/aws-sdk-go/service/s3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -28,10 +29,12 @@ const ( type testHandler struct { id string position HandlerPosition - handleRequest func(r *http.Request) - handleResponse func(r *http.Response) + handleRequest func(id string, r *http.Request) + handleResponse func(id string, r *http.Response) start time.Time end time.Time + requestIDs []string + responseIDs []string } var _ RequestHandler = (*testHandler)(nil) @@ -45,17 +48,19 @@ func (t *testHandler) Position() HandlerPosition { return t.position } -func (t *testHandler) HandleRequest(r *http.Request) { +func (t *testHandler) HandleRequest(id string, r *http.Request) { t.start = time.Now() + t.requestIDs = append(t.requestIDs, id) if t.handleRequest != nil { - t.handleRequest(r) + t.handleRequest(id, r) } } -func (t *testHandler) HandleResponse(r *http.Response) { +func (t *testHandler) HandleResponse(id string, r *http.Response) { t.end = time.Now() + t.responseIDs = append(t.responseIDs, id) if t.handleResponse != nil { - t.handleResponse(r) + t.handleResponse(id, r) } } @@ -67,8 +72,8 @@ type recordOrder struct { order []string } -func (ro *recordOrder) handle(id string) func(*http.Request) { - return func(*http.Request) { +func (ro *recordOrder) Handle(id string) func(string, *http.Request) { + return func(string, *http.Request) { ro.order = append(ro.order, id) } } @@ -103,22 +108,25 @@ func TestInvalidHandlerPosition(t *testing.T) { } func TestInvalidHandlers(t *testing.T) { - invalidHandler := &testHandler{id: "invalid handler", position: -1} - testExtension := &testMiddlewareExtension{ - requestHandlers: []RequestHandler{invalidHandler}, - responseHandlers: []ResponseHandler{invalidHandler}, - } + handler := new(MockHandler) + handler.On("ID").Return("invalid handler") + handler.On("Position").Return(HandlerPosition(-1)) + middleware := new(MockMiddlewareExtension) + middleware.On("Handlers").Return([]RequestHandler{handler}, []ResponseHandler{handler}) + c := NewConfigurer(middleware) // v1 client := awstesting.NewClient() - err := ConfigureSDKv1(testExtension, &client.Handlers) + err := c.ConfigureSDKv1(&client.Handlers) assert.Error(t, err) assert.True(t, errors.Is(err, errInvalidHandler)) assert.True(t, errors.Is(err, errUnsupportedPosition)) // v2 - err = ConfigureSDKv2(testExtension, &awsv2.Config{}) + err = c.ConfigureSDKv2(&awsv2.Config{}) assert.Error(t, err) assert.True(t, errors.Is(err, errInvalidHandler)) assert.True(t, errors.Is(err, errUnsupportedPosition)) + handler.AssertNotCalled(t, "HandleRequest", mock.Anything, mock.Anything) + handler.AssertNotCalled(t, "HandleResponse", mock.Anything, mock.Anything) } func TestAppendOrder(t *testing.T) { @@ -161,19 +169,31 @@ func TestAppendOrder(t *testing.T) { } for name, testCase := range testCases { t.Run(name, func(t *testing.T) { - middleware := &testMiddlewareExtension{} recorder := &recordOrder{} + var requestHandlers []RequestHandler for _, handler := range testCase.requestHandlers { - handler.handleRequest = recorder.handle(handler.id) - middleware.requestHandlers = append(middleware.requestHandlers, handler) + handler.handleRequest = recorder.Handle(handler.id) + requestHandlers = append(requestHandlers, handler) } + handler := new(MockHandler) + handler.On("ID").Return("mock") + handler.On("Position").Return(After) + handler.On("HandleRequest", mock.Anything, mock.Anything) + handler.On("HandleResponse", mock.Anything, mock.Anything) + requestHandlers = append(requestHandlers, handler) + middleware := new(MockMiddlewareExtension) + middleware.On("Handlers").Return( + requestHandlers, + []ResponseHandler{handler}, + ) + c := NewConfigurer(middleware) // v1 client := awstesting.NewClient(&awsv1.Config{ Region: awsv1.String("mock-region"), DisableSSL: awsv1.Bool(true), Endpoint: awsv1.String(server.URL), }) - assert.NoError(t, ConfigureSDKv1(middleware, &client.Handlers)) + assert.NoError(t, c.ConfigureSDKv1(&client.Handlers)) s3v1Client := &s3v1.S3{Client: client} _, err := s3v1Client.ListBuckets(&s3v1.ListBucketsInput{}) require.NoError(t, err) @@ -181,7 +201,7 @@ func TestAppendOrder(t *testing.T) { recorder.order = nil // v2 cfg := awsv2.Config{Region: "us-east-1"} - assert.NoError(t, ConfigureSDKv2(middleware, &cfg)) + assert.NoError(t, c.ConfigureSDKv2(&cfg)) s3v2Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) { options.BaseEndpoint = awsv2.String(server.URL) }) @@ -199,24 +219,26 @@ func TestConfigureSDKv1(t *testing.T) { Region: awsv1.String("mock-region"), DisableSSL: awsv1.Bool(true), Endpoint: awsv1.String(server.URL), + MaxRetries: awsv1.Int(0), }) require.Equal(t, 3, client.Handlers.Build.Len()) - require.Equal(t, 0, client.Handlers.Unmarshal.Len()) - assert.NoError(t, ConfigureSDKv1(middleware, &client.Handlers)) + require.Equal(t, 1, client.Handlers.ValidateResponse.Len()) + assert.NoError(t, NewConfigurer(middleware).ConfigureSDKv1(&client.Handlers)) assert.Equal(t, 5, client.Handlers.Build.Len()) - assert.Equal(t, 1, client.Handlers.Unmarshal.Len()) + assert.Equal(t, 2, client.Handlers.ValidateResponse.Len()) s3Client := &s3v1.S3{Client: client} output, err := s3Client.ListBuckets(&s3v1.ListBucketsInput{}) require.NoError(t, err) assert.NotNil(t, output) assert.GreaterOrEqual(t, recorder.Latency(), testLatency) + assert.Equal(t, recorder.requestIDs, recorder.responseIDs) } func TestConfigureSDKv2(t *testing.T) { middleware, recorder, server := setup(t) defer server.Close() - cfg := awsv2.Config{Region: "us-east-1"} - assert.NoError(t, ConfigureSDKv2(middleware, &cfg)) + cfg := awsv2.Config{Region: "us-east-1", RetryMaxAttempts: 0} + assert.NoError(t, NewConfigurer(middleware).ConfigureSDKv2(&cfg)) s3Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) { options.BaseEndpoint = awsv2.String(server.URL) }) @@ -224,24 +246,27 @@ func TestConfigureSDKv2(t *testing.T) { require.NoError(t, err) assert.NotNil(t, output) assert.GreaterOrEqual(t, recorder.Latency(), testLatency) + assert.Equal(t, recorder.requestIDs, recorder.responseIDs) } -func setup(t *testing.T) (Middleware, *testHandler, *httptest.Server) { - t.Helper() - recorder := &testHandler{id: "LatencyTest", position: After} - middleware := &testMiddlewareExtension{ - requestHandlers: []RequestHandler{ - &testHandler{ - id: "UserAgentTest", - position: Before, - handleRequest: func(r *http.Request) { - r.Header.Set("User-Agent", testUserAgent) - }, - }, - recorder, +func userAgentHandler() RequestHandler { + return &testHandler{ + id: "test.UserAgent", + position: Before, + handleRequest: func(_ string, r *http.Request) { + r.Header.Set("User-Agent", testUserAgent) }, - responseHandlers: []ResponseHandler{recorder}, } +} + +func setup(t *testing.T) (Middleware, *testHandler, *httptest.Server) { + t.Helper() + recorder := &testHandler{id: "test.Latency", position: After} + middleware := new(MockMiddlewareExtension) + middleware.On("Handlers").Return( + []RequestHandler{userAgentHandler(), recorder}, + []ResponseHandler{recorder}, + ) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotUserAgent := r.Header.Get("User-Agent") assert.Contains(t, gotUserAgent, testUserAgent) diff --git a/extension/awsmiddleware/mock.go b/extension/awsmiddleware/mock.go new file mode 100644 index 000000000000..afb8849c438a --- /dev/null +++ b/extension/awsmiddleware/mock.go @@ -0,0 +1,78 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package awsmiddleware // import "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware" + +import ( + "net/http" + + "github.com/stretchr/testify/mock" + "go.opentelemetry.io/collector/component" +) + +// MockMiddlewareExtension mocks the Extension interface. +type MockMiddlewareExtension struct { + component.StartFunc + component.ShutdownFunc + mock.Mock +} + +var _ Extension = (*MockMiddlewareExtension)(nil) + +func (m *MockMiddlewareExtension) Handlers() ([]RequestHandler, []ResponseHandler) { + var requestHandlers []RequestHandler + var responseHandlers []ResponseHandler + args := m.Called() + arg := args.Get(0) + if arg != nil { + requestHandlers = arg.([]RequestHandler) + } + arg = args.Get(1) + if arg != nil { + responseHandlers = arg.([]ResponseHandler) + } + return requestHandlers, responseHandlers +} + +// MockHandler mocks the functions for both +// RequestHandler and ResponseHandler. +type MockHandler struct { + mock.Mock +} + +var _ RequestHandler = (*MockHandler)(nil) +var _ ResponseHandler = (*MockHandler)(nil) + +func (m *MockHandler) ID() string { + args := m.Called() + return args.String(0) +} + +func (m *MockHandler) Position() HandlerPosition { + args := m.Called() + return args.Get(0).(HandlerPosition) +} + +func (m *MockHandler) HandleRequest(id string, r *http.Request) { + m.Called(id, r) +} + +func (m *MockHandler) HandleResponse(id string, r *http.Response) { + m.Called(id, r) +} + +// MockExtensionsHost only mocks the GetExtensions function. +// All other functions are ignored and will panic if called. +type MockExtensionsHost struct { + component.Host + mock.Mock +} + +func (m *MockExtensionsHost) GetExtensions() map[component.ID]component.Component { + args := m.Called() + arg := args.Get(0) + if arg == nil { + return nil + } + return arg.(map[component.ID]component.Component) +} diff --git a/extension/awsmiddleware/wrapper.go b/extension/awsmiddleware/wrapper.go index 77aa87fd9b83..9509b3c80515 100644 --- a/extension/awsmiddleware/wrapper.go +++ b/extension/awsmiddleware/wrapper.go @@ -1,7 +1,7 @@ // Copyright The OpenTelemetry Authors // SPDX-License-Identifier: Apache-2.0 -package awsmiddleware // import "github.com/open-telemetry/opentelemetry-collector-contrib/extension/awsmiddleware" +package awsmiddleware // import "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware" import ( "context" @@ -9,17 +9,25 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/smithy-go/middleware" "github.com/aws/smithy-go/transport/http" + "github.com/google/uuid" ) +type key struct{} + +var requestID key + func namedRequestHandler(handler RequestHandler) request.NamedHandler { return request.NamedHandler{Name: handler.ID(), Fn: func(r *request.Request) { - handler.HandleRequest(r.HTTPRequest) + ctx, id := setID(r.Context()) + r.SetContext(ctx) + handler.HandleRequest(id, r.HTTPRequest) }} } func namedResponseHandler(handler ResponseHandler) request.NamedHandler { return request.NamedHandler{Name: handler.ID(), Fn: func(r *request.Request) { - handler.HandleResponse(r.HTTPResponse) + id, _ := getID(r.Context()) + handler.HandleResponse(id, r.HTTPResponse) }} } @@ -32,7 +40,9 @@ var _ middleware.BuildMiddleware = (*requestMiddleware)(nil) func (r requestMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (out middleware.BuildOutput, metadata middleware.Metadata, err error) { req, ok := in.Request.(*http.Request) if ok { - r.HandleRequest(req.Request) + var id string + ctx, id = setID(ctx) + r.HandleRequest(id, req.Request) } return next.HandleBuild(ctx, in) } @@ -53,7 +63,8 @@ func (r responseMiddleware) HandleDeserialize(ctx context.Context, in middleware out, metadata, err = next.HandleDeserialize(ctx, in) res, ok := out.RawResponse.(*http.Response) if ok { - r.HandleResponse(res.Response) + id, _ := getID(ctx) + r.HandleResponse(id, res.Response) } return } @@ -63,3 +74,17 @@ func withDeserializeOption(rmw *responseMiddleware, position middleware.Relative return stack.Deserialize.Add(rmw, position) } } + +func setID(ctx context.Context) (context.Context, string) { + id, ok := getID(ctx) + if !ok { + id = uuid.NewString() + return context.WithValue(ctx, requestID, id), id + } + return ctx, id +} + +func getID(ctx context.Context) (string, bool) { + id, ok := ctx.Value(requestID).(string) + return id, ok +}