Skip to content

Commit

Permalink
minor improvements (#11)
Browse files Browse the repository at this point in the history
* minor improvements

* update

* remove unused return value
  • Loading branch information
pkieltyka authored Oct 29, 2024
1 parent 86f2003 commit 1b8b441
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 23 deletions.
2 changes: 1 addition & 1 deletion common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func errHandler(r *http.Request, w http.ResponseWriter, err error) {
w.Write(respBody)
}

type KeyFunc func(*http.Request) string
type AccessKeyFunc func(*http.Request) string

type UserStore interface {
GetUser(ctx context.Context, address string) (any, bool, error)
Expand Down
47 changes: 28 additions & 19 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,33 @@ import (
"github.com/0xsequence/authcontrol/proto"
)

// Options for the authcontrol middleware handlers Session and AccessControl.
type Options struct {
JWTSecret string
KeyFuncs []KeyFunc
UserStore UserStore
// JWT secret used to verify the JWT token.
JWTSecret string

// AccessKeyFuncs is a list of functions that are used to extract the access key
// from the request.
AccessKeyFuncs []AccessKeyFunc

// UserStore is a function that is used to get the user from the request
// with pluggable backends.
UserStore UserStore

// ErrHandler is a function that is used to handle and respond to errors.
ErrHandler ErrHandler
}

func (o *Options) ApplyDefaults() {
if o.ErrHandler == nil {
o.ErrHandler = errHandler
}
}

func Session(cfg *Options) func(next http.Handler) http.Handler {
cfg.ApplyDefaults()
auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil)

eh := errHandler
if cfg != nil && cfg.ErrHandler != nil {
eh = cfg.ErrHandler
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
Expand All @@ -42,7 +54,7 @@ func Session(cfg *Options) func(next http.Handler) http.Handler {
)

if cfg != nil {
for _, f := range cfg.KeyFuncs {
for _, f := range cfg.AccessKeyFuncs {
if accessKey = f(r); accessKey != "" {
break
}
Expand All @@ -52,20 +64,20 @@ func Session(cfg *Options) func(next http.Handler) http.Handler {
token, err := jwtauth.VerifyRequest(auth, r, jwtauth.TokenFromHeader)
if err != nil {
if errors.Is(err, jwtauth.ErrExpired) {
eh(r, w, proto.ErrSessionExpired)
cfg.ErrHandler(r, w, proto.ErrSessionExpired)
return
}

if !errors.Is(err, jwtauth.ErrNoTokenFound) {
eh(r, w, proto.ErrUnauthorized)
cfg.ErrHandler(r, w, proto.ErrUnauthorized)
return
}
}

if token != nil {
claims, err := token.AsMap(ctx)
if err != nil {
eh(r, w, err)
cfg.ErrHandler(r, w, err)
return
}

Expand All @@ -85,7 +97,7 @@ func Session(cfg *Options) func(next http.Handler) http.Handler {
if cfg != nil && cfg.UserStore != nil {
user, isAdmin, err := cfg.UserStore.GetUser(ctx, accountClaim)
if err != nil {
eh(r, w, err)
cfg.ErrHandler(r, w, err)
return
}

Expand Down Expand Up @@ -125,16 +137,13 @@ func Session(cfg *Options) func(next http.Handler) http.Handler {
// AccessControl middleware that checks if the session type is allowed to access the endpoint.
// It also sets the compute units on the context if the endpoint requires it.
func AccessControl(acl Config[ACL], cfg *Options) func(next http.Handler) http.Handler {
eh := errHandler
if cfg != nil && cfg.ErrHandler != nil {
eh = cfg.ErrHandler
}
cfg.ApplyDefaults()

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
acl, err := acl.Get(r.URL.Path)
if err != nil {
eh(r, w, proto.ErrUnauthorized.WithCausef("get acl: %w", err))
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get acl: %w", err))
return
}

Expand All @@ -144,7 +153,7 @@ func AccessControl(acl Config[ACL], cfg *Options) func(next http.Handler) http.H
err = proto.ErrUnauthorized
}

eh(r, w, err)
cfg.ErrHandler(r, w, err)
return
}

Expand Down
6 changes: 3 additions & 3 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestSession(t *testing.T) {
UserAddress: false,
AdminAddress: true,
},
KeyFuncs: []authcontrol.KeyFunc{keyFunc},
AccessKeyFuncs: []authcontrol.AccessKeyFunc{keyFunc},
}

r := chi.NewRouter()
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestInvalid(t *testing.T) {
UserAddress: false,
AdminAddress: true,
},
KeyFuncs: []authcontrol.KeyFunc{keyFunc},
AccessKeyFuncs: []authcontrol.AccessKeyFunc{keyFunc},
}

r := chi.NewRouter()
Expand Down Expand Up @@ -294,7 +294,7 @@ func TestCustomErrHandler(t *testing.T) {
UserAddress: false,
AdminAddress: true,
},
KeyFuncs: []authcontrol.KeyFunc{keyFunc},
AccessKeyFuncs: []authcontrol.AccessKeyFunc{keyFunc},
ErrHandler: func(r *http.Request, w http.ResponseWriter, err error) {
rpcErr := customErr

Expand Down

0 comments on commit 1b8b441

Please sign in to comment.