From e9ec854c4741a25d9c021aee38fe83a42b2826b2 Mon Sep 17 00:00:00 2001 From: Per Bockman Date: Tue, 13 Feb 2024 22:10:10 +0100 Subject: [PATCH] feat!: add context param to error handler func --- internal/test/chi/oapi_validate_test.go | 2 +- internal/test/gorilla/oapi_validate_test.go | 4 ++-- internal/test/nethttp/oapi_validate_test.go | 4 ++-- oapi_validate.go | 5 +++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/internal/test/chi/oapi_validate_test.go b/internal/test/chi/oapi_validate_test.go index a2b03b2..c6a0f56 100644 --- a/internal/test/chi/oapi_validate_test.go +++ b/internal/test/chi/oapi_validate_test.go @@ -266,7 +266,7 @@ func TestOapiRequestValidatorWithOptions(t *testing.T) { // Set up an authenticator to check authenticated function. It will allow // access to "someScope", but disallow others. options := middleware.Options{ - ErrorHandler: func(w http.ResponseWriter, message string, statusCode int) { + ErrorHandler: func(ctx context.Context, w http.ResponseWriter, message string, statusCode int) { http.Error(w, "test: "+message, statusCode) }, Options: openapi3filter.Options{ diff --git a/internal/test/gorilla/oapi_validate_test.go b/internal/test/gorilla/oapi_validate_test.go index cd1502d..62726cd 100644 --- a/internal/test/gorilla/oapi_validate_test.go +++ b/internal/test/gorilla/oapi_validate_test.go @@ -10,8 +10,8 @@ import ( "net/url" "testing" - "github.com/oapi-codegen/testutil" middleware "github.com/oapi-codegen/nethttp-middleware" + "github.com/oapi-codegen/testutil" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" @@ -266,7 +266,7 @@ func TestOapiRequestValidatorWithOptions(t *testing.T) { // Set up an authenticator to check authenticated function. It will allow // access to "someScope", but disallow others. options := middleware.Options{ - ErrorHandler: func(w http.ResponseWriter, message string, statusCode int) { + ErrorHandler: func(ctx context.Context, w http.ResponseWriter, message string, statusCode int) { http.Error(w, "test: "+message, statusCode) }, Options: openapi3filter.Options{ diff --git a/internal/test/nethttp/oapi_validate_test.go b/internal/test/nethttp/oapi_validate_test.go index 279969c..5dfb3ea 100644 --- a/internal/test/nethttp/oapi_validate_test.go +++ b/internal/test/nethttp/oapi_validate_test.go @@ -10,8 +10,8 @@ import ( "net/url" "testing" - "github.com/oapi-codegen/testutil" middleware "github.com/oapi-codegen/nethttp-middleware" + "github.com/oapi-codegen/testutil" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" @@ -281,7 +281,7 @@ func TestOapiRequestValidatorWithOptions(t *testing.T) { // Set up an authenticator to check authenticated function. It will allow // access to "someScope", but disallow others. options := middleware.Options{ - ErrorHandler: func(w http.ResponseWriter, message string, statusCode int) { + ErrorHandler: func(ctx context.Context, w http.ResponseWriter, message string, statusCode int) { http.Error(w, "test: "+message, statusCode) }, Options: openapi3filter.Options{ diff --git a/oapi_validate.go b/oapi_validate.go index 5bbce40..8b21ef5 100644 --- a/oapi_validate.go +++ b/oapi_validate.go @@ -4,6 +4,7 @@ package nethttpmiddleware import ( + "context" "errors" "fmt" "log" @@ -17,7 +18,7 @@ import ( ) // ErrorHandler is called when there is an error in validation -type ErrorHandler func(w http.ResponseWriter, message string, statusCode int) +type ErrorHandler func(ctx context.Context, w http.ResponseWriter, message string, statusCode int) // MultiErrorHandler is called when oapi returns a MultiError type type MultiErrorHandler func(openapi3.MultiError) (int, error) @@ -53,7 +54,7 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) func // validate request if statusCode, err := validateRequest(r, router, options); err != nil { if options != nil && options.ErrorHandler != nil { - options.ErrorHandler(w, err.Error(), statusCode) + options.ErrorHandler(r.Context(), w, err.Error(), statusCode) } else { http.Error(w, err.Error(), statusCode) }