diff --git a/common.go b/common.go index 24abb0e..ee75542 100644 --- a/common.go +++ b/common.go @@ -1,14 +1,21 @@ package authcontrol import ( + "cmp" "context" + "crypto/x509" "encoding/json" + "encoding/pem" "errors" "fmt" "net/http" + "strconv" "strings" "github.com/0xsequence/authcontrol/proto" + "github.com/go-chi/jwtauth/v5" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwt" ) const ( @@ -42,10 +49,11 @@ type UserStore interface { GetUser(ctx context.Context, address string) (user any, isAdmin bool, err error) } -// ProjectStore is a pluggable backend that verifies if the project exists. -// If the project doesn't exist, it should return nil, nil. +// ProjectStore is a pluggable backend that verifies if a project exists. +// If the project does not exist, it should return nil, nil, nil. +// The optional Auth, when returned, will be used for instead of the standard one. type ProjectStore interface { - GetProject(ctx context.Context, id uint64) (project any, err error) + GetProject(ctx context.Context, id uint64) (project any, auth *Auth, err error) } // Config is a generic map of services/methods to a config value. @@ -121,3 +129,69 @@ func (a ACL) And(session ...proto.SessionType) ACL { func (t ACL) Includes(session proto.SessionType) bool { return t&ACL(1< 0 { - projectID := uint64(projectClaim) - if cfg.ProjectStore != nil { - project, err := cfg.ProjectStore.GetProject(ctx, projectID) - if err != nil { - cfg.ErrHandler(r, w, err) - return - } - if project == nil { - cfg.ErrHandler(r, w, proto.ErrProjectNotFound) - return - } - ctx = WithProject(ctx, project) - } - ctx = WithProjectID(ctx, projectID) + ctx = WithProjectID(ctx, uint64(projectClaim)) sessionType = proto.SessionType_Project } } diff --git a/middleware_test.go b/middleware_test.go index ed8f8d7..02fd6b2 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -2,7 +2,11 @@ package authcontrol_test import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "encoding/json" + "encoding/pem" "fmt" "net/http" "strings" @@ -10,7 +14,9 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/go-chi/jwtauth/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/0xsequence/authcontrol" "github.com/0xsequence/authcontrol/proto" @@ -31,17 +37,6 @@ func (m MockUserStore) GetUser(ctx context.Context, address string) (user any, i return struct{}{}, v, nil } -// MockProjectStore is a simple in-memory Project store for testing, it stores the project. -type MockProjectStore map[uint64]struct{} - -// GetProject returns the project from the store. -func (m MockProjectStore) GetProject(ctx context.Context, id uint64) (project any, err error) { - if _, ok := m[id]; !ok { - return nil, nil - } - return struct{}{}, nil -} - func TestSession(t *testing.T) { const ( MethodPublic = "MethodPublic" @@ -80,17 +75,14 @@ func TestSession(t *testing.T) { UserAddress: false, AdminAddress: true, }, - ProjectStore: MockProjectStore{ - ProjectID: struct{}{}, - }, AccessKeyFuncs: []authcontrol.AccessKeyFunc{keyFunc}, } r := chi.NewRouter() - r.Use( - authcontrol.Session(options), - authcontrol.AccessControl(ACLConfig, options), - ) + r.Use(authcontrol.VerifyToken(options)) + r.Use(authcontrol.Session(options)) + r.Use(authcontrol.AccessControl(ACLConfig, options)) + r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) ctx := context.Background() @@ -204,13 +196,11 @@ func TestInvalid(t *testing.T) { UserAddress: false, AdminAddress: true, }, - ProjectStore: MockProjectStore{ - ProjectID: struct{}{}, - }, AccessKeyFuncs: []authcontrol.AccessKeyFunc{keyFunc}, } r := chi.NewRouter() + r.Use(authcontrol.VerifyToken(options)) r.Use(authcontrol.Session(options)) r.Use(authcontrol.AccessControl(ACLConfig, options)) @@ -281,12 +271,6 @@ func TestInvalid(t *testing.T) { ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(expiredJWT)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrSessionExpired) - - // Invalid Project - wrongProject := authcontrol.S2SToken(JWTSecret, map[string]any{"account": WalletAddress, "project_id": ProjectID + 1}) - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), jwt(wrongProject)) - assert.False(t, ok) - assert.ErrorIs(t, err, proto.ErrProjectNotFound) } func TestCustomErrHandler(t *testing.T) { @@ -336,10 +320,10 @@ func TestCustomErrHandler(t *testing.T) { } r := chi.NewRouter() - r.Use( - authcontrol.Session(opts), - authcontrol.AccessControl(ACLConfig, opts), - ) + r.Use(authcontrol.VerifyToken(opts)) + r.Use(authcontrol.Session(opts)) + r.Use(authcontrol.AccessControl(ACLConfig, opts)) + r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) var claims map[string]any @@ -364,6 +348,7 @@ func TestOrigin(t *testing.T) { } r := chi.NewRouter() + r.Use(authcontrol.VerifyToken(opts)) r.Use(authcontrol.Session(opts)) r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -387,3 +372,66 @@ func TestOrigin(t *testing.T) { assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) } + +type MockProjectStore map[uint64]*authcontrol.Auth + +func (m MockProjectStore) GetProject(ctx context.Context, projectID uint64) (any, *authcontrol.Auth, error) { + auth, ok := m[projectID] + if !ok { + return nil, nil, nil + } + return struct{}{}, auth, nil +} + +func TestProjectVerifier(t *testing.T) { + ctx := context.Background() + + authStore := MockProjectStore{} + + opts := authcontrol.Options{ + ProjectStore: authStore, + } + + r := chi.NewRouter() + r.Use(authcontrol.VerifyToken(opts)) + r.Use(authcontrol.Session(opts)) + r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + projectID := uint64(7) + + authStore[projectID] = authcontrol.NewAuth(JWTSecret) + + token := authcontrol.S2SToken(JWTSecret, map[string]any{ + "project_id": projectID, + }) + + ok, err := executeRequest(t, ctx, r, "", jwt(token)) + assert.True(t, ok) + assert.NoError(t, err) + + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + require.NoError(t, privateKey.Validate()) + + publicRaw, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + + public := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: publicRaw, + }) + + authStore[projectID] = &authcontrol.Auth{ + Algorithm: "RS256", + Public: public, + } + + _, token, err = jwtauth.New("RS256", privateKey, nil).Encode(map[string]any{ + "project_id": projectID, + }) + require.NoError(t, err) + + ok, err = executeRequest(t, ctx, r, "", jwt(token)) + assert.True(t, ok) + assert.NoError(t, err) +} diff --git a/proto/authcontrol.gen.go b/proto/authcontrol.gen.go index 79f48d8..fbcb8a1 100644 --- a/proto/authcontrol.gen.go +++ b/proto/authcontrol.gen.go @@ -1,4 +1,4 @@ -// access-control v0.9.0 4ec5724f11e7e078f8eaf99b10a5f1a020c46821 +// authcontrol v0.9.0 896c9dd61e9b52933577b35ac021a229cdc54fe2 // -- // Code generated by webrpc-gen@v0.21.0 with golang@v0.16.0 generator. DO NOT EDIT. // @@ -15,7 +15,7 @@ import ( const WebrpcHeader = "Webrpc" -const WebrpcHeaderValue = "webrpc@v0.21.0;gen-golang@v0.16.0;access-control@v0.9.0" +const WebrpcHeaderValue = "webrpc@v0.21.0;gen-golang@v0.16.0;authcontrol@v0.9.0" // WebRPC description and code-gen version func WebRPCVersion() string { @@ -29,7 +29,7 @@ func WebRPCSchemaVersion() string { // Schema hash generated from your RIDL schema func WebRPCSchemaHash() string { - return "4ec5724f11e7e078f8eaf99b10a5f1a020c46821" + return "896c9dd61e9b52933577b35ac021a229cdc54fe2" } type WebrpcGenVersions struct { diff --git a/proto/authcontrol.gen.ts b/proto/authcontrol.gen.ts index 2194385..fcb273b 100644 --- a/proto/authcontrol.gen.ts +++ b/proto/authcontrol.gen.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -// access-control v0.9.0 4ec5724f11e7e078f8eaf99b10a5f1a020c46821 +// authcontrol v0.9.0 896c9dd61e9b52933577b35ac021a229cdc54fe2 // -- // Code generated by webrpc-gen@v0.21.0 with typescript@v0.15.0 generator. DO NOT EDIT. // @@ -7,7 +7,7 @@ export const WebrpcHeader = "Webrpc" -export const WebrpcHeaderValue = "webrpc@v0.21.0;gen-typescript@v0.15.0;access-control@v0.9.0" +export const WebrpcHeaderValue = "webrpc@v0.21.0;gen-typescript@v0.15.0;authcontrol@v0.9.0" // WebRPC description and code-gen version export const WebRPCVersion = "v1" @@ -16,7 +16,7 @@ export const WebRPCVersion = "v1" export const WebRPCSchemaVersion = "v0.9.0" // Schema hash generated from your RIDL schema -export const WebRPCSchemaHash = "4ec5724f11e7e078f8eaf99b10a5f1a020c46821" +export const WebRPCSchemaHash = "896c9dd61e9b52933577b35ac021a229cdc54fe2" type WebrpcGenVersions = { webrpcGenVersion: string; diff --git a/proto/authcontrol.ridl b/proto/authcontrol.ridl index cd5a746..4ad9dab 100644 --- a/proto/authcontrol.ridl +++ b/proto/authcontrol.ridl @@ -1,6 +1,6 @@ webrpc = v1 -name = access-control +name = authcontrol version = v0.9.0 enum SessionType: uint16 diff --git a/s2s.go b/s2s.go index b2b1eb4..14dd3fe 100644 --- a/s2s.go +++ b/s2s.go @@ -5,10 +5,8 @@ import ( "net/http" "time" - "github.com/go-chi/jwtauth/v5" "github.com/go-chi/traceid" "github.com/go-chi/transport" - "github.com/lestrrat-go/jwx/v2/jwt" ) type S2SClientConfig struct { @@ -34,8 +32,7 @@ func S2SClient(cfg *S2SClientConfig) *http.Client { // Create short-lived service-to-service JWT token for internal communication between Sequence services. func S2SToken(jwtSecret string, claims map[string]any) string { - jwtAuth := jwtauth.New("HS256", []byte(jwtSecret), nil, jwt.WithAcceptableSkew(2*time.Minute)) - + jwtAuth, _ := NewAuth(jwtSecret).GetVerifier(nil) now := time.Now().UTC() c := maps.Clone(claims) diff --git a/s2s_test.go b/s2s_test.go index 9d36218..927d0e7 100644 --- a/s2s_test.go +++ b/s2s_test.go @@ -7,6 +7,7 @@ import ( "github.com/0xsequence/authcontrol" "github.com/go-chi/jwtauth/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,12 +24,8 @@ func TestS2SToken(t *testing.T) { require.NoError(t, err) expiresIn := time.Until(jwt.Expiration()) - if expiresIn < 29*time.Second { - t.Errorf("expected default expiry to be at least 30s, got %v", expiresIn) - } + assert.Greater(t, expiresIn, 29*time.Second) service := claims["service"].(string) - if service != "test" { - t.Errorf("unexpected service claim, got %q", service) - } + assert.Equal(t, "test", service) }