diff --git a/README.md b/README.md index 2ddb757a..d752ea89 100644 --- a/README.md +++ b/README.md @@ -359,8 +359,9 @@ There are several ways to adjust functionality of the library: 1. `ClaimsUpdater` - interface with `Update(claims Claims) Claims` method. This is the primary way to alter a token at login time and add any attributes, set ip, email, admin status, roles and so on. 1. `Validator` - interface with `Validate(token string, claims Claims) bool` method. This is post-token hook and will be called on **each request** wrapped with `Auth` middleware. This will be the place for special logic to reject some tokens or users. 1. `UserUpdater` - interface with `Update(claims token.User) token.User` method. This method will be called on **each request** wrapped with `UpdateUser` middleware. This will be the place for special logic modify User Info in request context. [Example of usage.](https://github.com/go-pkgz/auth/blob/19c1b6d26608494955a4480f8f6165af85b1deab/_example/main.go#L189) +1. `AuthErrorHTTPHandler` - interface with `ServeAuthError(w http.ResponseWriter, r *http.Request, and other params)` method. It is possible to change how authentication errors are written into HTTP responses by configuring custom implementations of this interface for the middlewares. -All of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`. +All of the interfaces above except `AuthErrorHTTPHandler` have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`. ### Implementing black list logic or some other filters diff --git a/auth.go b/auth.go index 5bd9d879..bcdd7d68 100644 --- a/auth.go +++ b/auth.go @@ -66,12 +66,13 @@ type Opts struct { AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar` UseGravatar bool // for email based auth (verified provider) use gravatar service - AdminPasswd string // if presented, allows basic auth with user admin and given password - BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored - AudienceReader token.Audience // list of allowed aud values, default (empty) allows any - AudSecrets bool // allow multiple secrets (secret per aud) - Logger logger.L // logger interface, default is no logging at all - RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens + AdminPasswd string // if presented, allows basic auth with user admin and given password + BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored + AudienceReader token.Audience // list of allowed aud values, default (empty) allows any + AudSecrets bool // allow multiple secrets (secret per aud) + Logger logger.L // logger interface, default is no logging at all + RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens + AuthErrorHTTPHandler middleware.AuthErrorHTTPHandler // optional HTTP handler for authentication errors } // NewService initializes everything @@ -81,10 +82,11 @@ func NewService(opts Opts) (res *Service) { opts: opts, logger: opts.Logger, authMiddleware: middleware.Authenticator{ - Validator: opts.Validator, - AdminPasswd: opts.AdminPasswd, - BasicAuthChecker: opts.BasicAuthChecker, - RefreshCache: opts.RefreshCache, + Validator: opts.Validator, + AdminPasswd: opts.AdminPasswd, + BasicAuthChecker: opts.BasicAuthChecker, + RefreshCache: opts.RefreshCache, + AuthErrorHTTPHandler: opts.AuthErrorHTTPHandler, }, issuer: opts.Issuer, useGravatar: opts.UseGravatar, diff --git a/auth_test.go b/auth_test.go index 06a3b8aa..73615981 100644 --- a/auth_test.go +++ b/auth_test.go @@ -3,6 +3,7 @@ package auth import ( "context" "encoding/json" + "fmt" "io" "net" "net/http" @@ -19,6 +20,7 @@ import ( "github.com/go-pkgz/auth/avatar" "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/middleware" "github.com/go-pkgz/auth/provider" "github.com/go-pkgz/auth/token" ) @@ -246,6 +248,148 @@ func TestIntegrationList(t *testing.T) { assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b)) } +type testAuthErrorHTTPHandler struct { + statusCode int + contentType string + responseBody string + wasCalled bool +} + +func (h *testAuthErrorHTTPHandler) ServeAuthError( + w http.ResponseWriter, + _ *http.Request, + authError error, + reason string, + statusCode int, +) { + h.wasCalled = true + w.Header().Set("Content-Type", h.contentType) + w.WriteHeader(h.statusCode) + fmt.Fprint(w, h.responseBody) +} + +func TestIntegrationAuthErrorHTTPHandler(t *testing.T) { + apiHandler := http.HandlerFunc( + func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("must not be called\n")) + t.Error("auth error must be raised before this HTTP handler is called") + }, + ) + defaultAuthErrorHTTPHandler := testAuthErrorHTTPHandler{ + statusCode: 400, + contentType: "application/json", + responseBody: `{"code": 400, "message": "from general error handler"}`, + } + type apiCall struct { + requestPath string + expectedHandler *testAuthErrorHTTPHandler + createMiddleware func(apiCall, middleware.Authenticator) http.Handler + } + apiCalls := []apiCall{ + { + requestPath: "/private1", + expectedHandler: &defaultAuthErrorHTTPHandler, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.Auth(apiHandler) + }, + }, + { + requestPath: "/private2", + expectedHandler: &testAuthErrorHTTPHandler{ + statusCode: 402, + contentType: "application/json", + responseBody: `{"code": 402, "message": "from private2 error handler"}`, + }, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.AuthWithErrorHTTPHandler(apiHandler, ac.expectedHandler) + }, + }, + { + requestPath: "/admin1", + expectedHandler: &defaultAuthErrorHTTPHandler, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.AdminOnly(apiHandler) + }, + }, + { + requestPath: "/admin2", + expectedHandler: &testAuthErrorHTTPHandler{ + statusCode: 404, + contentType: "application/json", + responseBody: `{"code": 404, "message": "from admin2 error handler"}`, + }, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.AdminOnlyWithErrorHTTPHandler(apiHandler, ac.expectedHandler) + }, + }, + { + requestPath: "/rbac1", + expectedHandler: &defaultAuthErrorHTTPHandler, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.RBAC("role1", "role2")(apiHandler) + }, + }, + { + requestPath: "/rbac2", + expectedHandler: &testAuthErrorHTTPHandler{ + statusCode: 406, + contentType: "text/html", + responseBody: `

from RBAC2 error handler

`, + }, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.RBACwithErrorHTTPHandler(ac.expectedHandler, "role1", "role2")(apiHandler) + }, + }, + } + + options := Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + Issuer: "my-test-app", + URL: "http://127.0.0.1:8089", + AuthErrorHTTPHandler: &defaultAuthErrorHTTPHandler, + } + + svc := NewService(options) + svc.AddDevProvider("localhost", 18084) // add dev provider on 18084 + + mux := http.NewServeMux() + m := svc.Middleware() + + for _, ac := range apiCalls { + mux.Handle(ac.requestPath, ac.createMiddleware(ac, m)) + } + + l, listenErr := net.Listen("tcp", "127.0.0.1:8089") + require.Nil(t, listenErr) + ts := httptest.NewUnstartedServer(mux) + assert.NoError(t, ts.Listener.Close()) + ts.Listener = l + ts.Start() + defer func() { + ts.Close() + }() + + for _, ac := range apiCalls { + t.Run("auth error test for endpoint "+ac.requestPath, func(t *testing.T) { + th := ac.expectedHandler + th.wasCalled = false + + resp, err := http.Get("http://127.0.0.1:8089" + ac.requestPath) + require.NoError(t, err) + defer resp.Body.Close() + + require.True(t, th.wasCalled) + + require.Equal(t, th.statusCode, resp.StatusCode) + require.Equal(t, th.contentType, resp.Header.Get("Content-Type")) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, th.responseBody, string(b)) + }) + } +} + func TestIntegrationUserInfo(t *testing.T) { _, teardown := prepService(t) defer teardown() diff --git a/middleware/auth.go b/middleware/auth.go index 64d6c7ce..23e95041 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -18,12 +18,13 @@ import ( // Authenticator is top level auth object providing middlewares type Authenticator struct { logger.L - JWTService TokenService - Providers []provider.Service - Validator token.Validator - AdminPasswd string - BasicAuthChecker BasicAuthFunc - RefreshCache RefreshCache + JWTService TokenService + Providers []provider.Service + Validator token.Validator + AdminPasswd string + BasicAuthChecker BasicAuthFunc + RefreshCache RefreshCache + AuthErrorHTTPHandler AuthErrorHTTPHandler } // RefreshCache defines interface storing and retrieving refreshed tokens @@ -45,6 +46,28 @@ type TokenService interface { // The second return parameter `User` need for add user claims into context of request. type BasicAuthFunc func(user, passwd string) (ok bool, userInfo token.User, err error) +// AuthErrorHTTPHandler defines interface for handling HTTP responses in case of authentication errors +type AuthErrorHTTPHandler interface { + // Serves HTTP response in case of authentication error + // w - response writer + // r - original request + // authError - authentication error with technical details + // reason - reason text + // statusCode - HTTP status code + ServeAuthError(w http.ResponseWriter, r *http.Request, authError error, reason string, statusCode int) +} + +// DefaultAuthErrorHTTPHandler is a default implementation, which writes text/plain responses using http.Error() +type DefaultAuthErrorHTTPHandler struct { + logger.L +} + +// ServeAuthError writes text/plain responses using http.Error() +func (h DefaultAuthErrorHTTPHandler) ServeAuthError(w http.ResponseWriter, _ *http.Request, authError error, reason string, statusCode int) { + h.Logf("[DEBUG] auth failed, %v", authError) + http.Error(w, reason, statusCode) +} + // adminUser sets claims for an optional basic auth var adminUser = token.User{ ID: "admin", @@ -56,24 +79,29 @@ var adminUser = token.User{ // Auth middleware adds auth from session and populates user info func (a *Authenticator) Auth(next http.Handler) http.Handler { - return a.auth(true)(next) + return a.auth(true, a.getAuthErrorHTTPHandler())(next) +} + +// AuthWithErrorHTTPHandler middleware adds auth from session and populates user info. +// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error. +func (a *Authenticator) AuthWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler { + return a.auth(true, errorHTTPHandler)(next) } // Trace middleware doesn't require valid user but if user info presented populates info func (a *Authenticator) Trace(next http.Handler) http.Handler { - return a.auth(false)(next) + return a.auth(false, a.getAuthErrorHTTPHandler())(next) } // auth implements all logic for authentication (reqAuth=true) and tracing (reqAuth=false) -func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { +func (a *Authenticator) auth(reqAuth bool, errorHTTPHandler AuthErrorHTTPHandler) func(http.Handler) http.Handler { onError := func(h http.Handler, w http.ResponseWriter, r *http.Request, err error) { if !reqAuth { // if no auth required allow to proceeded on error h.ServeHTTP(w, r) return } - a.Logf("[DEBUG] auth failed, %v", err) - http.Error(w, "Unauthorized", http.StatusUnauthorized) + errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized) } f := func(h http.Handler) http.Handler { @@ -191,23 +219,34 @@ func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token. return c, nil } -// AdminOnly middleware allows access for admins only -// this handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth +// AdminOnly middleware allows access for admins only. +// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth func (a *Authenticator) AdminOnly(next http.Handler) http.Handler { + return a.adminOnly(next, a.getAuthErrorHTTPHandler()) +} + +// AdminOnlyWithErrorHTTPHandler middleware allows access for admins only. +// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth. +// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error. +func (a *Authenticator) AdminOnlyWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler { + return a.adminOnly(next, errorHTTPHandler) +} + +func (a *Authenticator) adminOnly(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { user, err := token.GetUserInfo(r) if err != nil { - http.Error(w, "Unauthorized", http.StatusUnauthorized) + errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized) return } if !user.IsAdmin() { - http.Error(w, "Access denied", http.StatusForbidden) + errorHTTPHandler.ServeAuthError(w, r, fmt.Errorf("user %s/%s is not admin", user.Name, user.ID), "Access denied", http.StatusForbidden) return } next.ServeHTTP(w, r) } - return a.auth(true)(http.HandlerFunc(fn)) // enforce auth + return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth } // basic auth for admin user @@ -234,12 +273,23 @@ func (a *Authenticator) basicAdminUser(r *http.Request) bool { // RBAC middleware allows role based control for routes // this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler { + return a.rbac(a.getAuthErrorHTTPHandler(), roles...) +} + +// RBACwithErrorHTTPHandler middleware allows role based control for routes +// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth +// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error. +func (a *Authenticator) RBACwithErrorHTTPHandler(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler { + return a.rbac(errorHTTPHandler, roles...) +} + +func (a *Authenticator) rbac(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler { f := func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { user, err := token.GetUserInfo(r) if err != nil { - http.Error(w, "Unauthorized", http.StatusUnauthorized) + errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized) return } @@ -251,12 +301,26 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler { } } if !matched { - http.Error(w, "Access denied", http.StatusForbidden) + errorHTTPHandler.ServeAuthError( + w, + r, + fmt.Errorf("user %s/%s does not have any of required roles: %s", user.Name, user.ID, roles), + "Access denied", + http.StatusForbidden, + ) return } h.ServeHTTP(w, r) } - return a.auth(true)(http.HandlerFunc(fn)) // enforce auth + return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth } return f } + +func (a *Authenticator) getAuthErrorHTTPHandler() AuthErrorHTTPHandler { + if a.AuthErrorHTTPHandler != nil { + return a.AuthErrorHTTPHandler + } + + return DefaultAuthErrorHTTPHandler{L: a.L} +} diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 18208aee..c923b1a4 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -503,6 +503,70 @@ func TestRBAC(t *testing.T) { assert.Equal(t, "Access denied\n", string(data)) } +type testAuthErrorHTTPHandler struct { + wasCalled bool + statusCode int +} + +func (h *testAuthErrorHTTPHandler) ServeAuthError( + w http.ResponseWriter, + _ *http.Request, + _ error, + _ string, + _ int, +) { + h.wasCalled = true + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(h.statusCode) + fmt.Fprint(w, "Unauthorized") +} + +func TestAuthErrorHTTPHandler(t *testing.T) { + testErrorHandler1 := &testAuthErrorHTTPHandler{statusCode: 401} + testErrorHandler2 := &testAuthErrorHTTPHandler{statusCode: 402} + testErrorHandler3 := &testAuthErrorHTTPHandler{statusCode: 403} + testErrorHandler4 := &testAuthErrorHTTPHandler{statusCode: 404} + + a := makeTestAuth(t) + a.AuthErrorHTTPHandler = testErrorHandler1 + + handler := http.HandlerFunc( + func(w http.ResponseWriter, _ *http.Request) { // token required + _, _ = w.Write([]byte("must not be called\n")) + t.Error("auth error must be raised before this HTTP handler is called") + }, + ) + + mux := http.NewServeMux() + mux.Handle("/private1", a.Auth(handler)) + mux.Handle("/private2", a.AuthWithErrorHTTPHandler(handler, testErrorHandler2)) + mux.Handle("/admin1", a.AdminOnly(handler)) + mux.Handle("/admin2", a.AdminOnlyWithErrorHTTPHandler(handler, testErrorHandler3)) + mux.Handle("/rbac1", a.RBAC("role1")(handler)) + mux.Handle("/rbac2", a.RBACwithErrorHTTPHandler(testErrorHandler4, "role1")(handler)) + + server := httptest.NewServer(mux) + defer server.Close() + + assertThatHandlerWasCalledProperly := func(t *testing.T, errorHandler *testAuthErrorHTTPHandler, path string) { + errorHandler.wasCalled = false + + resp, err := http.Get(server.URL + path) + require.NoError(t, err) + defer resp.Body.Close() + + require.True(t, errorHandler.wasCalled, "error handler must be called") + require.Equal(t, errorHandler.statusCode, resp.StatusCode, "error handler must produce proper status code") + } + + assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/private1") + assertThatHandlerWasCalledProperly(t, testErrorHandler2, "/private2") + assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/admin1") + assertThatHandlerWasCalledProperly(t, testErrorHandler3, "/admin2") + assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/rbac1") + assertThatHandlerWasCalledProperly(t, testErrorHandler4, "/rbac2") +} + func makeTestMux(_ *testing.T, a *Authenticator, required bool) http.Handler { mux := http.NewServeMux() authMiddleware := a.Auth diff --git a/v2/auth.go b/v2/auth.go index 30d4a613..e5908399 100644 --- a/v2/auth.go +++ b/v2/auth.go @@ -66,12 +66,13 @@ type Opts struct { AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar` UseGravatar bool // for email based auth (verified provider) use gravatar service - AdminPasswd string // if presented, allows basic auth with user admin and given password - BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored - AudienceReader token.Audience // list of allowed aud values, default (empty) allows any - AudSecrets bool // allow multiple secrets (secret per aud) - Logger logger.L // logger interface, default is no logging at all - RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens + AdminPasswd string // if presented, allows basic auth with user admin and given password + BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored + AudienceReader token.Audience // list of allowed aud values, default (empty) allows any + AudSecrets bool // allow multiple secrets (secret per aud) + Logger logger.L // logger interface, default is no logging at all + RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens + AuthErrorHTTPHandler middleware.AuthErrorHTTPHandler // optional HTTP handler for authentication errors } // NewService initializes everything @@ -81,10 +82,11 @@ func NewService(opts Opts) (res *Service) { opts: opts, logger: opts.Logger, authMiddleware: middleware.Authenticator{ - Validator: opts.Validator, - AdminPasswd: opts.AdminPasswd, - BasicAuthChecker: opts.BasicAuthChecker, - RefreshCache: opts.RefreshCache, + Validator: opts.Validator, + AdminPasswd: opts.AdminPasswd, + BasicAuthChecker: opts.BasicAuthChecker, + RefreshCache: opts.RefreshCache, + AuthErrorHTTPHandler: opts.AuthErrorHTTPHandler, }, issuer: opts.Issuer, useGravatar: opts.UseGravatar, diff --git a/v2/auth_test.go b/v2/auth_test.go index 81655bb6..bc3c1852 100644 --- a/v2/auth_test.go +++ b/v2/auth_test.go @@ -3,6 +3,7 @@ package auth import ( "context" "encoding/json" + "fmt" "io" "net" "net/http" @@ -19,6 +20,7 @@ import ( "github.com/go-pkgz/auth/v2/avatar" "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/middleware" "github.com/go-pkgz/auth/v2/provider" "github.com/go-pkgz/auth/v2/token" ) @@ -246,6 +248,148 @@ func TestIntegrationList(t *testing.T) { assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b)) } +type testAuthErrorHTTPHandler struct { + statusCode int + contentType string + responseBody string + wasCalled bool +} + +func (h *testAuthErrorHTTPHandler) ServeAuthError( + w http.ResponseWriter, + _ *http.Request, + authError error, + reason string, + statusCode int, +) { + h.wasCalled = true + w.Header().Set("Content-Type", h.contentType) + w.WriteHeader(h.statusCode) + fmt.Fprint(w, h.responseBody) +} + +func TestIntegrationAuthErrorHTTPHandler(t *testing.T) { + apiHandler := http.HandlerFunc( + func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("must not be called\n")) + t.Error("auth error must be raised before this HTTP handler is called") + }, + ) + defaultAuthErrorHTTPHandler := testAuthErrorHTTPHandler{ + statusCode: 400, + contentType: "application/json", + responseBody: `{"code": 400, "message": "from general error handler"}`, + } + type apiCall struct { + requestPath string + expectedHandler *testAuthErrorHTTPHandler + createMiddleware func(apiCall, middleware.Authenticator) http.Handler + } + apiCalls := []apiCall{ + { + requestPath: "/private1", + expectedHandler: &defaultAuthErrorHTTPHandler, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.Auth(apiHandler) + }, + }, + { + requestPath: "/private2", + expectedHandler: &testAuthErrorHTTPHandler{ + statusCode: 402, + contentType: "application/json", + responseBody: `{"code": 402, "message": "from private2 error handler"}`, + }, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.AuthWithErrorHTTPHandler(apiHandler, ac.expectedHandler) + }, + }, + { + requestPath: "/admin1", + expectedHandler: &defaultAuthErrorHTTPHandler, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.AdminOnly(apiHandler) + }, + }, + { + requestPath: "/admin2", + expectedHandler: &testAuthErrorHTTPHandler{ + statusCode: 404, + contentType: "application/json", + responseBody: `{"code": 404, "message": "from admin2 error handler"}`, + }, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.AdminOnlyWithErrorHTTPHandler(apiHandler, ac.expectedHandler) + }, + }, + { + requestPath: "/rbac1", + expectedHandler: &defaultAuthErrorHTTPHandler, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.RBAC("role1", "role2")(apiHandler) + }, + }, + { + requestPath: "/rbac2", + expectedHandler: &testAuthErrorHTTPHandler{ + statusCode: 406, + contentType: "text/html", + responseBody: `

from RBAC2 error handler

`, + }, + createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler { + return a.RBACwithErrorHTTPHandler(ac.expectedHandler, "role1", "role2")(apiHandler) + }, + }, + } + + options := Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + Issuer: "my-test-app", + URL: "http://127.0.0.1:8089", + AuthErrorHTTPHandler: &defaultAuthErrorHTTPHandler, + } + + svc := NewService(options) + svc.AddDevProvider("localhost", 18084) // add dev provider on 18084 + + mux := http.NewServeMux() + m := svc.Middleware() + + for _, ac := range apiCalls { + mux.Handle(ac.requestPath, ac.createMiddleware(ac, m)) + } + + l, listenErr := net.Listen("tcp", "127.0.0.1:8089") + require.Nil(t, listenErr) + ts := httptest.NewUnstartedServer(mux) + assert.NoError(t, ts.Listener.Close()) + ts.Listener = l + ts.Start() + defer func() { + ts.Close() + }() + + for _, ac := range apiCalls { + t.Run("auth error test for endpoint "+ac.requestPath, func(t *testing.T) { + th := ac.expectedHandler + th.wasCalled = false + + resp, err := http.Get("http://127.0.0.1:8089" + ac.requestPath) + require.NoError(t, err) + defer resp.Body.Close() + + require.True(t, th.wasCalled) + + require.Equal(t, th.statusCode, resp.StatusCode) + require.Equal(t, th.contentType, resp.Header.Get("Content-Type")) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, th.responseBody, string(b)) + }) + } +} + func TestIntegrationUserInfo(t *testing.T) { _, teardown := prepService(t) defer teardown() diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go index 6e8c5ed8..61e1c08c 100644 --- a/v2/middleware/auth.go +++ b/v2/middleware/auth.go @@ -18,12 +18,13 @@ import ( // Authenticator is top level auth object providing middlewares type Authenticator struct { logger.L - JWTService TokenService - Providers []provider.Service - Validator token.Validator - AdminPasswd string - BasicAuthChecker BasicAuthFunc - RefreshCache RefreshCache + JWTService TokenService + Providers []provider.Service + Validator token.Validator + AdminPasswd string + BasicAuthChecker BasicAuthFunc + RefreshCache RefreshCache + AuthErrorHTTPHandler AuthErrorHTTPHandler } // RefreshCache defines interface storing and retrieving refreshed tokens @@ -45,6 +46,28 @@ type TokenService interface { // The second return parameter `User` need for add user claims into context of request. type BasicAuthFunc func(user, passwd string) (ok bool, userInfo token.User, err error) +// AuthErrorHTTPHandler defines interface for handling HTTP responses in case of authentication errors +type AuthErrorHTTPHandler interface { + // Serves HTTP response in case of authentication error + // w - response writer + // r - original request + // authError - authentication error with technical details + // reason - reason text + // statusCode - HTTP status code + ServeAuthError(w http.ResponseWriter, r *http.Request, authError error, reason string, statusCode int) +} + +// DefaultAuthErrorHTTPHandler is a default implementation, which writes text/plain responses using http.Error() +type DefaultAuthErrorHTTPHandler struct { + logger.L +} + +// ServeAuthError writes text/plain responses using http.Error() +func (h DefaultAuthErrorHTTPHandler) ServeAuthError(w http.ResponseWriter, _ *http.Request, authError error, reason string, statusCode int) { + h.Logf("[DEBUG] auth failed, %v", authError) + http.Error(w, reason, statusCode) +} + // adminUser sets claims for an optional basic auth var adminUser = token.User{ ID: "admin", @@ -56,24 +79,29 @@ var adminUser = token.User{ // Auth middleware adds auth from session and populates user info func (a *Authenticator) Auth(next http.Handler) http.Handler { - return a.auth(true)(next) + return a.auth(true, a.getAuthErrorHTTPHandler())(next) +} + +// AuthWithErrorHTTPHandler middleware adds auth from session and populates user info. +// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error. +func (a *Authenticator) AuthWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler { + return a.auth(true, errorHTTPHandler)(next) } // Trace middleware doesn't require valid user but if user info presented populates info func (a *Authenticator) Trace(next http.Handler) http.Handler { - return a.auth(false)(next) + return a.auth(false, a.getAuthErrorHTTPHandler())(next) } // auth implements all logic for authentication (reqAuth=true) and tracing (reqAuth=false) -func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { +func (a *Authenticator) auth(reqAuth bool, errorHTTPHandler AuthErrorHTTPHandler) func(http.Handler) http.Handler { onError := func(h http.Handler, w http.ResponseWriter, r *http.Request, err error) { if !reqAuth { // if no auth required allow to proceeded on error h.ServeHTTP(w, r) return } - a.Logf("[DEBUG] auth failed, %v", err) - http.Error(w, "Unauthorized", http.StatusUnauthorized) + errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized) } f := func(h http.Handler) http.Handler { @@ -191,23 +219,34 @@ func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token. return c, nil } -// AdminOnly middleware allows access for admins only -// this handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth +// AdminOnly middleware allows access for admins only. +// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth func (a *Authenticator) AdminOnly(next http.Handler) http.Handler { + return a.adminOnly(next, a.getAuthErrorHTTPHandler()) +} + +// AdminOnlyWithErrorHTTPHandler middleware allows access for admins only. +// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth. +// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error. +func (a *Authenticator) AdminOnlyWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler { + return a.adminOnly(next, errorHTTPHandler) +} + +func (a *Authenticator) adminOnly(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { user, err := token.GetUserInfo(r) if err != nil { - http.Error(w, "Unauthorized", http.StatusUnauthorized) + errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized) return } if !user.IsAdmin() { - http.Error(w, "Access denied", http.StatusForbidden) + errorHTTPHandler.ServeAuthError(w, r, fmt.Errorf("user %s/%s is not admin", user.Name, user.ID), "Access denied", http.StatusForbidden) return } next.ServeHTTP(w, r) } - return a.auth(true)(http.HandlerFunc(fn)) // enforce auth + return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth } // basic auth for admin user @@ -234,12 +273,23 @@ func (a *Authenticator) basicAdminUser(r *http.Request) bool { // RBAC middleware allows role based control for routes // this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler { + return a.rbac(a.getAuthErrorHTTPHandler(), roles...) +} + +// RBACwithErrorHTTPHandler middleware allows role based control for routes +// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth +// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error. +func (a *Authenticator) RBACwithErrorHTTPHandler(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler { + return a.rbac(errorHTTPHandler, roles...) +} + +func (a *Authenticator) rbac(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler { f := func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { user, err := token.GetUserInfo(r) if err != nil { - http.Error(w, "Unauthorized", http.StatusUnauthorized) + errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized) return } @@ -251,12 +301,26 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler { } } if !matched { - http.Error(w, "Access denied", http.StatusForbidden) + errorHTTPHandler.ServeAuthError( + w, + r, + fmt.Errorf("user %s/%s does not have any of required roles: %s", user.Name, user.ID, roles), + "Access denied", + http.StatusForbidden, + ) return } h.ServeHTTP(w, r) } - return a.auth(true)(http.HandlerFunc(fn)) // enforce auth + return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth } return f } + +func (a *Authenticator) getAuthErrorHTTPHandler() AuthErrorHTTPHandler { + if a.AuthErrorHTTPHandler != nil { + return a.AuthErrorHTTPHandler + } + + return DefaultAuthErrorHTTPHandler{L: a.L} +} diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go index d71e0c9f..37baa629 100644 --- a/v2/middleware/auth_test.go +++ b/v2/middleware/auth_test.go @@ -502,6 +502,70 @@ func TestRBAC(t *testing.T) { assert.Equal(t, "Access denied\n", string(data)) } +type testAuthErrorHTTPHandler struct { + wasCalled bool + statusCode int +} + +func (h *testAuthErrorHTTPHandler) ServeAuthError( + w http.ResponseWriter, + _ *http.Request, + _ error, + _ string, + _ int, +) { + h.wasCalled = true + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(h.statusCode) + fmt.Fprint(w, "Unauthorized") +} + +func TestAuthErrorHTTPHandler(t *testing.T) { + testErrorHandler1 := &testAuthErrorHTTPHandler{statusCode: 401} + testErrorHandler2 := &testAuthErrorHTTPHandler{statusCode: 402} + testErrorHandler3 := &testAuthErrorHTTPHandler{statusCode: 403} + testErrorHandler4 := &testAuthErrorHTTPHandler{statusCode: 404} + + a := makeTestAuth(t) + a.AuthErrorHTTPHandler = testErrorHandler1 + + handler := http.HandlerFunc( + func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("must not be called\n")) + t.Error("auth error must be raised before this HTTP handler is called") + }, + ) + + mux := http.NewServeMux() + mux.Handle("/private1", a.Auth(handler)) + mux.Handle("/private2", a.AuthWithErrorHTTPHandler(handler, testErrorHandler2)) + mux.Handle("/admin1", a.AdminOnly(handler)) + mux.Handle("/admin2", a.AdminOnlyWithErrorHTTPHandler(handler, testErrorHandler3)) + mux.Handle("/rbac1", a.RBAC("role1")(handler)) + mux.Handle("/rbac2", a.RBACwithErrorHTTPHandler(testErrorHandler4, "role1")(handler)) + + server := httptest.NewServer(mux) + defer server.Close() + + assertThatHandlerWasCalledProperly := func(t *testing.T, errorHandler *testAuthErrorHTTPHandler, path string) { + errorHandler.wasCalled = false + + resp, err := http.Get(server.URL + path) + require.NoError(t, err) + defer resp.Body.Close() + + require.True(t, errorHandler.wasCalled, "error handler must be called") + require.Equal(t, errorHandler.statusCode, resp.StatusCode, "error handler must produce proper status code") + } + + assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/private1") + assertThatHandlerWasCalledProperly(t, testErrorHandler2, "/private2") + assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/admin1") + assertThatHandlerWasCalledProperly(t, testErrorHandler3, "/admin2") + assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/rbac1") + assertThatHandlerWasCalledProperly(t, testErrorHandler4, "/rbac2") +} + func makeTestMux(_ *testing.T, a *Authenticator, required bool) http.Handler { mux := http.NewServeMux() authMiddleware := a.Auth