Skip to content

Commit

Permalink
Create S2S tokens (#3)
Browse files Browse the repository at this point in the history
* Create S2S tokens

* Create a config struct S2SClient

* Fix CI
  • Loading branch information
VojtechVitek authored Oct 24, 2024
1 parent bc9c31b commit fe8f8fe
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 57 deletions.
18 changes: 3 additions & 15 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,20 @@ import (
"net/http/httptest"
"testing"

"github.com/go-chi/jwtauth/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/0xsequence/authcontrol"
"github.com/0xsequence/authcontrol/proto"
)

func mustJWT(t *testing.T, auth *jwtauth.JWTAuth, claims map[string]any) *string {
t.Helper()
if claims == nil {
return nil
}

_, token, err := auth.Encode(claims)
require.NoError(t, err)
return &token
}

const HeaderKey = "Test-Key"

func keyFunc(r *http.Request) string {
return r.Header.Get(HeaderKey)
}

func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey string, jwt *string) (bool, error) {
func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey string, jwt string) (bool, error) {
req, err := http.NewRequest("POST", path, nil)
require.NoError(t, err)

Expand All @@ -42,8 +30,8 @@ func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, pat
req.Header.Set(HeaderKey, accessKey)
}

if jwt != nil {
req.Header.Set("Authorization", "Bearer "+*jwt)
if jwt != "" {
req.Header.Set("Authorization", "Bearer "+jwt)
}

rr := httptest.NewRecorder()
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect
github.com/go-chi/traceid v0.2.0 // indirect
github.com/go-chi/transport v0.4.0 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.6 // indirect
Expand Down
7 changes: 7 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@ github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/jwtauth/v5 v5.3.1 h1:1ePWrjVctvp1tyBq5b/2ER8Th/+RbYc7x4qNsc5rh5A=
github.com/go-chi/jwtauth/v5 v5.3.1/go.mod h1:6Fl2RRmWXs3tJYE1IQGX81FsPoGqDwq9c15j52R5q80=
github.com/go-chi/traceid v0.2.0 h1:M4SVlzbnq6zfNCOvi8LwLFGugY04El+hS8njO0Pwml4=
github.com/go-chi/traceid v0.2.0/go.mod h1:XFfEEYZjqgML4ySh+wYBU29eqJkc2um7oEzgIc63e74=
github.com/go-chi/transport v0.4.0 h1:wKkHHapDbijGz7sGicEqyQIj6KD/LV5+R7H8QrZQMco=
github.com/go-chi/transport v0.4.0/go.mod h1:uoCleTaQiFtoatEiiqcXFZ5OxIp6s1DfGeVsCVbalT4=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
Expand All @@ -32,6 +38,7 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
Expand Down
58 changes: 58 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package authcontrol

import (
"maps"
"net/http"
"time"

"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/traceid"
"github.com/go-chi/transport"
"github.com/lestrrat-go/jwx/v2/jwt"
)

type S2SClientConfig struct {
Service string
JWTSecret string
DebugRequests bool
}

// Service-to-service HTTP client for internal communication between Sequence services.
func S2SClient(cfg *S2SClientConfig) *http.Client {
httpClient := &http.Client{
Transport: transport.Chain(http.DefaultTransport,
traceid.Transport,
transport.SetHeaderFunc("Authorization", s2sAuthHeader(cfg.JWTSecret, map[string]any{"service": cfg.Service})),
transport.If(cfg.DebugRequests, transport.LogRequests(transport.LogOptions{Concise: true, CURL: true})),
),
}

return httpClient
}

// Create short-lived service-to-service JWT token for internal communication between Sequence services.
func S2SToken(jwtSecret string, claims map[string]any) string {
jwtAuth := jwtauth.New("HS256", []byte(jwtSecret), nil, jwt.WithAcceptableSkew(2*time.Minute))

now := time.Now().UTC()

c := maps.Clone(claims)
if c == nil {
c = map[string]any{}
}

c["iat"] = now

if _, ok := c["exp"]; !ok {
c["exp"] = now.Add(30 * time.Second)
}

_, t, _ := jwtAuth.Encode(c)
return t
}

func s2sAuthHeader(jwtSecret string, claims map[string]any) func(req *http.Request) string {
return func(req *http.Request) string {
return "BEARER " + S2SToken(jwtSecret, claims)
}
}
23 changes: 13 additions & 10 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@ import (
)

type Options struct {
JWTSecret string
KeyFuncs []KeyFunc
UserStore UserStore
ErrHandler ErrHandler
}

func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Handler {
func Session(cfg *Options) func(next http.Handler) http.Handler {
auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil)

eh := errHandler
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
if cfg != nil && cfg.ErrHandler != nil {
eh = cfg.ErrHandler
}

return func(next http.Handler) http.Handler {
Expand All @@ -38,8 +41,8 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
token jwt.Token
)

if o != nil {
for _, f := range o.KeyFuncs {
if cfg != nil {
for _, f := range cfg.KeyFuncs {
if accessKey = f(r); accessKey != "" {
break
}
Expand Down Expand Up @@ -79,8 +82,8 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
ctx = withAccount(ctx, accountClaim)
sessionType = proto.SessionType_Wallet

if o != nil && o.UserStore != nil {
user, isAdmin, err := o.UserStore.GetUser(ctx, accountClaim)
if cfg != nil && cfg.UserStore != nil {
user, isAdmin, err := cfg.UserStore.GetUser(ctx, accountClaim)
if err != nil {
eh(r, w, err)
return
Expand Down Expand Up @@ -121,10 +124,10 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han

// AccessControl middleware that checks if the session type is allowed to access the endpoint.
// It also sets the compute units on the context if the endpoint requires it.
func AccessControl(acl Config[ACL], o *Options) func(next http.Handler) http.Handler {
func AccessControl(acl Config[ACL], cfg *Options) func(next http.Handler) http.Handler {
eh := errHandler
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
if cfg != nil && cfg.ErrHandler != nil {
eh = cfg.ErrHandler
}

return func(next http.Handler) http.Handler {
Expand Down
59 changes: 27 additions & 32 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"time"

"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/stretchr/testify/assert"

"github.com/0xsequence/authcontrol"
Expand All @@ -20,6 +19,8 @@ import (

type mockStore map[string]bool

var secret = "secret"

func (m mockStore) GetUser(ctx context.Context, address string) (any, bool, error) {
v, ok := m[address]
if !ok {
Expand Down Expand Up @@ -77,9 +78,8 @@ func TestSession(t *testing.T) {
ServiceName = "serviceName"
)

auth := jwtauth.New("HS256", []byte("secret"), nil)

options := &authcontrol.Options{
JWTSecret: secret,
UserStore: mockStore{
UserAddress: false,
AdminAddress: true,
Expand All @@ -89,7 +89,7 @@ func TestSession(t *testing.T) {

r := chi.NewRouter()
r.Use(
authcontrol.Session(auth, options),
authcontrol.Session(options),
authcontrol.AccessControl(ACLConfig, options),
)
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestSession(t *testing.T) {
claims = map[string]any{"service": ServiceName}
}

ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, mustJWT(t, auth, claims))
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, authcontrol.S2SToken(secret, claims))

session := tc.Session
switch {
Expand Down Expand Up @@ -178,9 +178,8 @@ func TestInvalid(t *testing.T) {
AdminAddress = "adminAddress"
)

auth := jwtauth.New("HS256", []byte("secret"), nil)

options := &authcontrol.Options{
JWTSecret: secret,
UserStore: mockStore{
UserAddress: false,
AdminAddress: true,
Expand All @@ -189,10 +188,9 @@ func TestInvalid(t *testing.T) {
}

r := chi.NewRouter()
r.Use(
authcontrol.Session(auth, options),
authcontrol.AccessControl(ACLConfig, options),
)
r.Use(authcontrol.Session(options))
r.Use(authcontrol.AccessControl(ACLConfig, options))

r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
resp := map[string]any{}
Expand All @@ -206,60 +204,58 @@ func TestInvalid(t *testing.T) {
}))

// Without JWT
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, nil)
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, "")
assert.True(t, ok)
assert.NoError(t, err)

// Wrong JWT
wrongJwt := "nope"
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, &wrongJwt)
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, "wrong-secret")
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

var claims map[string]any
claims = map[string]any{"service": "client_service"}
claims := map[string]any{"service": "client_service"}

// Valid Request
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, mustJWT(t, auth, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid request path with wrong not enough parts in path for valid RPC request
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, mustJWT(t, auth, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid request path with wrong "rpc"
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, mustJWT(t, auth, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid Service
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, mustJWT(t, auth, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid Method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, mustJWT(t, auth, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, authcontrol.S2SToken(secret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Expired JWT Token
claims["exp"] = time.Now().Add(-time.Second).Unix()
jwt := mustJWT(t, auth, claims)
expiredJWT := authcontrol.S2SToken(secret, claims)

// Expired JWT Token valid method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, jwt)
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, expiredJWT)
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)

// Expired JWT Token invalid service
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, jwt)
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, expiredJWT)
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)

// Expired JWT Token invalid method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, jwt)
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, expiredJWT)
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)
}
Expand Down Expand Up @@ -292,9 +288,8 @@ func TestCustomErrHandler(t *testing.T) {
HTTPStatus: 400,
}

auth := jwtauth.New("HS256", []byte("secret"), nil)

options := &authcontrol.Options{
opts := &authcontrol.Options{
JWTSecret: secret,
UserStore: mockStore{
UserAddress: false,
AdminAddress: true,
Expand All @@ -313,21 +308,21 @@ func TestCustomErrHandler(t *testing.T) {

r := chi.NewRouter()
r.Use(
authcontrol.Session(auth, options),
authcontrol.AccessControl(ACLConfig, options),
authcontrol.Session(opts),
authcontrol.AccessControl(ACLConfig, opts),
)
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

var claims map[string]any
claims = map[string]any{"service": "client_service"}

// Valid Request
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, mustJWT(t, auth, claims))
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid service which should return custom error from overrided ErrHandler
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, mustJWT(t, auth, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, customErr)
}

0 comments on commit fe8f8fe

Please sign in to comment.