Skip to content

Commit

Permalink
Split session (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
klaidliadon authored Nov 19, 2024
1 parent f5ebb9d commit e9b255d
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 88 deletions.
80 changes: 77 additions & 3 deletions common.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
package authcontrol

import (
"cmp"
"context"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/http"
"strconv"
"strings"

"github.com/0xsequence/authcontrol/proto"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)

const (
Expand Down Expand Up @@ -42,10 +49,11 @@ type UserStore interface {
GetUser(ctx context.Context, address string) (user any, isAdmin bool, err error)
}

// ProjectStore is a pluggable backend that verifies if the project exists.
// If the project doesn't exist, it should return nil, nil.
// ProjectStore is a pluggable backend that verifies if a project exists.
// If the project does not exist, it should return nil, nil, nil.
// The optional Auth, when returned, will be used for instead of the standard one.
type ProjectStore interface {
GetProject(ctx context.Context, id uint64) (project any, err error)
GetProject(ctx context.Context, id uint64) (project any, auth *Auth, err error)
}

// Config is a generic map of services/methods to a config value.
Expand Down Expand Up @@ -121,3 +129,69 @@ func (a ACL) And(session ...proto.SessionType) ACL {
func (t ACL) Includes(session proto.SessionType) bool {
return t&ACL(1<<session) != 0
}

// NewAuth creates a new Auth HS256 with the given secret.
func NewAuth(secret string) *Auth {
return &Auth{Algorithm: jwa.HS256, Private: []byte(secret)}
}

// Auth is a struct that holds the private and public keys for JWT signing and verification.
type Auth struct {
Algorithm jwa.SignatureAlgorithm
Private []byte
Public []byte
}

// GetVerifier returns a JWTAuth using the private secret when available, otherwise the public key
func (a Auth) GetVerifier(options ...jwt.ValidateOption) (*jwtauth.JWTAuth, error) {
if a.Algorithm == "" {
return nil, fmt.Errorf("missing algorithm")
}

if a.Private != nil {
return jwtauth.New(string(a.Algorithm), a.Private, a.Private, options...), nil
}

if a.Public == nil {
return nil, fmt.Errorf("missing public key")
}

block, _ := pem.Decode(a.Public)

pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse public key: %w", err)
}

return jwtauth.New(a.Algorithm.String(), nil, pub, options...), nil
}

// findProjectClaim looks for the project_id/project claim in the JWT
func findProjectClaim(r *http.Request) (uint64, error) {
raw := cmp.Or(jwtauth.TokenFromHeader(r))

token, err := jwt.ParseString(raw, jwt.WithVerify(false))
if err != nil {
return 0, fmt.Errorf("parse token: %w", err)
}

claims := token.PrivateClaims()

claim := cmp.Or(claims["project_id"], claims["project"])
if claim == nil {
return 0, fmt.Errorf("missing project claim")
}

switch val := claim.(type) {
case float64:
return uint64(val), nil
case string:
v, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid value")
}
return v, nil
default:
return 0, fmt.Errorf("invalid type: %T", val)
}
}
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ var (
ctxKeyUser = &contextKey{"User"}
ctxKeyService = &contextKey{"Service"}
ctxKeyAccessKey = &contextKey{"AccessKey"}
ctxKeyProject = &contextKey{"Project"}
ctxKeyProjectID = &contextKey{"ProjectID"}
ctxKeyProject = &contextKey{"Project"}
)

//
Expand Down
114 changes: 78 additions & 36 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ import (

// Options for the authcontrol middleware handlers Session and AccessControl.
type Options struct {
// JWT secret used to verify the JWT token.
// JWTsecret is required, and it is used for the JWT verification.
// If a Project Store is also provided and the request has a project claim,
// it could be replaced by the a specific verifier.
JWTSecret string

// ProjectStore is a pluggable backends that verifies if the project from the claim exists.
// When provived, it checks the Project from the JWT, and can override the JWT Auth.
ProjectStore ProjectStore

// AccessKeyFuncs are used to extract the access key from the request.
AccessKeyFuncs []AccessKeyFunc

// UserStore is a pluggable backends that verifies if the account exists.
// When provided, it can upgrade a Wallet session to a User or Admin session.
UserStore UserStore

// ProjectStore is a pluggable backends that verifies if the project exists.
ProjectStore ProjectStore

// ErrHandler is a function that is used to handle and respond to errors.
ErrHandler ErrHandler
}
Expand All @@ -46,34 +50,47 @@ func (o *Options) ApplyDefaults() {
}
}

func Session(cfg Options) func(next http.Handler) http.Handler {
func VerifyToken(cfg Options) func(next http.Handler) http.Handler {
cfg.ApplyDefaults()
auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil, jwt.WithAcceptableSkew(2*time.Minute))
jwtOptions := []jwt.ValidateOption{
jwt.WithAcceptableSkew(2 * time.Minute),
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// check if the request already contains session, if it does then continue
if _, ok := GetSessionType(ctx); ok {
next.ServeHTTP(w, r)
return
}
auth := NewAuth(cfg.JWTSecret)

var (
sessionType proto.SessionType
accessKey string
token jwt.Token
)
if cfg.ProjectStore != nil {
projectID, err := findProjectClaim(r)
if err != nil {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project claim: %w", err))
return
}

for _, f := range cfg.AccessKeyFuncs {
if accessKey = f(r); accessKey != "" {
break
project, _auth, err := cfg.ProjectStore.GetProject(ctx, projectID)
if err != nil {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project: %w", err))
return
}
if project == nil {
cfg.ErrHandler(r, w, proto.ErrProjectNotFound)
return
}
if _auth != nil {
auth = _auth
}
ctx = WithProject(ctx, project)
}

// Verify JWT token and validate its claims.
token, err := jwtauth.VerifyRequest(auth, r, jwtauth.TokenFromHeader)
jwtAuth, err := auth.GetVerifier(jwtOptions...)
if err != nil {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get verifier: %w", err))
return
}

token, err := jwtauth.VerifyRequest(jwtAuth, r, jwtauth.TokenFromHeader)
if err != nil {
if errors.Is(err, jwtauth.ErrExpired) {
cfg.ErrHandler(r, w, proto.ErrSessionExpired)
Expand All @@ -89,7 +106,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
if token != nil {
claims, err := token.AsMap(ctx)
if err != nil {
cfg.ErrHandler(r, w, err)
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("invalid token: %w", err))
return
}

Expand All @@ -102,6 +119,44 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
}
}

ctx = jwtauth.NewContext(ctx, token, nil)
}

next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

func Session(cfg Options) func(next http.Handler) http.Handler {
cfg.ApplyDefaults()

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// if a custom middleware already sets the session type, skip this middleware
if _, ok := GetSessionType(ctx); ok {
next.ServeHTTP(w, r)
return
}

var (
accessKey string
sessionType proto.SessionType
)

for _, f := range cfg.AccessKeyFuncs {
if accessKey = f(r); accessKey != "" {
break
}
}

_, claims, err := jwtauth.FromContext(ctx)
if err != nil {
cfg.ErrHandler(r, w, err)
return
}
if claims != nil {
serviceClaim, _ := claims["service"].(string)
accountClaim, _ := claims["account"].(string)
adminClaim, _ := claims["admin"].(bool)
Expand Down Expand Up @@ -140,20 +195,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
}

if projectClaim > 0 {
projectID := uint64(projectClaim)
if cfg.ProjectStore != nil {
project, err := cfg.ProjectStore.GetProject(ctx, projectID)
if err != nil {
cfg.ErrHandler(r, w, err)
return
}
if project == nil {
cfg.ErrHandler(r, w, proto.ErrProjectNotFound)
return
}
ctx = WithProject(ctx, project)
}
ctx = WithProjectID(ctx, projectID)
ctx = WithProjectID(ctx, uint64(projectClaim))
sessionType = proto.SessionType_Project
}
}
Expand Down
Loading

0 comments on commit e9b255d

Please sign in to comment.