Skip to content

Commit

Permalink
Add order/position test.
Browse files Browse the repository at this point in the history
  • Loading branch information
jefchien committed Oct 16, 2023
1 parent 8a62538 commit 506bd65
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 21 deletions.
29 changes: 29 additions & 0 deletions extension/awsmiddleware/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# AWS Middleware

An AWS middleware extension provides request and/or response handlers that can be configured on AWS SDK v1/v2 clients.
Other components can configure their AWS SDK clients using the `awsmiddleware.ConfigureSDKv1` and `awsmiddleware.ConfigureSDKv2` functions.

The `awsmiddleware.Extension` interface extends `component.Extension` by adding the following methods:
```
RequestHandlers() []RequestHandler
ResponseHandlers() []ResponseHandler
```

The `awsmiddleware.RequestHandler` interface contains the following methods:
```
ID() string
Position() HandlerPosition
HandleRequest(r *http.Request)
```

The `awsmiddleware.ResponseHandler` interface contains the following methods:
```
ID() string
Position() HandlerPosition
HandleResponse(r *http.Response)
```

- `ID` uniquely identifies a handler. Middleware will fail if there is clashing
- `Position` determines whether the handler is appended to the front or back of the existing list. Insertion is done
in the order of the handlers provided.
- `HandleRequest/Response` provides a hook to handle the request/response before and after they've been sent.
28 changes: 11 additions & 17 deletions extension/awsmiddleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func (h HandlerPosition) String() string {
}
}

// MarshalText converts the position into a string. Returns an error
// if unsupported.
// MarshalText converts the position into a byte slice.
// Returns an error if unsupported.
func (h HandlerPosition) MarshalText() (text []byte, err error) {
s := h.String()
if s == "" {
Expand All @@ -73,9 +73,9 @@ func (h *HandlerPosition) UnmarshalText(text []byte) error {
return nil
}

// metadata is used to differentiate between handlers and determine
// handlerConfig is used to differentiate between handlers and determine
// relative positioning within their groups.
type metadata interface {
type handlerConfig interface {
// ID must be unique. It cannot clash with existing middleware.
ID() string
// Position to insert the handler.
Expand All @@ -84,13 +84,13 @@ type metadata interface {

// RequestHandler allows for custom processing of requests.
type RequestHandler interface {
metadata
handlerConfig
HandleRequest(r *http.Request)
}

// ResponseHandler allows for custom processing of responses.
type ResponseHandler interface {
metadata
handlerConfig
HandleResponse(r *http.Response)
}

Expand All @@ -112,12 +112,12 @@ type Extension interface {
func ConfigureSDKv1(mw Middleware, handlers *request.Handlers) error {
var errs error
for _, handler := range mw.RequestHandlers() {
if err := addHandlerToList(&handlers.Build, namedRequestHandler(handler), handler.Position()); err != nil {
if err := appendHandler(&handlers.Build, namedRequestHandler(handler), handler.Position()); err != nil {
errs = errors.Join(errs, fmt.Errorf("%w (%q): %w", errInvalidHandler, handler.ID(), err))
}
}
for _, handler := range mw.ResponseHandlers() {
if err := addHandlerToList(&handlers.Unmarshal, namedResponseHandler(handler), handler.Position()); err != nil {
if err := appendHandler(&handlers.Unmarshal, namedResponseHandler(handler), handler.Position()); err != nil {
errs = errors.Join(errs, fmt.Errorf("%w (%q): %w", errInvalidHandler, handler.ID(), err))
}
}
Expand All @@ -134,27 +134,21 @@ func ConfigureSDKv2(mw Middleware, config *aws.Config) error {
errs = errors.Join(errs, fmt.Errorf("%w (%q): %w", errInvalidHandler, handler.ID(), err))
continue
}
rmw := &requestMiddleware{RequestHandler: handler, position: relativePosition}
config.APIOptions = append(config.APIOptions, func(stack *middleware.Stack) error {
return stack.Build.Add(rmw, rmw.position)
})
config.APIOptions = append(config.APIOptions, withBuildOption(&requestMiddleware{RequestHandler: handler}, relativePosition))
}
for _, handler := range mw.ResponseHandlers() {
relativePosition, err := toRelativePosition(handler.Position())
if err != nil {
errs = errors.Join(errs, fmt.Errorf("%w (%q): %w", errInvalidHandler, handler.ID(), err))
continue
}
rmw := &responseMiddleware{ResponseHandler: handler, position: relativePosition}
config.APIOptions = append(config.APIOptions, func(stack *middleware.Stack) error {
return stack.Deserialize.Add(rmw, rmw.position)
})
config.APIOptions = append(config.APIOptions, withDeserializeOption(&responseMiddleware{ResponseHandler: handler}, relativePosition))
}
return errs
}

// addHandlerToList adds the handler to the list based on the position.
func addHandlerToList(handlerList *request.HandlerList, handler request.NamedHandler, position HandlerPosition) error {
func appendHandler(handlerList *request.HandlerList, handler request.NamedHandler, position HandlerPosition) error {
relativePosition, err := toRelativePosition(position)
if err != nil {
return err
Expand Down
87 changes: 85 additions & 2 deletions extension/awsmiddleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ func (t *testHandler) Latency() time.Duration {
return t.end.Sub(t.start)
}

type recordOrder struct {
order []string
}

func (ro *recordOrder) handle(id string) func(*http.Request) {
return func(*http.Request) {
ro.order = append(ro.order, id)
}
}

func TestHandlerPosition(t *testing.T) {
testCases := []struct {
position HandlerPosition
Expand Down Expand Up @@ -111,8 +121,80 @@ func TestInvalidHandlers(t *testing.T) {
assert.True(t, errors.Is(err, errUnsupportedPosition))
}

func TestAppendOrder(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
testCases := map[string]struct {
requestHandlers []*testHandler
wantOrder []string
}{
"WithBothBefore": {
requestHandlers: []*testHandler{
{id: "1", position: Before},
{id: "2", position: Before},
},
wantOrder: []string{"2", "1"},
},
"WithBothAfter": {
requestHandlers: []*testHandler{
{id: "1", position: After},
{id: "2", position: After},
},
wantOrder: []string{"1", "2"},
},
"WithBeforeAfter": {
requestHandlers: []*testHandler{
{id: "1", position: Before},
{id: "2", position: After},
},
wantOrder: []string{"1", "2"},
},
"WithAfterBefore": {
requestHandlers: []*testHandler{
{id: "1", position: After},
{id: "2", position: Before},
},
wantOrder: []string{"2", "1"},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
middleware := &testMiddlewareExtension{}
recorder := &recordOrder{}
for _, handler := range testCase.requestHandlers {
handler.handleRequest = recorder.handle(handler.id)
middleware.requestHandlers = append(middleware.requestHandlers, handler)
}
// v1
client := awstesting.NewClient(&awsv1.Config{
Region: awsv1.String("mock-region"),
DisableSSL: awsv1.Bool(true),
Endpoint: awsv1.String(server.URL),
})
assert.NoError(t, ConfigureSDKv1(middleware, &client.Handlers))
s3v1Client := &s3v1.S3{Client: client}
_, err := s3v1Client.ListBuckets(&s3v1.ListBucketsInput{})
require.NoError(t, err)
assert.Equal(t, testCase.wantOrder, recorder.order)
recorder.order = nil
// v2
cfg := awsv2.Config{Region: "us-east-1"}
assert.NoError(t, ConfigureSDKv2(middleware, &cfg))
s3v2Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) {
options.BaseEndpoint = awsv2.String(server.URL)
})
_, err = s3v2Client.ListBuckets(context.Background(), &s3v2.ListBucketsInput{})
require.NoError(t, err)
assert.Equal(t, testCase.wantOrder, recorder.order)
})
}
}

func TestConfigureSDKv1(t *testing.T) {
middleware, recorder, server := setup(t)
defer server.Close()
client := awstesting.NewClient(&awsv1.Config{
Region: awsv1.String("mock-region"),
DisableSSL: awsv1.Bool(true),
Expand All @@ -131,9 +213,10 @@ func TestConfigureSDKv1(t *testing.T) {
}

func TestConfigureSDKv2(t *testing.T) {
mw, recorder, server := setup(t)
middleware, recorder, server := setup(t)
defer server.Close()
cfg := awsv2.Config{Region: "us-east-1"}
assert.NoError(t, ConfigureSDKv2(mw, &cfg))
assert.NoError(t, ConfigureSDKv2(middleware, &cfg))
s3Client := s3v2.NewFromConfig(cfg, func(options *s3v2.Options) {
options.BaseEndpoint = awsv2.String(server.URL)
})
Expand Down
14 changes: 12 additions & 2 deletions extension/awsmiddleware/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ func namedResponseHandler(handler ResponseHandler) request.NamedHandler {

type requestMiddleware struct {
RequestHandler
position middleware.RelativePosition
}

var _ middleware.BuildMiddleware = (*requestMiddleware)(nil)
Expand All @@ -38,9 +37,14 @@ func (r requestMiddleware) HandleBuild(ctx context.Context, in middleware.BuildI
return next.HandleBuild(ctx, in)
}

func withBuildOption(rmw *requestMiddleware, position middleware.RelativePosition) func(stack *middleware.Stack) error {
return func(stack *middleware.Stack) error {
return stack.Build.Add(rmw, position)
}
}

type responseMiddleware struct {
ResponseHandler
position middleware.RelativePosition
}

var _ middleware.DeserializeMiddleware = (*responseMiddleware)(nil)
Expand All @@ -53,3 +57,9 @@ func (r responseMiddleware) HandleDeserialize(ctx context.Context, in middleware
}
return
}

func withDeserializeOption(rmw *responseMiddleware, position middleware.RelativePosition) func(stack *middleware.Stack) error {
return func(stack *middleware.Stack) error {
return stack.Deserialize.Add(rmw, position)
}
}

0 comments on commit 506bd65

Please sign in to comment.