Skip to content

Commit

Permalink
suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
david-littlefarmer committed Oct 23, 2024
1 parent 5ce6340 commit da24c66
Show file tree
Hide file tree
Showing 12 changed files with 459 additions and 406 deletions.
57 changes: 28 additions & 29 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,70 +29,69 @@ type UserStore interface {
}

// Config is a generic map of services/methods to a config value.
// map[service]map[method]T
type Config[T any] map[string]map[string]T

// Get returns the config value for the given request.
func (c Config[T]) Get(r *rcpRequest) (v T, ok bool) {
if c == nil || r.Package != "rpc" {
if c == nil {
return v, false
}
serviceCfg, ok := c[r.Service]
if !ok {
return v, false
}
methodCfg, ok := serviceCfg[r.Method]

methodCfg, ok := c[r.serviceName][r.methodName]
if !ok {
return v, false
}

return methodCfg, true
}

// rcpRequest is a parsed RPC request.
type rcpRequest struct {
Package string
Service string
Method string
packageName string
serviceName string
methodName string
}

// newRequest parses a path into an rcpRequest.
func newRequest(path string) *rcpRequest {
parts := strings.Split(path, "/")
if len(parts) != 4 {
return nil
}
if parts[0] != "" {
p := strings.Split(path, "/")
if len(p) < 4 {
return nil
}
t := rcpRequest{
Package: parts[1],
Service: parts[2],
Method: parts[3],

t := &rcpRequest{
packageName: p[len(p)-3],
serviceName: p[len(p)-2],
methodName: p[len(p)-1],
}
if t.Package == "" || t.Service == "" || t.Method == "" {

if t.packageName != "rpc" {
return nil
}
return &t

return t
}

// ACL is a list of session types, encoded as a bitfield.
// SessionType(n) is represented by n=-the bit.
type ACL uint64

// NewACL returns a new ACL with the given session types.
func NewACL(t ...proto.SessionType) ACL {
var types ACL
for _, v := range t {
types = types.And(v)
func NewACL(sessions ...proto.SessionType) ACL {
var acl ACL
for _, v := range sessions {
acl = acl.And(v)
}
return types
return acl
}

// And returns a new ACL with the given session types added.
func (t ACL) And(types ...proto.SessionType) ACL {
for _, v := range types {
t |= 1 << v
func (a ACL) And(session ...proto.SessionType) ACL {
for _, v := range session {
a |= 1 << v
}
return t
return a
}

// Includes returns true if the ACL includes the given session type.
Expand Down
15 changes: 8 additions & 7 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func mustJWT(t *testing.T, auth *jwtauth.JWTAuth, claims map[string]any) string
if claims == nil {
return ""
}

_, token, err := auth.Encode(claims)
require.NoError(t, err)
return token
Expand All @@ -28,11 +29,10 @@ func keyFunc(r *http.Request) string {
return r.Header.Get(HeaderKey)
}

func executeRequest(ctx context.Context, handler http.Handler, path, accessKey, jwt string) (bool, http.Header, error) {
func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey, jwt string) (bool, error) {
req, err := http.NewRequest("POST", path, nil)
if err != nil {
return false, nil, err
}
require.NoError(t, err)

req.Header.Set("X-Real-IP", "127.0.0.1")
if accessKey != "" {
req.Header.Set(HeaderKey, accessKey)
Expand All @@ -46,9 +46,10 @@ func executeRequest(ctx context.Context, handler http.Handler, path, accessKey,

if status := rr.Result().StatusCode; status < http.StatusOK || status >= http.StatusBadRequest {
w := proto.WebRPCError{}
json.Unmarshal(rr.Body.Bytes(), &w)
return false, rr.Header(), w
err = json.Unmarshal(rr.Body.Bytes(), &w)
require.NoError(t, err)
return false, w
}

return true, rr.Header(), nil
return true, nil
}
44 changes: 34 additions & 10 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ var (
ctxKeyProjectID = &contextKey{"ProjectID"}
)

// WithSessionType adds the access key to the context.
func WithSessionType(ctx context.Context, accessType proto.SessionType) context.Context {
//
// Session Type
//

// withSessionType adds the access key to the context.
func withSessionType(ctx context.Context, accessType proto.SessionType) context.Context {
return context.WithValue(ctx, ctxKeySessionType, accessType)
}

Expand All @@ -37,8 +41,12 @@ func GetSessionType(ctx context.Context) (proto.SessionType, bool) {
return v, true
}

// WithAccount adds the account to the context.
func WithAccount(ctx context.Context, account string) context.Context {
//
// Account
//

// withAccount adds the account to the context.
func withAccount(ctx context.Context, account string) context.Context {
return context.WithValue(ctx, ctxKeyAccount, account)
}

Expand All @@ -48,8 +56,12 @@ func GetAccount(ctx context.Context) (string, bool) {
return v, ok
}

// WithUser adds the user to the context.
func WithUser(ctx context.Context, user any) context.Context {
//
// User
//

// withUser adds the user to the context.
func withUser(ctx context.Context, user any) context.Context {
return context.WithValue(ctx, ctxKeyUser, user)
}

Expand All @@ -59,8 +71,12 @@ func GetUser[T any](ctx context.Context) (T, bool) {
return v, ok
}

// WithService adds the service to the context.
func WithService(ctx context.Context, service string) context.Context {
//
// Service
//

// withService adds the service to the context.
func withService(ctx context.Context, service string) context.Context {
return context.WithValue(ctx, ctxKeyService, service)
}

Expand All @@ -70,8 +86,12 @@ func GetService(ctx context.Context) (string, bool) {
return v, ok
}

// WithAccessKey adds the access key to the context.
func WithAccessKey(ctx context.Context, accessKey string) context.Context {
//
// AccessKey
//

// withAccessKey adds the access key to the context.
func withAccessKey(ctx context.Context, accessKey string) context.Context {
return context.WithValue(ctx, ctxKeyAccessKey, accessKey)
}

Expand All @@ -81,6 +101,10 @@ func GetAccessKey(ctx context.Context) (string, bool) {
return v, ok
}

//
// Project ID
//

// withProjectID adds the projectID to the context.
func withProjectID(ctx context.Context, projectID uint64) context.Context {
return context.WithValue(ctx, ctxKeyProjectID, projectID)
Expand Down
36 changes: 22 additions & 14 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// check if the request already contains session, if it does then continue
if _, ok := GetSessionType(ctx); ok {
next.ServeHTTP(w, r)
return
}

var (
sessionType proto.SessionType
accessKey string
Expand All @@ -51,14 +54,15 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
eh(r, w, proto.ErrSessionExpired)
return
}

if !errors.Is(err, jwtauth.ErrNoTokenFound) {
eh(r, w, proto.ErrUnauthorized)
return
}
}

if token != nil {
claims, err := token.AsMap(r.Context())
claims, err := token.AsMap(ctx)
if err != nil {
eh(r, w, err)
return
Expand All @@ -68,12 +72,13 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
accountClaim, _ := claims["account"].(string)
adminClaim, _ := claims["admin"].(bool)
projectClaim, _ := claims["project"].(float64)

switch {
case serviceClaim != "":
ctx = WithService(ctx, serviceClaim)
ctx = withService(ctx, serviceClaim)
sessionType = proto.SessionType_Service
case accountClaim != "":
ctx = WithAccount(ctx, accountClaim)
ctx = withAccount(ctx, accountClaim)
sessionType = proto.SessionType_Wallet

if o != nil && o.UserStore != nil {
Expand All @@ -82,34 +87,37 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
eh(r, w, err)
return
}

if user != nil {
ctx = withUser(ctx, user)

sessionType = proto.SessionType_User
if isAdmin {
sessionType = proto.SessionType_Admin
} else {
sessionType = proto.SessionType_User
}
ctx = WithUser(ctx, user)
}
}

if adminClaim {
sessionType = proto.SessionType_Admin
break
}

if projectClaim > 0 {
projectID := uint64(projectClaim)
ctx = withProjectID(ctx, projectID)
sessionType = proto.SessionType_Project
}
case adminClaim:
sessionType = proto.SessionType_Admin
}
}

if accessKey != "" && sessionType < proto.SessionType_Admin {
ctx = WithAccessKey(ctx, accessKey)
ctx = withAccessKey(ctx, accessKey)
sessionType = max(sessionType, proto.SessionType_AccessKey)
}

ctx = WithSessionType(ctx, sessionType)
ctx = withSessionType(ctx, sessionType)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
Expand All @@ -122,6 +130,7 @@ func AccessControl(acl Config[ACL], o *Options) func(next http.Handler) http.Han
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := newRequest(r.URL.Path)
Expand All @@ -130,24 +139,23 @@ func AccessControl(acl Config[ACL], o *Options) func(next http.Handler) http.Han
return
}

types, ok := acl.Get(req)
acl, ok := acl.Get(req)
if !ok {
eh(r, w, proto.ErrUnauthorized.WithCausef("rpc method not found"))
return
}

if session, _ := GetSessionType(r.Context()); !types.Includes(session) {
if session, _ := GetSessionType(r.Context()); !acl.Includes(session) {
err := proto.ErrPermissionDenied
if session == proto.SessionType_Public {
err = proto.ErrUnauthorized
}

eh(r, w, err)
return
}

ctx := r.Context()

next.ServeHTTP(w, r.WithContext(ctx))
next.ServeHTTP(w, r)
})
}
}
12 changes: 6 additions & 6 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (
"strings"
"testing"

"github.com/0xsequence/authcontrol"
"github.com/0xsequence/authcontrol/proto"
"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/stretchr/testify/assert"

"github.com/0xsequence/authcontrol"
"github.com/0xsequence/authcontrol/proto"
)

type mockStore map[string]bool
Expand Down Expand Up @@ -54,9 +55,9 @@ func TestSession(t *testing.T) {
MethodService = "MethodService"
)

var Methods = []string{MethodPublic, MethodAccount, MethodAccessKey, MethodProject, MethodUser, MethodAdmin, MethodService}
Methods := []string{MethodPublic, MethodAccount, MethodAccessKey, MethodProject, MethodUser, MethodAdmin, MethodService}

var ACLConfig = authcontrol.Config[authcontrol.ACL]{"Service": {
ACLConfig := authcontrol.Config[authcontrol.ACL]{"Service": {
MethodPublic: authcontrol.NewACL(proto.SessionType_Public.OrHigher()...),
MethodAccount: authcontrol.NewACL(proto.SessionType_Wallet.OrHigher()...),
MethodAccessKey: authcontrol.NewACL(proto.SessionType_AccessKey.OrHigher()...),
Expand Down Expand Up @@ -129,7 +130,7 @@ func TestSession(t *testing.T) {
claims = map[string]any{"service": ServiceName}
}

ok, _, err := executeRequest(ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, mustJWT(t, auth, claims))
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, mustJWT(t, auth, claims))

session := tc.Session
switch {
Expand All @@ -150,6 +151,5 @@ func TestSession(t *testing.T) {
})
}
}

}
}
Loading

0 comments on commit da24c66

Please sign in to comment.