diff --git a/common_test.go b/common_test.go index cc5aa91..b221a18 100644 --- a/common_test.go +++ b/common_test.go @@ -8,7 +8,6 @@ import ( "net/http/httptest" "testing" - "github.com/go-chi/jwtauth/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,24 +15,13 @@ import ( "github.com/0xsequence/authcontrol/proto" ) -func mustJWT(t *testing.T, auth *jwtauth.JWTAuth, claims map[string]any) *string { - t.Helper() - if claims == nil { - return nil - } - - _, token, err := auth.Encode(claims) - require.NoError(t, err) - return &token -} - const HeaderKey = "Test-Key" func keyFunc(r *http.Request) string { return r.Header.Get(HeaderKey) } -func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey string, jwt *string) (bool, error) { +func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey string, jwt string) (bool, error) { req, err := http.NewRequest("POST", path, nil) require.NoError(t, err) @@ -42,8 +30,8 @@ func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, pat req.Header.Set(HeaderKey, accessKey) } - if jwt != nil { - req.Header.Set("Authorization", "Bearer "+*jwt) + if jwt != "" { + req.Header.Set("Authorization", "Bearer "+jwt) } rr := httptest.NewRecorder() diff --git a/go.mod b/go.mod index 01639c7..bf68bff 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,10 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect + github.com/go-chi/traceid v0.2.0 // indirect + github.com/go-chi/transport v0.4.0 // indirect github.com/goccy/go-json v0.10.3 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/httprc v1.0.6 // indirect diff --git a/go.sum b/go.sum index b504e0a..0893db6 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,14 @@ github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/jwtauth/v5 v5.3.1 h1:1ePWrjVctvp1tyBq5b/2ER8Th/+RbYc7x4qNsc5rh5A= github.com/go-chi/jwtauth/v5 v5.3.1/go.mod h1:6Fl2RRmWXs3tJYE1IQGX81FsPoGqDwq9c15j52R5q80= +github.com/go-chi/traceid v0.2.0 h1:M4SVlzbnq6zfNCOvi8LwLFGugY04El+hS8njO0Pwml4= +github.com/go-chi/traceid v0.2.0/go.mod h1:XFfEEYZjqgML4ySh+wYBU29eqJkc2um7oEzgIc63e74= +github.com/go-chi/transport v0.4.0 h1:wKkHHapDbijGz7sGicEqyQIj6KD/LV5+R7H8QrZQMco= +github.com/go-chi/transport v0.4.0/go.mod h1:uoCleTaQiFtoatEiiqcXFZ5OxIp6s1DfGeVsCVbalT4= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= @@ -32,6 +38,7 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/http.go b/http.go new file mode 100644 index 0000000..243db54 --- /dev/null +++ b/http.go @@ -0,0 +1,58 @@ +package authcontrol + +import ( + "maps" + "net/http" + "time" + + "github.com/go-chi/jwtauth/v5" + "github.com/go-chi/traceid" + "github.com/go-chi/transport" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +type S2SClientConfig struct { + Service string + JWTSecret string + DebugRequests bool +} + +// Service-to-service HTTP client for internal communication between Sequence services. +func S2SClient(cfg *S2SClientConfig) *http.Client { + httpClient := &http.Client{ + Transport: transport.Chain(http.DefaultTransport, + traceid.Transport, + transport.SetHeaderFunc("Authorization", s2sAuthHeader(cfg.JWTSecret, map[string]any{"service": cfg.Service})), + transport.If(cfg.DebugRequests, transport.LogRequests(transport.LogOptions{Concise: true, CURL: true})), + ), + } + + return httpClient +} + +// Create short-lived service-to-service JWT token for internal communication between Sequence services. +func S2SToken(jwtSecret string, claims map[string]any) string { + jwtAuth := jwtauth.New("HS256", []byte(jwtSecret), nil, jwt.WithAcceptableSkew(2*time.Minute)) + + now := time.Now().UTC() + + c := maps.Clone(claims) + if c == nil { + c = map[string]any{} + } + + c["iat"] = now + + if _, ok := c["exp"]; !ok { + c["exp"] = now.Add(30 * time.Second) + } + + _, t, _ := jwtAuth.Encode(c) + return t +} + +func s2sAuthHeader(jwtSecret string, claims map[string]any) func(req *http.Request) string { + return func(req *http.Request) string { + return "BEARER " + S2SToken(jwtSecret, claims) + } +} diff --git a/middleware.go b/middleware.go index 66f9918..d0ebfd6 100644 --- a/middleware.go +++ b/middleware.go @@ -11,15 +11,18 @@ import ( ) type Options struct { + JWTSecret string KeyFuncs []KeyFunc UserStore UserStore ErrHandler ErrHandler } -func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Handler { +func Session(cfg *Options) func(next http.Handler) http.Handler { + auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil) + eh := errHandler - if o != nil && o.ErrHandler != nil { - eh = o.ErrHandler + if cfg != nil && cfg.ErrHandler != nil { + eh = cfg.ErrHandler } return func(next http.Handler) http.Handler { @@ -38,8 +41,8 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han token jwt.Token ) - if o != nil { - for _, f := range o.KeyFuncs { + if cfg != nil { + for _, f := range cfg.KeyFuncs { if accessKey = f(r); accessKey != "" { break } @@ -79,8 +82,8 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han ctx = withAccount(ctx, accountClaim) sessionType = proto.SessionType_Wallet - if o != nil && o.UserStore != nil { - user, isAdmin, err := o.UserStore.GetUser(ctx, accountClaim) + if cfg != nil && cfg.UserStore != nil { + user, isAdmin, err := cfg.UserStore.GetUser(ctx, accountClaim) if err != nil { eh(r, w, err) return @@ -121,10 +124,10 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han // AccessControl middleware that checks if the session type is allowed to access the endpoint. // It also sets the compute units on the context if the endpoint requires it. -func AccessControl(acl Config[ACL], o *Options) func(next http.Handler) http.Handler { +func AccessControl(acl Config[ACL], cfg *Options) func(next http.Handler) http.Handler { eh := errHandler - if o != nil && o.ErrHandler != nil { - eh = o.ErrHandler + if cfg != nil && cfg.ErrHandler != nil { + eh = cfg.ErrHandler } return func(next http.Handler) http.Handler { diff --git a/middleware_test.go b/middleware_test.go index b21d0a4..58cf61e 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -11,7 +11,6 @@ import ( "time" "github.com/go-chi/chi/v5" - "github.com/go-chi/jwtauth/v5" "github.com/stretchr/testify/assert" "github.com/0xsequence/authcontrol" @@ -20,6 +19,8 @@ import ( type mockStore map[string]bool +var secret = "secret" + func (m mockStore) GetUser(ctx context.Context, address string) (any, bool, error) { v, ok := m[address] if !ok { @@ -77,9 +78,8 @@ func TestSession(t *testing.T) { ServiceName = "serviceName" ) - auth := jwtauth.New("HS256", []byte("secret"), nil) - options := &authcontrol.Options{ + JWTSecret: secret, UserStore: mockStore{ UserAddress: false, AdminAddress: true, @@ -89,7 +89,7 @@ func TestSession(t *testing.T) { r := chi.NewRouter() r.Use( - authcontrol.Session(auth, options), + authcontrol.Session(options), authcontrol.AccessControl(ACLConfig, options), ) r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -132,7 +132,7 @@ func TestSession(t *testing.T) { claims = map[string]any{"service": ServiceName} } - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, mustJWT(t, auth, claims)) + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, authcontrol.S2SToken(secret, claims)) session := tc.Session switch { @@ -178,9 +178,8 @@ func TestInvalid(t *testing.T) { AdminAddress = "adminAddress" ) - auth := jwtauth.New("HS256", []byte("secret"), nil) - options := &authcontrol.Options{ + JWTSecret: secret, UserStore: mockStore{ UserAddress: false, AdminAddress: true, @@ -189,10 +188,9 @@ func TestInvalid(t *testing.T) { } r := chi.NewRouter() - r.Use( - authcontrol.Session(auth, options), - authcontrol.AccessControl(ACLConfig, options), - ) + r.Use(authcontrol.Session(options)) + r.Use(authcontrol.AccessControl(ACLConfig, options)) + r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() resp := map[string]any{} @@ -206,60 +204,58 @@ func TestInvalid(t *testing.T) { })) // Without JWT - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, nil) + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, "") assert.True(t, ok) assert.NoError(t, err) // Wrong JWT - wrongJwt := "nope" - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, &wrongJwt) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, "wrong-secret") assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) - var claims map[string]any - claims = map[string]any{"service": "client_service"} + claims := map[string]any{"service": "client_service"} // Valid Request - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, mustJWT(t, auth, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) assert.True(t, ok) assert.NoError(t, err) // Invalid request path with wrong not enough parts in path for valid RPC request - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, mustJWT(t, auth, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid request path with wrong "rpc" - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, mustJWT(t, auth, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid Service - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, mustJWT(t, auth, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid Method - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, mustJWT(t, auth, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, authcontrol.S2SToken(secret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Expired JWT Token claims["exp"] = time.Now().Add(-time.Second).Unix() - jwt := mustJWT(t, auth, claims) + expiredJWT := authcontrol.S2SToken(secret, claims) // Expired JWT Token valid method - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, jwt) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, expiredJWT) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrSessionExpired) // Expired JWT Token invalid service - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, jwt) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, expiredJWT) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrSessionExpired) // Expired JWT Token invalid method - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, jwt) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, expiredJWT) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrSessionExpired) } @@ -292,9 +288,8 @@ func TestCustomErrHandler(t *testing.T) { HTTPStatus: 400, } - auth := jwtauth.New("HS256", []byte("secret"), nil) - - options := &authcontrol.Options{ + opts := &authcontrol.Options{ + JWTSecret: secret, UserStore: mockStore{ UserAddress: false, AdminAddress: true, @@ -313,8 +308,8 @@ func TestCustomErrHandler(t *testing.T) { r := chi.NewRouter() r.Use( - authcontrol.Session(auth, options), - authcontrol.AccessControl(ACLConfig, options), + authcontrol.Session(opts), + authcontrol.AccessControl(ACLConfig, opts), ) r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -322,12 +317,12 @@ func TestCustomErrHandler(t *testing.T) { claims = map[string]any{"service": "client_service"} // Valid Request - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, mustJWT(t, auth, claims)) + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) assert.True(t, ok) assert.NoError(t, err) // Invalid service which should return custom error from overrided ErrHandler - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, mustJWT(t, auth, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, customErr) }