diff --git a/README.md b/README.md
index 2ddb757a..d752ea89 100644
--- a/README.md
+++ b/README.md
@@ -359,8 +359,9 @@ There are several ways to adjust functionality of the library:
1. `ClaimsUpdater` - interface with `Update(claims Claims) Claims` method. This is the primary way to alter a token at login time and add any attributes, set ip, email, admin status, roles and so on.
1. `Validator` - interface with `Validate(token string, claims Claims) bool` method. This is post-token hook and will be called on **each request** wrapped with `Auth` middleware. This will be the place for special logic to reject some tokens or users.
1. `UserUpdater` - interface with `Update(claims token.User) token.User` method. This method will be called on **each request** wrapped with `UpdateUser` middleware. This will be the place for special logic modify User Info in request context. [Example of usage.](https://github.com/go-pkgz/auth/blob/19c1b6d26608494955a4480f8f6165af85b1deab/_example/main.go#L189)
+1. `AuthErrorHTTPHandler` - interface with `ServeAuthError(w http.ResponseWriter, r *http.Request, and other params)` method. It is possible to change how authentication errors are written into HTTP responses by configuring custom implementations of this interface for the middlewares.
-All of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.
+All of the interfaces above except `AuthErrorHTTPHandler` have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.
### Implementing black list logic or some other filters
diff --git a/auth.go b/auth.go
index 5bd9d879..bcdd7d68 100644
--- a/auth.go
+++ b/auth.go
@@ -66,12 +66,13 @@ type Opts struct {
AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar`
UseGravatar bool // for email based auth (verified provider) use gravatar service
- AdminPasswd string // if presented, allows basic auth with user admin and given password
- BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
- AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
- AudSecrets bool // allow multiple secrets (secret per aud)
- Logger logger.L // logger interface, default is no logging at all
- RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
+ AdminPasswd string // if presented, allows basic auth with user admin and given password
+ BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
+ AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
+ AudSecrets bool // allow multiple secrets (secret per aud)
+ Logger logger.L // logger interface, default is no logging at all
+ RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
+ AuthErrorHTTPHandler middleware.AuthErrorHTTPHandler // optional HTTP handler for authentication errors
}
// NewService initializes everything
@@ -81,10 +82,11 @@ func NewService(opts Opts) (res *Service) {
opts: opts,
logger: opts.Logger,
authMiddleware: middleware.Authenticator{
- Validator: opts.Validator,
- AdminPasswd: opts.AdminPasswd,
- BasicAuthChecker: opts.BasicAuthChecker,
- RefreshCache: opts.RefreshCache,
+ Validator: opts.Validator,
+ AdminPasswd: opts.AdminPasswd,
+ BasicAuthChecker: opts.BasicAuthChecker,
+ RefreshCache: opts.RefreshCache,
+ AuthErrorHTTPHandler: opts.AuthErrorHTTPHandler,
},
issuer: opts.Issuer,
useGravatar: opts.UseGravatar,
diff --git a/auth_test.go b/auth_test.go
index 06a3b8aa..73615981 100644
--- a/auth_test.go
+++ b/auth_test.go
@@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
+ "fmt"
"io"
"net"
"net/http"
@@ -19,6 +20,7 @@ import (
"github.com/go-pkgz/auth/avatar"
"github.com/go-pkgz/auth/logger"
+ "github.com/go-pkgz/auth/middleware"
"github.com/go-pkgz/auth/provider"
"github.com/go-pkgz/auth/token"
)
@@ -246,6 +248,148 @@ func TestIntegrationList(t *testing.T) {
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}
+type testAuthErrorHTTPHandler struct {
+ statusCode int
+ contentType string
+ responseBody string
+ wasCalled bool
+}
+
+func (h *testAuthErrorHTTPHandler) ServeAuthError(
+ w http.ResponseWriter,
+ _ *http.Request,
+ authError error,
+ reason string,
+ statusCode int,
+) {
+ h.wasCalled = true
+ w.Header().Set("Content-Type", h.contentType)
+ w.WriteHeader(h.statusCode)
+ fmt.Fprint(w, h.responseBody)
+}
+
+func TestIntegrationAuthErrorHTTPHandler(t *testing.T) {
+ apiHandler := http.HandlerFunc(
+ func(w http.ResponseWriter, _ *http.Request) {
+ _, _ = w.Write([]byte("must not be called\n"))
+ t.Error("auth error must be raised before this HTTP handler is called")
+ },
+ )
+ defaultAuthErrorHTTPHandler := testAuthErrorHTTPHandler{
+ statusCode: 400,
+ contentType: "application/json",
+ responseBody: `{"code": 400, "message": "from general error handler"}`,
+ }
+ type apiCall struct {
+ requestPath string
+ expectedHandler *testAuthErrorHTTPHandler
+ createMiddleware func(apiCall, middleware.Authenticator) http.Handler
+ }
+ apiCalls := []apiCall{
+ {
+ requestPath: "/private1",
+ expectedHandler: &defaultAuthErrorHTTPHandler,
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.Auth(apiHandler)
+ },
+ },
+ {
+ requestPath: "/private2",
+ expectedHandler: &testAuthErrorHTTPHandler{
+ statusCode: 402,
+ contentType: "application/json",
+ responseBody: `{"code": 402, "message": "from private2 error handler"}`,
+ },
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.AuthWithErrorHTTPHandler(apiHandler, ac.expectedHandler)
+ },
+ },
+ {
+ requestPath: "/admin1",
+ expectedHandler: &defaultAuthErrorHTTPHandler,
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.AdminOnly(apiHandler)
+ },
+ },
+ {
+ requestPath: "/admin2",
+ expectedHandler: &testAuthErrorHTTPHandler{
+ statusCode: 404,
+ contentType: "application/json",
+ responseBody: `{"code": 404, "message": "from admin2 error handler"}`,
+ },
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.AdminOnlyWithErrorHTTPHandler(apiHandler, ac.expectedHandler)
+ },
+ },
+ {
+ requestPath: "/rbac1",
+ expectedHandler: &defaultAuthErrorHTTPHandler,
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.RBAC("role1", "role2")(apiHandler)
+ },
+ },
+ {
+ requestPath: "/rbac2",
+ expectedHandler: &testAuthErrorHTTPHandler{
+ statusCode: 406,
+ contentType: "text/html",
+ responseBody: `
from RBAC2 error handler
`,
+ },
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.RBACwithErrorHTTPHandler(ac.expectedHandler, "role1", "role2")(apiHandler)
+ },
+ },
+ }
+
+ options := Opts{
+ SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
+ Issuer: "my-test-app",
+ URL: "http://127.0.0.1:8089",
+ AuthErrorHTTPHandler: &defaultAuthErrorHTTPHandler,
+ }
+
+ svc := NewService(options)
+ svc.AddDevProvider("localhost", 18084) // add dev provider on 18084
+
+ mux := http.NewServeMux()
+ m := svc.Middleware()
+
+ for _, ac := range apiCalls {
+ mux.Handle(ac.requestPath, ac.createMiddleware(ac, m))
+ }
+
+ l, listenErr := net.Listen("tcp", "127.0.0.1:8089")
+ require.Nil(t, listenErr)
+ ts := httptest.NewUnstartedServer(mux)
+ assert.NoError(t, ts.Listener.Close())
+ ts.Listener = l
+ ts.Start()
+ defer func() {
+ ts.Close()
+ }()
+
+ for _, ac := range apiCalls {
+ t.Run("auth error test for endpoint "+ac.requestPath, func(t *testing.T) {
+ th := ac.expectedHandler
+ th.wasCalled = false
+
+ resp, err := http.Get("http://127.0.0.1:8089" + ac.requestPath)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ require.True(t, th.wasCalled)
+
+ require.Equal(t, th.statusCode, resp.StatusCode)
+ require.Equal(t, th.contentType, resp.Header.Get("Content-Type"))
+
+ b, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ require.Equal(t, th.responseBody, string(b))
+ })
+ }
+}
+
func TestIntegrationUserInfo(t *testing.T) {
_, teardown := prepService(t)
defer teardown()
diff --git a/middleware/auth.go b/middleware/auth.go
index 64d6c7ce..23e95041 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -18,12 +18,13 @@ import (
// Authenticator is top level auth object providing middlewares
type Authenticator struct {
logger.L
- JWTService TokenService
- Providers []provider.Service
- Validator token.Validator
- AdminPasswd string
- BasicAuthChecker BasicAuthFunc
- RefreshCache RefreshCache
+ JWTService TokenService
+ Providers []provider.Service
+ Validator token.Validator
+ AdminPasswd string
+ BasicAuthChecker BasicAuthFunc
+ RefreshCache RefreshCache
+ AuthErrorHTTPHandler AuthErrorHTTPHandler
}
// RefreshCache defines interface storing and retrieving refreshed tokens
@@ -45,6 +46,28 @@ type TokenService interface {
// The second return parameter `User` need for add user claims into context of request.
type BasicAuthFunc func(user, passwd string) (ok bool, userInfo token.User, err error)
+// AuthErrorHTTPHandler defines interface for handling HTTP responses in case of authentication errors
+type AuthErrorHTTPHandler interface {
+ // Serves HTTP response in case of authentication error
+ // w - response writer
+ // r - original request
+ // authError - authentication error with technical details
+ // reason - reason text
+ // statusCode - HTTP status code
+ ServeAuthError(w http.ResponseWriter, r *http.Request, authError error, reason string, statusCode int)
+}
+
+// DefaultAuthErrorHTTPHandler is a default implementation, which writes text/plain responses using http.Error()
+type DefaultAuthErrorHTTPHandler struct {
+ logger.L
+}
+
+// ServeAuthError writes text/plain responses using http.Error()
+func (h DefaultAuthErrorHTTPHandler) ServeAuthError(w http.ResponseWriter, _ *http.Request, authError error, reason string, statusCode int) {
+ h.Logf("[DEBUG] auth failed, %v", authError)
+ http.Error(w, reason, statusCode)
+}
+
// adminUser sets claims for an optional basic auth
var adminUser = token.User{
ID: "admin",
@@ -56,24 +79,29 @@ var adminUser = token.User{
// Auth middleware adds auth from session and populates user info
func (a *Authenticator) Auth(next http.Handler) http.Handler {
- return a.auth(true)(next)
+ return a.auth(true, a.getAuthErrorHTTPHandler())(next)
+}
+
+// AuthWithErrorHTTPHandler middleware adds auth from session and populates user info.
+// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
+func (a *Authenticator) AuthWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
+ return a.auth(true, errorHTTPHandler)(next)
}
// Trace middleware doesn't require valid user but if user info presented populates info
func (a *Authenticator) Trace(next http.Handler) http.Handler {
- return a.auth(false)(next)
+ return a.auth(false, a.getAuthErrorHTTPHandler())(next)
}
// auth implements all logic for authentication (reqAuth=true) and tracing (reqAuth=false)
-func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
+func (a *Authenticator) auth(reqAuth bool, errorHTTPHandler AuthErrorHTTPHandler) func(http.Handler) http.Handler {
onError := func(h http.Handler, w http.ResponseWriter, r *http.Request, err error) {
if !reqAuth { // if no auth required allow to proceeded on error
h.ServeHTTP(w, r)
return
}
- a.Logf("[DEBUG] auth failed, %v", err)
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
}
f := func(h http.Handler) http.Handler {
@@ -191,23 +219,34 @@ func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token.
return c, nil
}
-// AdminOnly middleware allows access for admins only
-// this handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth
+// AdminOnly middleware allows access for admins only.
+// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth
func (a *Authenticator) AdminOnly(next http.Handler) http.Handler {
+ return a.adminOnly(next, a.getAuthErrorHTTPHandler())
+}
+
+// AdminOnlyWithErrorHTTPHandler middleware allows access for admins only.
+// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth.
+// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
+func (a *Authenticator) AdminOnlyWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
+ return a.adminOnly(next, errorHTTPHandler)
+}
+
+func (a *Authenticator) adminOnly(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
user, err := token.GetUserInfo(r)
if err != nil {
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
return
}
if !user.IsAdmin() {
- http.Error(w, "Access denied", http.StatusForbidden)
+ errorHTTPHandler.ServeAuthError(w, r, fmt.Errorf("user %s/%s is not admin", user.Name, user.ID), "Access denied", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
}
- return a.auth(true)(http.HandlerFunc(fn)) // enforce auth
+ return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth
}
// basic auth for admin user
@@ -234,12 +273,23 @@ func (a *Authenticator) basicAdminUser(r *http.Request) bool {
// RBAC middleware allows role based control for routes
// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth
func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
+ return a.rbac(a.getAuthErrorHTTPHandler(), roles...)
+}
+
+// RBACwithErrorHTTPHandler middleware allows role based control for routes
+// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth
+// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
+func (a *Authenticator) RBACwithErrorHTTPHandler(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler {
+ return a.rbac(errorHTTPHandler, roles...)
+}
+
+func (a *Authenticator) rbac(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
user, err := token.GetUserInfo(r)
if err != nil {
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
return
}
@@ -251,12 +301,26 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
}
}
if !matched {
- http.Error(w, "Access denied", http.StatusForbidden)
+ errorHTTPHandler.ServeAuthError(
+ w,
+ r,
+ fmt.Errorf("user %s/%s does not have any of required roles: %s", user.Name, user.ID, roles),
+ "Access denied",
+ http.StatusForbidden,
+ )
return
}
h.ServeHTTP(w, r)
}
- return a.auth(true)(http.HandlerFunc(fn)) // enforce auth
+ return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth
}
return f
}
+
+func (a *Authenticator) getAuthErrorHTTPHandler() AuthErrorHTTPHandler {
+ if a.AuthErrorHTTPHandler != nil {
+ return a.AuthErrorHTTPHandler
+ }
+
+ return DefaultAuthErrorHTTPHandler{L: a.L}
+}
diff --git a/middleware/auth_test.go b/middleware/auth_test.go
index 18208aee..c923b1a4 100644
--- a/middleware/auth_test.go
+++ b/middleware/auth_test.go
@@ -503,6 +503,70 @@ func TestRBAC(t *testing.T) {
assert.Equal(t, "Access denied\n", string(data))
}
+type testAuthErrorHTTPHandler struct {
+ wasCalled bool
+ statusCode int
+}
+
+func (h *testAuthErrorHTTPHandler) ServeAuthError(
+ w http.ResponseWriter,
+ _ *http.Request,
+ _ error,
+ _ string,
+ _ int,
+) {
+ h.wasCalled = true
+ w.Header().Set("Content-Type", "text/plain")
+ w.WriteHeader(h.statusCode)
+ fmt.Fprint(w, "Unauthorized")
+}
+
+func TestAuthErrorHTTPHandler(t *testing.T) {
+ testErrorHandler1 := &testAuthErrorHTTPHandler{statusCode: 401}
+ testErrorHandler2 := &testAuthErrorHTTPHandler{statusCode: 402}
+ testErrorHandler3 := &testAuthErrorHTTPHandler{statusCode: 403}
+ testErrorHandler4 := &testAuthErrorHTTPHandler{statusCode: 404}
+
+ a := makeTestAuth(t)
+ a.AuthErrorHTTPHandler = testErrorHandler1
+
+ handler := http.HandlerFunc(
+ func(w http.ResponseWriter, _ *http.Request) { // token required
+ _, _ = w.Write([]byte("must not be called\n"))
+ t.Error("auth error must be raised before this HTTP handler is called")
+ },
+ )
+
+ mux := http.NewServeMux()
+ mux.Handle("/private1", a.Auth(handler))
+ mux.Handle("/private2", a.AuthWithErrorHTTPHandler(handler, testErrorHandler2))
+ mux.Handle("/admin1", a.AdminOnly(handler))
+ mux.Handle("/admin2", a.AdminOnlyWithErrorHTTPHandler(handler, testErrorHandler3))
+ mux.Handle("/rbac1", a.RBAC("role1")(handler))
+ mux.Handle("/rbac2", a.RBACwithErrorHTTPHandler(testErrorHandler4, "role1")(handler))
+
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ assertThatHandlerWasCalledProperly := func(t *testing.T, errorHandler *testAuthErrorHTTPHandler, path string) {
+ errorHandler.wasCalled = false
+
+ resp, err := http.Get(server.URL + path)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ require.True(t, errorHandler.wasCalled, "error handler must be called")
+ require.Equal(t, errorHandler.statusCode, resp.StatusCode, "error handler must produce proper status code")
+ }
+
+ assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/private1")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler2, "/private2")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/admin1")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler3, "/admin2")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/rbac1")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler4, "/rbac2")
+}
+
func makeTestMux(_ *testing.T, a *Authenticator, required bool) http.Handler {
mux := http.NewServeMux()
authMiddleware := a.Auth
diff --git a/v2/auth.go b/v2/auth.go
index 30d4a613..e5908399 100644
--- a/v2/auth.go
+++ b/v2/auth.go
@@ -66,12 +66,13 @@ type Opts struct {
AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar`
UseGravatar bool // for email based auth (verified provider) use gravatar service
- AdminPasswd string // if presented, allows basic auth with user admin and given password
- BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
- AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
- AudSecrets bool // allow multiple secrets (secret per aud)
- Logger logger.L // logger interface, default is no logging at all
- RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
+ AdminPasswd string // if presented, allows basic auth with user admin and given password
+ BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
+ AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
+ AudSecrets bool // allow multiple secrets (secret per aud)
+ Logger logger.L // logger interface, default is no logging at all
+ RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
+ AuthErrorHTTPHandler middleware.AuthErrorHTTPHandler // optional HTTP handler for authentication errors
}
// NewService initializes everything
@@ -81,10 +82,11 @@ func NewService(opts Opts) (res *Service) {
opts: opts,
logger: opts.Logger,
authMiddleware: middleware.Authenticator{
- Validator: opts.Validator,
- AdminPasswd: opts.AdminPasswd,
- BasicAuthChecker: opts.BasicAuthChecker,
- RefreshCache: opts.RefreshCache,
+ Validator: opts.Validator,
+ AdminPasswd: opts.AdminPasswd,
+ BasicAuthChecker: opts.BasicAuthChecker,
+ RefreshCache: opts.RefreshCache,
+ AuthErrorHTTPHandler: opts.AuthErrorHTTPHandler,
},
issuer: opts.Issuer,
useGravatar: opts.UseGravatar,
diff --git a/v2/auth_test.go b/v2/auth_test.go
index 81655bb6..bc3c1852 100644
--- a/v2/auth_test.go
+++ b/v2/auth_test.go
@@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
+ "fmt"
"io"
"net"
"net/http"
@@ -19,6 +20,7 @@ import (
"github.com/go-pkgz/auth/v2/avatar"
"github.com/go-pkgz/auth/v2/logger"
+ "github.com/go-pkgz/auth/v2/middleware"
"github.com/go-pkgz/auth/v2/provider"
"github.com/go-pkgz/auth/v2/token"
)
@@ -246,6 +248,148 @@ func TestIntegrationList(t *testing.T) {
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}
+type testAuthErrorHTTPHandler struct {
+ statusCode int
+ contentType string
+ responseBody string
+ wasCalled bool
+}
+
+func (h *testAuthErrorHTTPHandler) ServeAuthError(
+ w http.ResponseWriter,
+ _ *http.Request,
+ authError error,
+ reason string,
+ statusCode int,
+) {
+ h.wasCalled = true
+ w.Header().Set("Content-Type", h.contentType)
+ w.WriteHeader(h.statusCode)
+ fmt.Fprint(w, h.responseBody)
+}
+
+func TestIntegrationAuthErrorHTTPHandler(t *testing.T) {
+ apiHandler := http.HandlerFunc(
+ func(w http.ResponseWriter, _ *http.Request) {
+ _, _ = w.Write([]byte("must not be called\n"))
+ t.Error("auth error must be raised before this HTTP handler is called")
+ },
+ )
+ defaultAuthErrorHTTPHandler := testAuthErrorHTTPHandler{
+ statusCode: 400,
+ contentType: "application/json",
+ responseBody: `{"code": 400, "message": "from general error handler"}`,
+ }
+ type apiCall struct {
+ requestPath string
+ expectedHandler *testAuthErrorHTTPHandler
+ createMiddleware func(apiCall, middleware.Authenticator) http.Handler
+ }
+ apiCalls := []apiCall{
+ {
+ requestPath: "/private1",
+ expectedHandler: &defaultAuthErrorHTTPHandler,
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.Auth(apiHandler)
+ },
+ },
+ {
+ requestPath: "/private2",
+ expectedHandler: &testAuthErrorHTTPHandler{
+ statusCode: 402,
+ contentType: "application/json",
+ responseBody: `{"code": 402, "message": "from private2 error handler"}`,
+ },
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.AuthWithErrorHTTPHandler(apiHandler, ac.expectedHandler)
+ },
+ },
+ {
+ requestPath: "/admin1",
+ expectedHandler: &defaultAuthErrorHTTPHandler,
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.AdminOnly(apiHandler)
+ },
+ },
+ {
+ requestPath: "/admin2",
+ expectedHandler: &testAuthErrorHTTPHandler{
+ statusCode: 404,
+ contentType: "application/json",
+ responseBody: `{"code": 404, "message": "from admin2 error handler"}`,
+ },
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.AdminOnlyWithErrorHTTPHandler(apiHandler, ac.expectedHandler)
+ },
+ },
+ {
+ requestPath: "/rbac1",
+ expectedHandler: &defaultAuthErrorHTTPHandler,
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.RBAC("role1", "role2")(apiHandler)
+ },
+ },
+ {
+ requestPath: "/rbac2",
+ expectedHandler: &testAuthErrorHTTPHandler{
+ statusCode: 406,
+ contentType: "text/html",
+ responseBody: `from RBAC2 error handler
`,
+ },
+ createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
+ return a.RBACwithErrorHTTPHandler(ac.expectedHandler, "role1", "role2")(apiHandler)
+ },
+ },
+ }
+
+ options := Opts{
+ SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
+ Issuer: "my-test-app",
+ URL: "http://127.0.0.1:8089",
+ AuthErrorHTTPHandler: &defaultAuthErrorHTTPHandler,
+ }
+
+ svc := NewService(options)
+ svc.AddDevProvider("localhost", 18084) // add dev provider on 18084
+
+ mux := http.NewServeMux()
+ m := svc.Middleware()
+
+ for _, ac := range apiCalls {
+ mux.Handle(ac.requestPath, ac.createMiddleware(ac, m))
+ }
+
+ l, listenErr := net.Listen("tcp", "127.0.0.1:8089")
+ require.Nil(t, listenErr)
+ ts := httptest.NewUnstartedServer(mux)
+ assert.NoError(t, ts.Listener.Close())
+ ts.Listener = l
+ ts.Start()
+ defer func() {
+ ts.Close()
+ }()
+
+ for _, ac := range apiCalls {
+ t.Run("auth error test for endpoint "+ac.requestPath, func(t *testing.T) {
+ th := ac.expectedHandler
+ th.wasCalled = false
+
+ resp, err := http.Get("http://127.0.0.1:8089" + ac.requestPath)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ require.True(t, th.wasCalled)
+
+ require.Equal(t, th.statusCode, resp.StatusCode)
+ require.Equal(t, th.contentType, resp.Header.Get("Content-Type"))
+
+ b, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ require.Equal(t, th.responseBody, string(b))
+ })
+ }
+}
+
func TestIntegrationUserInfo(t *testing.T) {
_, teardown := prepService(t)
defer teardown()
diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go
index 6e8c5ed8..61e1c08c 100644
--- a/v2/middleware/auth.go
+++ b/v2/middleware/auth.go
@@ -18,12 +18,13 @@ import (
// Authenticator is top level auth object providing middlewares
type Authenticator struct {
logger.L
- JWTService TokenService
- Providers []provider.Service
- Validator token.Validator
- AdminPasswd string
- BasicAuthChecker BasicAuthFunc
- RefreshCache RefreshCache
+ JWTService TokenService
+ Providers []provider.Service
+ Validator token.Validator
+ AdminPasswd string
+ BasicAuthChecker BasicAuthFunc
+ RefreshCache RefreshCache
+ AuthErrorHTTPHandler AuthErrorHTTPHandler
}
// RefreshCache defines interface storing and retrieving refreshed tokens
@@ -45,6 +46,28 @@ type TokenService interface {
// The second return parameter `User` need for add user claims into context of request.
type BasicAuthFunc func(user, passwd string) (ok bool, userInfo token.User, err error)
+// AuthErrorHTTPHandler defines interface for handling HTTP responses in case of authentication errors
+type AuthErrorHTTPHandler interface {
+ // Serves HTTP response in case of authentication error
+ // w - response writer
+ // r - original request
+ // authError - authentication error with technical details
+ // reason - reason text
+ // statusCode - HTTP status code
+ ServeAuthError(w http.ResponseWriter, r *http.Request, authError error, reason string, statusCode int)
+}
+
+// DefaultAuthErrorHTTPHandler is a default implementation, which writes text/plain responses using http.Error()
+type DefaultAuthErrorHTTPHandler struct {
+ logger.L
+}
+
+// ServeAuthError writes text/plain responses using http.Error()
+func (h DefaultAuthErrorHTTPHandler) ServeAuthError(w http.ResponseWriter, _ *http.Request, authError error, reason string, statusCode int) {
+ h.Logf("[DEBUG] auth failed, %v", authError)
+ http.Error(w, reason, statusCode)
+}
+
// adminUser sets claims for an optional basic auth
var adminUser = token.User{
ID: "admin",
@@ -56,24 +79,29 @@ var adminUser = token.User{
// Auth middleware adds auth from session and populates user info
func (a *Authenticator) Auth(next http.Handler) http.Handler {
- return a.auth(true)(next)
+ return a.auth(true, a.getAuthErrorHTTPHandler())(next)
+}
+
+// AuthWithErrorHTTPHandler middleware adds auth from session and populates user info.
+// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
+func (a *Authenticator) AuthWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
+ return a.auth(true, errorHTTPHandler)(next)
}
// Trace middleware doesn't require valid user but if user info presented populates info
func (a *Authenticator) Trace(next http.Handler) http.Handler {
- return a.auth(false)(next)
+ return a.auth(false, a.getAuthErrorHTTPHandler())(next)
}
// auth implements all logic for authentication (reqAuth=true) and tracing (reqAuth=false)
-func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
+func (a *Authenticator) auth(reqAuth bool, errorHTTPHandler AuthErrorHTTPHandler) func(http.Handler) http.Handler {
onError := func(h http.Handler, w http.ResponseWriter, r *http.Request, err error) {
if !reqAuth { // if no auth required allow to proceeded on error
h.ServeHTTP(w, r)
return
}
- a.Logf("[DEBUG] auth failed, %v", err)
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
}
f := func(h http.Handler) http.Handler {
@@ -191,23 +219,34 @@ func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token.
return c, nil
}
-// AdminOnly middleware allows access for admins only
-// this handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth
+// AdminOnly middleware allows access for admins only.
+// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth
func (a *Authenticator) AdminOnly(next http.Handler) http.Handler {
+ return a.adminOnly(next, a.getAuthErrorHTTPHandler())
+}
+
+// AdminOnlyWithErrorHTTPHandler middleware allows access for admins only.
+// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth.
+// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
+func (a *Authenticator) AdminOnlyWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
+ return a.adminOnly(next, errorHTTPHandler)
+}
+
+func (a *Authenticator) adminOnly(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
user, err := token.GetUserInfo(r)
if err != nil {
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
return
}
if !user.IsAdmin() {
- http.Error(w, "Access denied", http.StatusForbidden)
+ errorHTTPHandler.ServeAuthError(w, r, fmt.Errorf("user %s/%s is not admin", user.Name, user.ID), "Access denied", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
}
- return a.auth(true)(http.HandlerFunc(fn)) // enforce auth
+ return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth
}
// basic auth for admin user
@@ -234,12 +273,23 @@ func (a *Authenticator) basicAdminUser(r *http.Request) bool {
// RBAC middleware allows role based control for routes
// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth
func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
+ return a.rbac(a.getAuthErrorHTTPHandler(), roles...)
+}
+
+// RBACwithErrorHTTPHandler middleware allows role based control for routes
+// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth
+// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
+func (a *Authenticator) RBACwithErrorHTTPHandler(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler {
+ return a.rbac(errorHTTPHandler, roles...)
+}
+
+func (a *Authenticator) rbac(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
user, err := token.GetUserInfo(r)
if err != nil {
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
return
}
@@ -251,12 +301,26 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
}
}
if !matched {
- http.Error(w, "Access denied", http.StatusForbidden)
+ errorHTTPHandler.ServeAuthError(
+ w,
+ r,
+ fmt.Errorf("user %s/%s does not have any of required roles: %s", user.Name, user.ID, roles),
+ "Access denied",
+ http.StatusForbidden,
+ )
return
}
h.ServeHTTP(w, r)
}
- return a.auth(true)(http.HandlerFunc(fn)) // enforce auth
+ return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth
}
return f
}
+
+func (a *Authenticator) getAuthErrorHTTPHandler() AuthErrorHTTPHandler {
+ if a.AuthErrorHTTPHandler != nil {
+ return a.AuthErrorHTTPHandler
+ }
+
+ return DefaultAuthErrorHTTPHandler{L: a.L}
+}
diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go
index d71e0c9f..37baa629 100644
--- a/v2/middleware/auth_test.go
+++ b/v2/middleware/auth_test.go
@@ -502,6 +502,70 @@ func TestRBAC(t *testing.T) {
assert.Equal(t, "Access denied\n", string(data))
}
+type testAuthErrorHTTPHandler struct {
+ wasCalled bool
+ statusCode int
+}
+
+func (h *testAuthErrorHTTPHandler) ServeAuthError(
+ w http.ResponseWriter,
+ _ *http.Request,
+ _ error,
+ _ string,
+ _ int,
+) {
+ h.wasCalled = true
+ w.Header().Set("Content-Type", "text/plain")
+ w.WriteHeader(h.statusCode)
+ fmt.Fprint(w, "Unauthorized")
+}
+
+func TestAuthErrorHTTPHandler(t *testing.T) {
+ testErrorHandler1 := &testAuthErrorHTTPHandler{statusCode: 401}
+ testErrorHandler2 := &testAuthErrorHTTPHandler{statusCode: 402}
+ testErrorHandler3 := &testAuthErrorHTTPHandler{statusCode: 403}
+ testErrorHandler4 := &testAuthErrorHTTPHandler{statusCode: 404}
+
+ a := makeTestAuth(t)
+ a.AuthErrorHTTPHandler = testErrorHandler1
+
+ handler := http.HandlerFunc(
+ func(w http.ResponseWriter, _ *http.Request) {
+ _, _ = w.Write([]byte("must not be called\n"))
+ t.Error("auth error must be raised before this HTTP handler is called")
+ },
+ )
+
+ mux := http.NewServeMux()
+ mux.Handle("/private1", a.Auth(handler))
+ mux.Handle("/private2", a.AuthWithErrorHTTPHandler(handler, testErrorHandler2))
+ mux.Handle("/admin1", a.AdminOnly(handler))
+ mux.Handle("/admin2", a.AdminOnlyWithErrorHTTPHandler(handler, testErrorHandler3))
+ mux.Handle("/rbac1", a.RBAC("role1")(handler))
+ mux.Handle("/rbac2", a.RBACwithErrorHTTPHandler(testErrorHandler4, "role1")(handler))
+
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ assertThatHandlerWasCalledProperly := func(t *testing.T, errorHandler *testAuthErrorHTTPHandler, path string) {
+ errorHandler.wasCalled = false
+
+ resp, err := http.Get(server.URL + path)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ require.True(t, errorHandler.wasCalled, "error handler must be called")
+ require.Equal(t, errorHandler.statusCode, resp.StatusCode, "error handler must produce proper status code")
+ }
+
+ assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/private1")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler2, "/private2")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/admin1")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler3, "/admin2")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler1, "/rbac1")
+ assertThatHandlerWasCalledProperly(t, testErrorHandler4, "/rbac2")
+}
+
func makeTestMux(_ *testing.T, a *Authenticator, required bool) http.Handler {
mux := http.NewServeMux()
authMiddleware := a.Auth