Skip to content

Commit

Permalink
improve TestIntegrationAuthErrorHTTPHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
cyb3r4nt committed Aug 5, 2024
1 parent 0d7af14 commit 727ffc4
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 302 deletions.
245 changes: 94 additions & 151 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/go-pkgz/auth/avatar"
"github.com/go-pkgz/auth/logger"
"github.com/go-pkgz/auth/middleware"
"github.com/go-pkgz/auth/provider"
"github.com/go-pkgz/auth/token"
)
Expand Down Expand Up @@ -248,10 +249,10 @@ func TestIntegrationList(t *testing.T) {
}

type testAuthErrorHTTPHandler struct {
wasCalled bool
statusCode int
contentType string
responseBody string
wasCalled bool
}

func (h *testAuthErrorHTTPHandler) ServeAuthError(
Expand All @@ -268,96 +269,95 @@ func (h *testAuthErrorHTTPHandler) ServeAuthError(
}

func TestIntegrationAuthErrorHTTPHandler(t *testing.T) {
testErrorHandler1 := &testAuthErrorHTTPHandler{
statusCode: 401,
apiHandler := http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("must not be called\n"))
t.Error("auth error must be raised before this HTTP handler is called")
},
)
defaultAuthErrorHTTPHandler := testAuthErrorHTTPHandler{
statusCode: 400,
contentType: "application/json",
responseBody: `{"code": 401, "message": "from general error handler"}`,
}
testErrorHandler2 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "text/html",
responseBody: `<html><body><h1>from private2 error handler</h1></body></html>`,
responseBody: `{"code": 400, "message": "from general error handler"}`,
}
testErrorHandler3 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "application/json",
responseBody: `{"code": 401, "message": "from admin error handler"}`,
type apiCall struct {
requestPath string
expectedHandler *testAuthErrorHTTPHandler
createMiddleware func(apiCall, middleware.Authenticator) http.Handler
}
testErrorHandler4 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "text/html",
responseBody: `<html><body><h1>from RBAC error handler</h1></body></html>`,
apiCalls := []apiCall{
{
requestPath: "/private1",
expectedHandler: &defaultAuthErrorHTTPHandler,
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.Auth(apiHandler)
},
},
{
requestPath: "/private2",
expectedHandler: &testAuthErrorHTTPHandler{
statusCode: 402,
contentType: "application/json",
responseBody: `{"code": 402, "message": "from private2 error handler"}`,
},
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.AuthWithErrorHTTPHandler(apiHandler, ac.expectedHandler)
},
},
{
requestPath: "/admin1",
expectedHandler: &defaultAuthErrorHTTPHandler,
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.AdminOnly(apiHandler)
},
},
{
requestPath: "/admin2",
expectedHandler: &testAuthErrorHTTPHandler{
statusCode: 404,
contentType: "application/json",
responseBody: `{"code": 404, "message": "from admin2 error handler"}`,
},
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.AdminOnlyWithErrorHTTPHandler(apiHandler, ac.expectedHandler)
},
},
{
requestPath: "/rbac1",
expectedHandler: &defaultAuthErrorHTTPHandler,
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.RBAC("role1", "role2")(apiHandler)
},
},
{
requestPath: "/rbac2",
expectedHandler: &testAuthErrorHTTPHandler{
statusCode: 406,
contentType: "text/html",
responseBody: `<html><body><h1>from RBAC2 error handler</h1></body></html>`,
},
createMiddleware: func(ac apiCall, a middleware.Authenticator) http.Handler {
return a.RBACwithErrorHTTPHandler(ac.expectedHandler, "role1", "role2")(apiHandler)
},
},
}

options := Opts{
SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
Issuer: "my-test-app",
URL: "http://127.0.0.1:8089",
SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
Issuer: "my-test-app",
URL: "http://127.0.0.1:8089",
AuthErrorHTTPHandler: &defaultAuthErrorHTTPHandler,
}

svc := NewService(options)
svc.AddDevProvider("localhost", 18084) // add dev provider on 18084
svc.authMiddleware.AuthErrorHTTPHandler = testErrorHandler1

// setup http server
m := svc.Middleware()
mux := http.NewServeMux()
mux.Handle("/private1",
m.Auth(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("protected route1\n"))
},
),
),
)
mux.Handle("/private2",
m.AuthWithErrorHTTPHandler(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("protected route2\n"))
},
),
testErrorHandler2,
),
)
mux.Handle("/admin1",
m.AdminOnly(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("admin route1\n"))
},
),
),
)
mux.Handle("/admin2",
m.AdminOnlyWithErrorHTTPHandler(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("admin route2\n"))
},
),
testErrorHandler3,
),
)
mux.Handle("/rbac1",
m.RBAC("role1", "role2")(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("rbac route1\n"))
},
),
),
)
mux.Handle("/rbac2",
m.RBACwithErrorHTTPHandler(testErrorHandler4, "role1", "role2")(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("rbac route2\n"))
},
),
),
)
m := svc.Middleware()

for _, ac := range apiCalls {
mux.Handle(ac.requestPath, ac.createMiddleware(ac, m))
}

l, listenErr := net.Listen("tcp", "127.0.0.1:8089")
require.Nil(t, listenErr)
Expand All @@ -369,82 +369,25 @@ func TestIntegrationAuthErrorHTTPHandler(t *testing.T) {
ts.Close()
}()

assertBodyEquals := func(t *testing.T, r *http.Response, expectedBody string) {
b, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.Equal(t, expectedBody, string(b))
}
assertContentTypeEquals := func(t *testing.T, r *http.Response, expectedContentType string) {
assert.Equal(t, expectedContentType, r.Header.Get("Content-Type"))
}

// private1
resp, err := http.Get("http://127.0.0.1:8089/private1")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler1.wasCalled)

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)

// private2
resp, err = http.Get("http://127.0.0.1:8089/private2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler2.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "text/html")
assertBodyEquals(t, resp, `<html><body><h1>from private2 error handler</h1></body></html>`)

// admin1
testErrorHandler1.wasCalled = false
resp, err = http.Get("http://127.0.0.1:8089/admin1")
require.NoError(t, err)
defer resp.Body.Close()
for _, ac := range apiCalls {
t.Run("auth error test for endpoint "+ac.requestPath, func(t *testing.T) {
th := ac.expectedHandler
th.wasCalled = false

require.True(t, testErrorHandler1.wasCalled)
resp, err := http.Get("http://127.0.0.1:8089" + ac.requestPath)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)
require.True(t, th.wasCalled)

// admin2
resp, err = http.Get("http://127.0.0.1:8089/admin2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler3.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from admin error handler"}`)

// rbac1
testErrorHandler1.wasCalled = false
resp, err = http.Get("http://127.0.0.1:8089/rbac1")
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, th.statusCode, resp.StatusCode)
require.Equal(t, th.contentType, resp.Header.Get("Content-Type"))

require.True(t, testErrorHandler1.wasCalled)

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)

// rbac2
resp, err = http.Get("http://127.0.0.1:8089/rbac2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler4.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "text/html")
assertBodyEquals(t, resp, `<html><body><h1>from RBAC error handler</h1></body></html>`)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, th.responseBody, string(b))
})
}
}

func TestIntegrationUserInfo(t *testing.T) {
Expand Down
Loading

0 comments on commit 727ffc4

Please sign in to comment.