Skip to content

Commit

Permalink
Pass options separately
Browse files Browse the repository at this point in the history
  • Loading branch information
klaidliadon committed Oct 23, 2024
1 parent 989c1fc commit 8c3254a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 46 deletions.
54 changes: 38 additions & 16 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package authcontrol
import (
"context"
"encoding/json"
"errors"
"net/http"
"reflect"
"strings"

"github.com/0xsequence/authcontrol/proto"
)

func defaultErrHandler(r *http.Request, w http.ResponseWriter, err error) {
type ErrHandler func(r *http.Request, w http.ResponseWriter, err error)

func DefaultErrorHandler(r *http.Request, w http.ResponseWriter, err error) {
rpcErr, ok := err.(proto.WebRPCError)
if !ok {
rpcErr = proto.ErrWebrpcEndpoint.WithCause(err)
Expand All @@ -32,45 +36,45 @@ type UserStore interface {
// map[service]map[method]T
type Config[T any] map[string]map[string]T

// returns the config value for the given request.
func (c Config[T]) Get(r *rcpRequest) (v T, ok bool) {
// Get returns the config value for the given request.
func (c Config[T]) Get(r *Request) (v T, ok bool) {
if c == nil {
return v, false
}

methodCfg, ok := c[r.serviceName][r.methodName]
methodCfg, ok := c[r.ServiceName][r.MethodName]
if !ok {
return v, false
}

return methodCfg, true
}

// rcpRequest is a parsed RPC request.
type rcpRequest struct {
packageName string
serviceName string
methodName string
// Request is a parsed RPC request.
type Request struct {
PackageName string
ServiceName string
MethodName string
}

// newRequest parses a path into an rcpRequest.
func newRequest(path string) *rcpRequest {
func ParseRequest(path string) *Request {
p := strings.Split(path, "/")
if len(p) < 4 {
return nil
}

t := &rcpRequest{
packageName: p[len(p)-3],
serviceName: p[len(p)-2],
methodName: p[len(p)-1],
r := &Request{
PackageName: p[len(p)-3],
ServiceName: p[len(p)-2],
MethodName: p[len(p)-1],
}

if t.packageName != "rpc" {
if r.PackageName != "rpc" {
return nil
}

return t
return r
}

// ACL is a list of session types, encoded as a bitfield.
Expand Down Expand Up @@ -98,3 +102,21 @@ func (a ACL) And(session ...proto.SessionType) ACL {
func (t ACL) Includes(session proto.SessionType) bool {
return t&ACL(1<<session) != 0
}

func VerifyACL[T any](acl Config[ACL]) error {
var t T
iType := reflect.TypeOf(&t).Elem()
service := iType.Name()
m, ok := acl[service]
if !ok {
return errors.New("service " + service + " not found")
}
var errList []error
for i := 0; i < iType.NumMethod(); i++ {
method := iType.Method(i)
if _, ok := m[method.Name]; !ok {
errList = append(errList, errors.New(""+service+"."+method.Name+" not found"))
}
}
return errors.Join(errList...)
}
31 changes: 9 additions & 22 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,9 @@ import (
"github.com/0xsequence/authcontrol/proto"
)

type Options struct {
KeyFuncs []KeyFunc
UserStore UserStore
ErrHandler func(r *http.Request, w http.ResponseWriter, err error)
}

func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Handler {
eh := defaultErrHandler
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
}

var keyFuncs []KeyFunc
if o != nil {
keyFuncs = o.KeyFuncs
func Session(auth *jwtauth.JWTAuth, u UserStore, eh ErrHandler, keyFuncs ...KeyFunc) func(next http.Handler) http.Handler {
if eh == nil {
eh = DefaultErrorHandler
}

return func(next http.Handler) http.Handler {
Expand Down Expand Up @@ -82,8 +70,8 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
ctx = withAccount(ctx, accountClaim)
sessionType = proto.SessionType_Wallet

if o != nil && o.UserStore != nil {
user, isAdmin, err := o.UserStore.GetUser(ctx, accountClaim)
if u != nil {
user, isAdmin, err := u.GetUser(ctx, accountClaim)
if err != nil {
eh(r, w, err)
return
Expand Down Expand Up @@ -124,15 +112,14 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han

// AccessControl middleware that checks if the session type is allowed to access the endpoint.
// It also sets the compute units on the context if the endpoint requires it.
func AccessControl(acl Config[ACL], o *Options) func(next http.Handler) http.Handler {
eh := defaultErrHandler
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
func AccessControl(acl Config[ACL], eh ErrHandler) func(next http.Handler) http.Handler {
if eh == nil {
eh = DefaultErrorHandler
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := newRequest(r.URL.Path)
req := ParseRequest(r.URL.Path)
if req == nil {
eh(r, w, proto.ErrUnauthorized.WithCausef("invalid rpc method"))
return
Expand Down
13 changes: 5 additions & 8 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,15 @@ func TestSession(t *testing.T) {

auth := jwtauth.New("HS256", []byte("secret"), nil)

options := authcontrol.Options{
UserStore: mockStore{
UserAddress: false,
AdminAddress: true,
},
KeyFuncs: []authcontrol.KeyFunc{keyFunc},
userStore := mockStore{
UserAddress: false,
AdminAddress: true,
}

r := chi.NewRouter()
r.Use(
authcontrol.Session(auth, &options),
authcontrol.AccessControl(ACLConfig, &options),
authcontrol.Session(auth, userStore, nil, keyFunc),
authcontrol.AccessControl(ACLConfig, nil),
)
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

Expand Down

0 comments on commit 8c3254a

Please sign in to comment.