Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom HTTP handler for authentication errors #211

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ There are several ways to adjust functionality of the library:
1. `ClaimsUpdater` - interface with `Update(claims Claims) Claims` method. This is the primary way to alter a token at login time and add any attributes, set ip, email, admin status, roles and so on.
1. `Validator` - interface with `Validate(token string, claims Claims) bool` method. This is post-token hook and will be called on **each request** wrapped with `Auth` middleware. This will be the place for special logic to reject some tokens or users.
1. `UserUpdater` - interface with `Update(claims token.User) token.User` method. This method will be called on **each request** wrapped with `UpdateUser` middleware. This will be the place for special logic modify User Info in request context. [Example of usage.](https://github.com/go-pkgz/auth/blob/19c1b6d26608494955a4480f8f6165af85b1deab/_example/main.go#L189)
1. `AuthErrorHTTPHandler` - interface with `ServeAuthError(w http.ResponseWriter, r *http.Request, and other params)` method. It is possible to change how authentication errors are written into HTTP responses by configuring custom implementations of this interface for the middlewares.

All of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.
All of the interfaces above except `AuthErrorHTTPHandler` have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.

### Implementing black list logic or some other filters

Expand Down
22 changes: 12 additions & 10 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ type Opts struct {
AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar`
UseGravatar bool // for email based auth (verified provider) use gravatar service

AdminPasswd string // if presented, allows basic auth with user admin and given password
BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
AudSecrets bool // allow multiple secrets (secret per aud)
Logger logger.L // logger interface, default is no logging at all
RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
AdminPasswd string // if presented, allows basic auth with user admin and given password
BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
AudSecrets bool // allow multiple secrets (secret per aud)
Logger logger.L // logger interface, default is no logging at all
RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
AuthErrorHTTPHandler middleware.AuthErrorHTTPHandler // optional HTTP handler for authentication errors
}

// NewService initializes everything
Expand All @@ -81,10 +82,11 @@ func NewService(opts Opts) (res *Service) {
opts: opts,
logger: opts.Logger,
authMiddleware: middleware.Authenticator{
Validator: opts.Validator,
AdminPasswd: opts.AdminPasswd,
BasicAuthChecker: opts.BasicAuthChecker,
RefreshCache: opts.RefreshCache,
Validator: opts.Validator,
AdminPasswd: opts.AdminPasswd,
BasicAuthChecker: opts.BasicAuthChecker,
RefreshCache: opts.RefreshCache,
AuthErrorHTTPHandler: opts.AuthErrorHTTPHandler,
},
issuer: opts.Issuer,
useGravatar: opts.UseGravatar,
Expand Down
144 changes: 144 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
Expand All @@ -19,6 +20,7 @@ import (

"github.com/go-pkgz/auth/avatar"
"github.com/go-pkgz/auth/logger"
"github.com/go-pkgz/auth/middleware"
"github.com/go-pkgz/auth/provider"
"github.com/go-pkgz/auth/token"
)
Expand Down Expand Up @@ -246,6 +248,148 @@ func TestIntegrationList(t *testing.T) {
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}

type testAuthErrorHTTPHandler struct {
statusCode int
contentType string
responseBody string
wasCalled bool
}

func (h *testAuthErrorHTTPHandler) ServeAuthError(
w http.ResponseWriter,
_ *http.Request,
authError error,
reason string,
statusCode int,
) {
h.wasCalled = true
w.Header().Set("Content-Type", h.contentType)
w.WriteHeader(h.statusCode)
fmt.Fprint(w, h.responseBody)
}

func TestIntegrationAuthErrorHTTPHandler(t *testing.T) {
paskal marked this conversation as resolved.
Show resolved Hide resolved
apiHandler := http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("must not be called\n"))
t.Error("auth error must be raised before this HTTP handler is called")
},
)
defaultAuthErrorHTTPHandler := testAuthErrorHTTPHandler{
statusCode: 400,
contentType: "application/json",
responseBody: `{"code": 400, "message": "from general error handler"}`,
}
type apiCall struct {
requestPath string
expectedHandler *testAuthErrorHTTPHandler
createMiddleware func(apiCall, middleware.Authenticator) http.Handler
}
apiCalls := []apiCall{
{
requestPath: "/private1",
expectedHandler: &defaultAuthErrorHTTPHandler,
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.Auth(apiHandler)
},
},
{
requestPath: "/private2",
expectedHandler: &testAuthErrorHTTPHandler{
statusCode: 402,
contentType: "application/json",
responseBody: `{"code": 402, "message": "from private2 error handler"}`,
},
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.AuthWithErrorHTTPHandler(apiHandler, ac.expectedHandler)
},
},
{
requestPath: "/admin1",
expectedHandler: &defaultAuthErrorHTTPHandler,
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.AdminOnly(apiHandler)
},
},
{
requestPath: "/admin2",
expectedHandler: &testAuthErrorHTTPHandler{
statusCode: 404,
contentType: "application/json",
responseBody: `{"code": 404, "message": "from admin2 error handler"}`,
},
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.AdminOnlyWithErrorHTTPHandler(apiHandler, ac.expectedHandler)
},
},
{
requestPath: "/rbac1",
expectedHandler: &defaultAuthErrorHTTPHandler,
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.RBAC("role1", "role2")(apiHandler)
},
},
{
requestPath: "/rbac2",
expectedHandler: &testAuthErrorHTTPHandler{
statusCode: 406,
contentType: "text/html",
responseBody: `<html><body><h1>from RBAC2 error handler</h1></body></html>`,
},
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.RBACwithErrorHTTPHandler(ac.expectedHandler, "role1", "role2")(apiHandler)
},
},
}

options := Opts{
SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
Issuer: "my-test-app",
URL: "http://127.0.0.1:8089",
AuthErrorHTTPHandler: &defaultAuthErrorHTTPHandler,
}

svc := NewService(options)
svc.AddDevProvider("localhost", 18084) // add dev provider on 18084

mux := http.NewServeMux()
m := svc.Middleware()

for _, ac := range apiCalls {
mux.Handle(ac.requestPath, ac.createMiddleware(ac, m))
}

l, listenErr := net.Listen("tcp", "127.0.0.1:8089")
require.Nil(t, listenErr)
ts := httptest.NewUnstartedServer(mux)
assert.NoError(t, ts.Listener.Close())
ts.Listener = l
ts.Start()
defer func() {
ts.Close()
}()

for _, ac := range apiCalls {
t.Run("auth error test for endpoint "+ac.requestPath, func(t *testing.T) {
th := ac.expectedHandler
th.wasCalled = false

resp, err := http.Get("http://127.0.0.1:8089" + ac.requestPath)
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, th.wasCalled)

require.Equal(t, th.statusCode, resp.StatusCode)
require.Equal(t, th.contentType, resp.Header.Get("Content-Type"))

b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, th.responseBody, string(b))
})
}
}

func TestIntegrationUserInfo(t *testing.T) {
_, teardown := prepService(t)
defer teardown()
Expand Down
102 changes: 83 additions & 19 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ import (
// Authenticator is top level auth object providing middlewares
type Authenticator struct {
logger.L
JWTService TokenService
Providers []provider.Service
Validator token.Validator
AdminPasswd string
BasicAuthChecker BasicAuthFunc
RefreshCache RefreshCache
JWTService TokenService
Providers []provider.Service
Validator token.Validator
AdminPasswd string
BasicAuthChecker BasicAuthFunc
RefreshCache RefreshCache
AuthErrorHTTPHandler AuthErrorHTTPHandler
}

// RefreshCache defines interface storing and retrieving refreshed tokens
Expand All @@ -45,6 +46,28 @@ type TokenService interface {
// The second return parameter `User` need for add user claims into context of request.
type BasicAuthFunc func(user, passwd string) (ok bool, userInfo token.User, err error)

// AuthErrorHTTPHandler defines interface for handling HTTP responses in case of authentication errors
type AuthErrorHTTPHandler interface {
// Serves HTTP response in case of authentication error
// w - response writer
// r - original request
// authError - authentication error with technical details
// reason - reason text
// statusCode - HTTP status code
ServeAuthError(w http.ResponseWriter, r *http.Request, authError error, reason string, statusCode int)
}

// DefaultAuthErrorHTTPHandler is a default implementation, which writes text/plain responses using http.Error()
type DefaultAuthErrorHTTPHandler struct {
logger.L
}

// ServeAuthError writes text/plain responses using http.Error()
func (h DefaultAuthErrorHTTPHandler) ServeAuthError(w http.ResponseWriter, _ *http.Request, authError error, reason string, statusCode int) {
h.Logf("[DEBUG] auth failed, %v", authError)
http.Error(w, reason, statusCode)
}

// adminUser sets claims for an optional basic auth
var adminUser = token.User{
ID: "admin",
Expand All @@ -56,24 +79,29 @@ var adminUser = token.User{

// Auth middleware adds auth from session and populates user info
func (a *Authenticator) Auth(next http.Handler) http.Handler {
return a.auth(true)(next)
return a.auth(true, a.getAuthErrorHTTPHandler())(next)
}

// AuthWithErrorHTTPHandler middleware adds auth from session and populates user info.
// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
func (a *Authenticator) AuthWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
return a.auth(true, errorHTTPHandler)(next)
}

// Trace middleware doesn't require valid user but if user info presented populates info
func (a *Authenticator) Trace(next http.Handler) http.Handler {
return a.auth(false)(next)
return a.auth(false, a.getAuthErrorHTTPHandler())(next)
}

// auth implements all logic for authentication (reqAuth=true) and tracing (reqAuth=false)
func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
func (a *Authenticator) auth(reqAuth bool, errorHTTPHandler AuthErrorHTTPHandler) func(http.Handler) http.Handler {

onError := func(h http.Handler, w http.ResponseWriter, r *http.Request, err error) {
if !reqAuth { // if no auth required allow to proceeded on error
h.ServeHTTP(w, r)
return
}
a.Logf("[DEBUG] auth failed, %v", err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
}

f := func(h http.Handler) http.Handler {
Expand Down Expand Up @@ -191,23 +219,34 @@ func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token.
return c, nil
}

// AdminOnly middleware allows access for admins only
// this handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth
// AdminOnly middleware allows access for admins only.
// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth
func (a *Authenticator) AdminOnly(next http.Handler) http.Handler {
return a.adminOnly(next, a.getAuthErrorHTTPHandler())
}

// AdminOnlyWithErrorHTTPHandler middleware allows access for admins only.
// This handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth.
// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
func (a *Authenticator) AdminOnlyWithErrorHTTPHandler(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
return a.adminOnly(next, errorHTTPHandler)
}

func (a *Authenticator) adminOnly(next http.Handler, errorHTTPHandler AuthErrorHTTPHandler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
user, err := token.GetUserInfo(r)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
return
}

if !user.IsAdmin() {
http.Error(w, "Access denied", http.StatusForbidden)
errorHTTPHandler.ServeAuthError(w, r, fmt.Errorf("user %s/%s is not admin", user.Name, user.ID), "Access denied", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
}
return a.auth(true)(http.HandlerFunc(fn)) // enforce auth
return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth
}

// basic auth for admin user
Expand All @@ -234,12 +273,23 @@ func (a *Authenticator) basicAdminUser(r *http.Request) bool {
// RBAC middleware allows role based control for routes
// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth
func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
return a.rbac(a.getAuthErrorHTTPHandler(), roles...)
}

// RBACwithErrorHTTPHandler middleware allows role based control for routes
// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth
// errorHttpHandler parameter may be used to write custom HTTP responses in case of authentication error.
func (a *Authenticator) RBACwithErrorHTTPHandler(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler {
return a.rbac(errorHTTPHandler, roles...)
}

func (a *Authenticator) rbac(errorHTTPHandler AuthErrorHTTPHandler, roles ...string) func(http.Handler) http.Handler {

f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
user, err := token.GetUserInfo(r)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
errorHTTPHandler.ServeAuthError(w, r, err, "Unauthorized", http.StatusUnauthorized)
return
}

Expand All @@ -251,12 +301,26 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
}
}
if !matched {
http.Error(w, "Access denied", http.StatusForbidden)
errorHTTPHandler.ServeAuthError(
w,
r,
fmt.Errorf("user %s/%s does not have any of required roles: %s", user.Name, user.ID, roles),
"Access denied",
http.StatusForbidden,
)
return
}
h.ServeHTTP(w, r)
}
return a.auth(true)(http.HandlerFunc(fn)) // enforce auth
return a.auth(true, errorHTTPHandler)(http.HandlerFunc(fn)) // enforce auth
}
return f
}

func (a *Authenticator) getAuthErrorHTTPHandler() AuthErrorHTTPHandler {
if a.AuthErrorHTTPHandler != nil {
return a.AuthErrorHTTPHandler
}

return DefaultAuthErrorHTTPHandler{L: a.L}
}
Loading
Loading