Skip to content

Commit

Permalink
test: add tests for rotation persister
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Dec 3, 2024
1 parent 9e34b19 commit 3811d20
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 73 deletions.
190 changes: 121 additions & 69 deletions oauth2/fosite_store_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,27 @@ import (
"testing"
"time"

"github.com/ory/hydra/v2/oauth2"

"github.com/ory/x/assertx"

"github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/jwk"

"github.com/go-jose/go-jose/v3"
"github.com/pborman/uuid"

"github.com/ory/fosite/handler/rfc7523"

"github.com/ory/hydra/v2/oauth2/trust"

"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/x"

"github.com/ory/fosite/storage"
"github.com/ory/x/sqlxx"

gofrsuuid "github.com/gofrs/uuid"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/x/sqlcon"

"github.com/ory/fosite/handler/rfc7523"
"github.com/ory/fosite/storage"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/oauth2/trust"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/assertx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/sqlxx"
)

var defaultIgnoreKeys = []string{
Expand All @@ -57,28 +48,32 @@ var defaultIgnoreKeys = []string{
"client.client_secret",
}

var defaultRequest = fosite.Request{
ID: "blank",
RequestedAt: time.Now().UTC().Round(time.Second),
Client: &client.Client{
ID: "foobar",
Contacts: []string{},
RedirectURIs: []string{},
Audience: []string{},
AllowedCORSOrigins: []string{},
ResponseTypes: []string{},
GrantTypes: []string{},
JSONWebKeys: &x.JoseJSONWebKeySet{},
Metadata: sqlxx.JSONRawMessage("{}"),
},
RequestedScope: fosite.Arguments{"fa", "ba"},
GrantedScope: fosite.Arguments{"fa", "ba"},
RequestedAudience: fosite.Arguments{"ad1", "ad2"},
GrantedAudience: fosite.Arguments{"ad1", "ad2"},
Form: url.Values{"foo": []string{"bar", "baz"}},
Session: oauth2.NewSession("bar"),
func newDefaultRequest(id string) fosite.Request {
return fosite.Request{
ID: id,
RequestedAt: time.Now().UTC().Round(time.Second),
Client: &client.Client{
ID: "foobar",
Contacts: []string{},
RedirectURIs: []string{},
Audience: []string{},
AllowedCORSOrigins: []string{},
ResponseTypes: []string{},
GrantTypes: []string{},
JSONWebKeys: &x.JoseJSONWebKeySet{},
Metadata: sqlxx.JSONRawMessage("{}"),
},
RequestedScope: fosite.Arguments{"fa", "ba"},
GrantedScope: fosite.Arguments{"fa", "ba"},
RequestedAudience: fosite.Arguments{"ad1", "ad2"},
GrantedAudience: fosite.Arguments{"ad1", "ad2"},
Form: url.Values{"foo": []string{"bar", "baz"}},
Session: oauth2.NewSession("bar"),
}
}

var defaultRequest = newDefaultRequest("blank")

// var lifespan = time.Hour
var flushRequests = []*fosite.Request{
{
Expand Down Expand Up @@ -497,54 +492,111 @@ func testHelperRotateRefreshToken(x oauth2.InternalRegistry) func(t *testing.T)
return func(t *testing.T) {
ctx := context.Background()

createTokens := func(t *testing.T, r *fosite.Request) (refreshTokenSession string, accessTokenSession string) {
refreshTokenSession = fmt.Sprintf("refresh_token_%s", uuid.New())
accessTokenSession = fmt.Sprintf("access_token_%s", uuid.New())
err := x.OAuth2Storage().CreateAccessTokenSession(ctx, accessTokenSession, r)
require.NoError(t, err)

err = x.OAuth2Storage().CreateRefreshTokenSession(ctx, refreshTokenSession, accessTokenSession, r)
require.NoError(t, err)

// Sanity check
req, err := x.OAuth2Storage().GetRefreshTokenSession(ctx, refreshTokenSession, nil)
require.NoError(t, err)
require.EqualValues(t, r.GetID(), req.GetID())

req, err = x.OAuth2Storage().GetAccessTokenSession(ctx, accessTokenSession, nil)
require.NoError(t, err)
require.EqualValues(t, r.GetID(), req.GetID())
return
}

t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) {
// SETUP
m := x.OAuth2Storage()
r := newDefaultRequest(uuid.New())
refreshTokenSession, accessTokenSession := createTokens(t, &r)

refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix())
err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, "", &defaultRequest)
require.NoError(t, err, "precondition failed: could not create refresh token session")

// ACT
err = m.RotateRefreshToken(ctx, defaultRequest.GetID(), refreshTokenSession)
err := m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)
require.NoError(t, err)

tmpSession := new(fosite.Session)
_, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession)
_, err = m.GetAccessTokenSession(ctx, accessTokenSession, nil)
assert.ErrorIs(t, err, fosite.ErrNotFound, "Token is no longer active because it was refreshed")

// ASSERT
// a revoked refresh token returns an error when getting the token again
assert.ErrorIs(t, err, fosite.ErrInactiveToken)
_, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil)
assert.ErrorIs(t, err, fosite.ErrInactiveToken, "Token is no longer active because it was refreshed")
})

t.Run("refresh token enters grace period when configured,", func(t *testing.T) {
// SETUP
x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1m")
t.Run("refresh token is valid until the grace period has ended", func(t *testing.T) {
x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s")
t.Cleanup(func() {
x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod)
})

m := x.OAuth2Storage()
r := newDefaultRequest(uuid.New())
refreshTokenSession, accessTokenSession1 := createTokens(t, &r)
accessTokenSession2 := fmt.Sprintf("access_token_%s", uuid.New())
require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession2, &r))

// Create a second access token
require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession))
require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession))
require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession))

req, err := m.GetAccessTokenSession(ctx, accessTokenSession1, nil)
assert.ErrorIs(t, err, fosite.ErrNotFound)

req, err = m.GetAccessTokenSession(ctx, accessTokenSession2, nil)
assert.NoError(t, err, "The second access token is still valid.")

req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil)
assert.NoError(t, err)
assert.Equal(t, r.GetID(), req.GetID())

// always reset back to the default
time.Sleep(time.Second)

req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil)
assert.Error(t, err)
})

t.Run("refresh token revokes all access tokens from the request if the access token signature is not found", func(t *testing.T) {
x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s")
t.Cleanup(func() {
x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "0m")
x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod)
})

m := x.OAuth2Storage()
r := newDefaultRequest(uuid.New())

refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix())
err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, "", &defaultRequest)
require.NoError(t, err, "precondition failed: could not create refresh token session")
refreshTokenSession := fmt.Sprintf("refresh_token_%s", uuid.New())
accessTokenSession1 := fmt.Sprintf("access_token_%s", uuid.New())
accessTokenSession2 := fmt.Sprintf("access_token_%s", uuid.New())
require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession1, &r))
require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession2, &r))

require.NoError(t, m.CreateRefreshTokenSession(ctx, refreshTokenSession, "", &r),
"precondition failed: could not create refresh token session")

// ACT
require.NoError(t, m.RotateRefreshToken(ctx, defaultRequest.GetID(), refreshTokenSession))
require.NoError(t, m.RotateRefreshToken(ctx, defaultRequest.GetID(), refreshTokenSession))
require.NoError(t, m.RotateRefreshToken(ctx, defaultRequest.GetID(), refreshTokenSession))
require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession))
require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession))
require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession))

req, err := m.GetAccessTokenSession(ctx, accessTokenSession1, nil)
assert.ErrorIs(t, err, fosite.ErrNotFound)

req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil)
req, err = m.GetAccessTokenSession(ctx, accessTokenSession2, nil)
assert.ErrorIs(t, err, fosite.ErrNotFound)

// ASSERT
// when grace period is configured the refresh token can be obtained within
// the grace period without error
req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil)
assert.NoError(t, err)
assert.Equal(t, r.GetID(), req.GetID())

assert.Equal(t, defaultRequest.GetID(), req.GetID())
time.Sleep(time.Second)

req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil)
assert.Error(t, err)
})
}
}
Expand Down
13 changes: 9 additions & 4 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,17 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str
return err
}

if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, &OAuth2RefreshTable{
OAuth2RequestSQL: *req,
AccessTokenSignature: sql.NullString{
var sig sql.NullString
if len(accessTokenSignature) > 0 {
sig = sql.NullString{
Valid: true,
String: x.SignatureHash(accessTokenSignature),
},
}
}

if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, &OAuth2RefreshTable{
OAuth2RequestSQL: *req,
AccessTokenSignature: sig,
})); errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return fosite.ErrSerializationFailure.WithWrap(err)
} else if err != nil {
Expand Down

0 comments on commit 3811d20

Please sign in to comment.