diff --git a/driver/config/provider.go b/driver/config/provider.go index 52b9ee45a3f..1c66bb32c36 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -673,8 +673,8 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string { func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration { gracePeriod := p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) - if gracePeriod > time.Hour { - return time.Hour + if gracePeriod > time.Minute*5 { + return time.Minute * 5 } return gracePeriod } diff --git a/internal/testhelpers/oauth2.go b/internal/testhelpers/oauth2.go index 41f0ddaec8e..63d28c98f65 100644 --- a/internal/testhelpers/oauth2.go +++ b/internal/testhelpers/oauth2.go @@ -111,6 +111,20 @@ func IntrospectToken(t testing.TB, conf *oauth2.Config, token string, adminTS *h return gjson.ParseBytes(ioutilx.MustReadAll(res.Body)) } +func RevokeToken(t testing.TB, conf *oauth2.Config, token string, publicTS *httptest.Server) gjson.Result { + require.NotEmpty(t, token) + + req := httpx.MustNewRequest("POST", publicTS.URL+"/oauth2/revoke", + strings.NewReader((url.Values{"token": {token}}).Encode()), + "application/x-www-form-urlencoded") + + req.SetBasicAuth(conf.ClientID, conf.ClientSecret) + res, err := publicTS.Client().Do(req) + require.NoError(t, err) + defer res.Body.Close() + return gjson.ParseBytes(ioutilx.MustReadAll(res.Body)) +} + func UpdateClientTokenLifespans(t *testing.T, conf *oauth2.Config, clientID string, lifespans client.Lifespans, adminTS *httptest.Server) { b, err := json.Marshal(lifespans) require.NoError(t, err) diff --git a/oauth2/fosite_store_helpers_test.go b/oauth2/fosite_store_helpers_test.go index 1e1faafe68f..adb630bf8c8 100644 --- a/oauth2/fosite_store_helpers_test.go +++ b/oauth2/fosite_store_helpers_test.go @@ -6,12 +6,13 @@ package oauth2_test import ( "context" "fmt" - "github.com/ory/hydra/v2/oauth2" "net/url" "slices" "testing" "time" + "github.com/ory/hydra/v2/oauth2" + "github.com/ory/x/assertx" "github.com/ory/hydra/v2/flow" diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index 483160cafe1..f92dd79387a 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -38,7 +38,7 @@ func setupRegistries(t *testing.T) { if len(registries) == 0 && !testing.Short() { // first time called and sql tests var cleanSQL func(*testing.T) - registries["postgres"], registries["mysql"], registries["cockroach"], cleanSQL = internal.ConnectDatabases(t, true, &contextx.Default{}) + registries["postgres"], registries["mysql"], registries["cockroach"], cleanSQL = internal.ConnectDatabases(t, false, &contextx.Default{}) cleanMem := cleanRegistries cleanMem(t) cleanRegistries = func(t *testing.T) { diff --git a/oauth2/handler.go b/oauth2/handler.go index 288ed1f16f0..3f1a633038d 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -727,11 +727,14 @@ type revokeOAuth2Token struct { // default: errorOAuth2 func (h *Handler) revokeOAuth2Token(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - events.Trace(ctx, events.AccessTokenRevoked) - err := h.r.OAuth2Provider().NewRevocationRequest(ctx, r) + err := h.r.Persister().Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { + return h.r.OAuth2Provider().NewRevocationRequest(ctx, r) + }) if err != nil { x.LogError(r, err, h.r.Logger()) + } else { + events.Trace(ctx, events.AccessTokenRevoked) } h.r.OAuth2Provider().WriteRevocationResponse(ctx, w, err) diff --git a/oauth2/helper_test.go b/oauth2/helper_test.go index c2a4194d7ff..04f41298b71 100644 --- a/oauth2/helper_test.go +++ b/oauth2/helper_test.go @@ -5,10 +5,11 @@ package oauth2_test import ( "context" + "testing" + "github.com/oleiade/reflections" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" diff --git a/oauth2/helpers.go b/oauth2/helpers.go index 024d7827741..4db4bf84d8e 100644 --- a/oauth2/helpers.go +++ b/oauth2/helpers.go @@ -1,13 +1,18 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package oauth2 import ( "context" "crypto/sha256" "fmt" + "time" + "github.com/gobuffalo/pop/v6" gofrsuuid "github.com/gofrs/uuid" + "github.com/ory/hydra/v2/x" - "time" ) func signatureFromJTI(jti string) string { diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index eb002b73a95..47ce2d4411b 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -10,8 +10,8 @@ import ( "encoding/json" "errors" "fmt" - "github.com/ory/hydra/v2/jwk" "io" + "math/rand" "net/http" "net/http/httptest" "net/url" @@ -21,6 +21,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/jwk" + "github.com/go-jose/go-jose/v3" "github.com/golang-jwt/jwt/v5" "github.com/julienschmidt/httprouter" @@ -167,7 +169,7 @@ func acceptConsentHandler(t *testing.T, c *client.Client, adminClient *hydra.API // - [x] What happens if `id_token_hint` does not match the value from the handled authentication request ("accept login") func TestAuthCodeWithDefaultStrategy(t *testing.T) { setupRegistries(t) - + rng := rand.New(rand.NewSource(time.Now().UnixNano())) ctx := context.Background() reg := internal.NewRegistrySQLFromURL(t, registries["cockroach"].Config().DSN(), true, &contextx.Default{}) @@ -187,7 +189,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) require.NoError(t, err, "%s", introspect) - requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second) + requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second*2) } assertIDToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedSubject, expectedNonce string, expectedExp time.Time) gjson.Result { @@ -202,7 +204,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.True(t, time.Now().After(time.Unix(claims.Get("iat").Int(), 0)), "%s", claims) assert.True(t, time.Now().After(time.Unix(claims.Get("nbf").Int(), 0)), "%s", claims) assert.True(t, time.Now().Before(time.Unix(claims.Get("exp").Int(), 0)), "%s", claims) - requirex.EqualTime(t, expectedExp, time.Unix(claims.Get("exp").Int(), 0), 2*time.Second) + requirex.EqualTime(t, expectedExp, time.Unix(claims.Get("exp").Int(), 0), 3*time.Second) assert.NotEmpty(t, claims.Get("jti").String(), "%s", claims) assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), claims.Get("iss").String(), "%s", claims) assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) @@ -342,12 +344,17 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { }) t.Run("case=graceful token rotation", func(t *testing.T) { + // This is an essential and complex test suite. We need to cover the following cases: + // + // 5. Once a refresh token is used - within the grace period - it is marked as used and cannot be used again (unless we are in the grace period). + // 6. Revoking a refresh token also invalidates all tokens associated with the original consent. + run := func(t *testing.T, strategy string) { - reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, time.Second*3) + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "2s") reg.Config().MustSet(ctx, config.KeyTokenHook, nil) reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, nil) - reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, time.Minute) - reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, time.Minute) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, "1m") t.Cleanup(func() { reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, nil) @@ -392,11 +399,230 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { return refreshedToken } - t.Run("followup=graceful token refresh can handle concurrent refreshing", func(t *testing.T) { - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + assertInactive := func(t *testing.T, token string, c *oauth2.Config) { + t.Helper() + at := testhelpers.IntrospectToken(t, conf, token, adminTS) + assert.False(t, at.Get("active").Bool(), "%s", at) + } + + t.Run("gracefully refreshing a token does invalidate the previous access token", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + + token := issueTokens(t) + _ = refreshTokens(t, token) + + assertInactive(t, token.AccessToken, conf) // Original access token is invalid + + _ = refreshTokens(t, token) + assertInactive(t, token.AccessToken, conf) // Original access token is still invalid + }) + + t.Run("an expired refresh token can not be used even if we are in the grace period", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "500ms") + + token := issueTokens(t) + time.Sleep(time.Second) + + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err, "Rotating an expired token is not possible even when we are in the grace period") + + // The access token is still valid because using an expired refresh token has no effect on the access token. + assertInactive(t, token.RefreshToken, conf) + }) + + t.Run("a used refresh token can not be re-used once the grace period ends and it triggers re-use detection", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "500ms") + + token := issueTokens(t) + refreshed := refreshTokens(t, token) + + time.Sleep(time.Millisecond * 501) // Wait for the grace period to end + + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err, "Rotating a used refresh token is not possible after the grace period") + + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + + assertInactive(t, refreshed.AccessToken, conf) + assertInactive(t, refreshed.RefreshToken, conf) + }) + + // This test suite covers complex scenarios where we have multiple generations of tokens and we need to ensure + // that key security mitigations are in place: + // + // - Token re-use detection clears all tokens if a refresh token is re-used after the grace period. + // - Revoking consent clears all tokens. + // - Token revokation clears all tokens. + // + // The test creates 4 token generations, where each generations has twice as many tokens as the previous generation. + // The generations are created like this: + // + // - In the first scenario, all token generations are created at the same time. + // - In the second scenario, we create token generations with a delay that is longer than the grace period between them. + // + // Tokens for each generation are created in parallel to ensure we have no state leak anywhere. + t.Run("token generations", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "500ms") + + createTokenGenerations := func(t *testing.T, count int, withSleep time.Duration) [][]*oauth2.Token { + generations := make([][]*oauth2.Token, count) + generations[0] = []*oauth2.Token{issueTokens(t)} + // Start from the first generation. For every next generation, we refresh all the tokens of the previous generation twice. + for i := 1; i < len(generations); i++ { + generations[i] = make([]*oauth2.Token, 0, len(generations[i-1])*2) + + var wg sync.WaitGroup + gen := func(i int, token *oauth2.Token) { + defer wg.Done() + generations[i] = append(generations[i], refreshTokens(t, token)) + } + + for _, token := range generations[i-1] { + wg.Add(2) + go gen(i, token) + go gen(i, token) + } + + wg.Wait() + if withSleep > 0 { + time.Sleep(withSleep) + } + } + return generations + } + + t.Run("re-using an old graceful refresh token invalidates all tokens", func(t *testing.T) { + // This test only works if the refresh token lifespan is longer than the grace period. + generations := createTokenGenerations(t, 4, time.Millisecond*600) + + generationIndex := rng.Intn(len(generations) - 1) // Exclude the last generation + tokenIndex := rng.Intn(len(generations[generationIndex])) + + token := generations[generationIndex][tokenIndex] + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err) + + // Now all tokens are inactive + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + + for _, withSleep := range []time.Duration{0, time.Millisecond * 600} { + t.Run(fmt.Sprintf("withSleep=%s", withSleep), func(t *testing.T) { + createTokenGenerations := func(t *testing.T, count int) [][]*oauth2.Token { + return createTokenGenerations(t, count, withSleep) + } + + t.Run("only the most recent token generation is valid across the board", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + // All generations except the last one are valid. + for i, generation := range generations[:len(generations)-1] { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + }) + } + }) + } + + // The last generation is valid: + t.Run(fmt.Sprintf("generation=%d", len(generations)-1), func(t *testing.T) { + for j, token := range generations[len(generations)-1] { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + introspectAccessToken(t, conf, token, subject) + assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, time.Now().Add(reg.Config().GetRefreshTokenLifespan(ctx))) + }) + } + }) + }) + + t.Run("revoking consent revokes all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + // After revoking consent, all generations are invalid. + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + + t.Run("re-using the a recent refresh token after the grace period has ended invalidates all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + token := generations[len(generations)-1][0] + + finalToken := refreshTokens(t, token) + time.Sleep(time.Millisecond * 600) // Wait for the grace period to end + + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err) + + // Now all tokens are inactive + for i, generation := range append(generations, []*oauth2.Token{finalToken}) { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + + t.Run("revoking a refresh token in the chain revokes all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + testhelpers.RevokeToken(t, conf, generations[len(generations)-1][0].RefreshToken, publicTS) + + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + token := token + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + }) + } + }) + + t.Run("it is possible to refresh tokens concurrently", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + + token := issueTokens(t) var wg sync.WaitGroup refresh := func(t *testing.T, token *oauth2.Token) *oauth2.Token { @@ -410,22 +636,19 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { refreshes := make([]*oauth2.Token, 5) for k := range refreshes { wg.Add(1) - //time.Sleep(time.Millisecond * 100) go func(k int) { defer wg.Done() - t.Logf("Refreshing token %d", k) refreshes[k] = refresh(t, token) }(k) } - wg.Wait() + // All tokens are valid. for k, actual := range refreshes { refresh := actual - - require.NotEmpty(t, refresh.RefreshToken) - require.NotEmpty(t, refresh.AccessToken) - require.NotEmpty(t, refresh.Extra("id_token")) + require.NotEmpty(t, refresh.RefreshToken, "token %d:\ntoken:%+v", k, refresh) + require.NotEmpty(t, refresh.AccessToken, "token %d:\ntoken:%+v", k, refresh) + require.NotEmpty(t, refresh.Extra("id_token"), "token %d:\ntoken:%+v", k, refresh) i := testhelpers.IntrospectToken(t, conf, refresh.AccessToken, adminTS) assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) @@ -433,158 +656,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { i = testhelpers.IntrospectToken(t, conf, refresh.RefreshToken, adminTS) assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) } - - t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) - require.NoError(t, err) - - _, err = conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) - - for k, actual := range refreshes { - i := testhelpers.IntrospectToken(t, conf, actual.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "token %d: %s", k, i) - } - }) - }) - - t.Run("followup=graceful token refresh with reuse detection", func(t *testing.T) { - start := time.Now() - - token := issueTokens(t) - var first, second *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - first = refreshTokens(t, token) - }) - - t.Run("followup=second refresh", func(t *testing.T) { - second = refreshTokens(t, token) - }) - - t.Run("followup=all resulting tokens are valid", func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.True(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.True(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.True(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.True(t, i.Get("active").Bool(), "%s", i) - }) - - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(4 * time.Second))) - - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.Falsef(t, i.Get("active").Bool(), "The refresh token should no longer be valid because it was used: %s", i) - - token.Expiry = time.Now().Add(-time.Hour * 24) - result, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err, "%+v", result) - - i = testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) }) - - //t.Run("followup=graceful token refresh with reuse detection with consent revocation", func(t *testing.T) { - // token := issueTokens(t) - // var first, second *oauth2.Token - // t.Run("followup=first refresh", func(t *testing.T) { - // first = refreshTokens(t, token) - // }) - // - // t.Run("followup=second refresh", func(t *testing.T) { - // second = refreshTokens(t, token) - // }) - // - // t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - // err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) - // require.NoError(t, err) - // - // _, err = conf.TokenSource(context.Background(), token).Token() - // assert.Error(t, err) - // - // i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - // assert.False(t, i.Get("active").Bool(), "%s", i) - // - // i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - // assert.False(t, i.Get("active").Bool(), "%s", i) - // - // i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - // assert.False(t, i.Get("active").Bool(), "%s", i) - // - // i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - // assert.False(t, i.Get("active").Bool(), "%s", i) - // }) - //}) - // - //t.Run("followup=graceful refresh tokens with multiple nested branches belong to the same request", func(t *testing.T) { - // start := time.Now() - // token := issueTokens(t) - // var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token - // t.Run("followup=first refresh", func(t *testing.T) { - // a1Refresh = refreshTokens(t, token) - // }) - // - // t.Run("followup=second refresh", func(t *testing.T) { - // b1Refresh = refreshTokens(t, token) - // }) - // - // t.Run("followup=first refresh from first refresh", func(t *testing.T) { - // a2RefreshA = refreshTokens(t, a1Refresh) - // }) - // - // t.Run("followup=second refresh from first refresh", func(t *testing.T) { - // a2RefreshB = refreshTokens(t, a1Refresh) - // }) - // - // t.Run("followup=first refresh from second refresh", func(t *testing.T) { - // b2RefreshA = refreshTokens(t, b1Refresh) - // }) - // - // t.Run("followup=second refresh from second refresh", func(t *testing.T) { - // b2RefreshB = refreshTokens(t, b1Refresh) - // }) - // - // // Sleep until the grace period is over - // time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - // t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - // _, err := conf.TokenSource(context.Background(), token).Token() - // assert.Error(t, err) - // - // for k, token := range []*oauth2.Token{ - // a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, - // } { - // t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - // i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - // assert.False(t, i.Get("active").Bool(), "%s", i) - // - // i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - // assert.False(t, i.Get("active").Bool(), "%s", i) - // - // i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - // assert.False(t, i.Get("active").Bool(), "%s", i) - // - // i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - // assert.False(t, i.Get("active").Bool(), "%s", i) - // }) - // } - // }) - //}) } t.Run("strategy=jwt", func(t *testing.T) { diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 571ac0f4520..dd7ce53305f 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -10,13 +10,14 @@ import ( "encoding/hex" "encoding/json" "fmt" - "github.com/gobuffalo/pop/v6" - "github.com/ory/x/dbal" - "go.opentelemetry.io/otel/attribute" "net/url" "strings" "time" + "go.opentelemetry.io/otel/attribute" + + "github.com/ory/x/dbal" + "github.com/ory/hydra/v2/x" "github.com/ory/x/sqlxx" @@ -463,7 +464,7 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str OAuth2RequestSQL: *req, AccessTokenSignature: sql.NullString{ Valid: true, - String: accessTokenSignature, + String: x.SignatureHash(accessTokenSignature), }, })); errors.Is(err, sqlcon.ErrConcurrentUpdate) { return fosite.ErrSerializationFailure.WithWrap(err) @@ -478,41 +479,29 @@ func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) - request, _, err = p.getPotentiallyGracefulRefreshToken(ctx, p.QueryWithNetwork(ctx).Where("signature = ?", signature), session) - return request, err -} - -func (p *Persister) getPotentiallyGracefulRefreshToken(ctx context.Context, q *pop.Query, session fosite.Session) (_ fosite.Requester, _ *OAuth2RefreshTable, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.getPotentiallyGracefulRefreshToken") - defer otelx.End(span, &err) - var row OAuth2RefreshTable - if err := q.First(&row); errors.Is(err, sql.ErrNoRows) { - return nil, nil, errorsx.WithStack(fosite.ErrNotFound) + if err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&row); errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) } else if err != nil { - return nil, nil, sqlcon.HandleError(err) + return nil, sqlcon.HandleError(err) } fositeRequest, err := row.toRequest(ctx, session, p) if err != nil { - return nil, nil, err - } - - if row.Active { - return fositeRequest, &row, nil - } - - if !row.FirstUsedAt.Valid { - return fositeRequest, &row, errors.WithStack(fosite.ErrInactiveToken) + return nil, err } gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx) - if gracePeriod < 1 { - return fositeRequest, &row, errors.WithStack(fosite.ErrInactiveToken) + if row.Active { + return fositeRequest, nil + } else if row.FirstUsedAt.Valid && + row.FirstUsedAt.Time.Add(gracePeriod).After(time.Now()) && + gracePeriod > 0 { + // There is ONLY ONE other case where this token is not active: the token has been used before and is now in the grace period. + return fositeRequest, nil } - row.Active = true // We set active to true because we are in the grace period. - return fositeRequest, &row, err + return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) } func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { @@ -563,7 +552,7 @@ func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature stri func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshToken") defer otelx.End(span, &err) - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + return p.deleteSessionByRequestID(ctx, id, sqlTableRefresh) } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { @@ -651,7 +640,6 @@ func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string) defer otelx.End(span, &err) c := p.Connection(ctx) - now := time.Now().UTC().Round(time.Millisecond) // In strict rotation we only have one token chain for every request. Therefore, we remove all // access tokens associated with the request ID. @@ -661,8 +649,7 @@ func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string) // The same applies to refresh tokens in strict mode. We disable all old refresh tokens when rotating. return sqlcon.HandleError(c.RawQuery( - "UPDATE hydra_oauth2_refresh SET active=false, first_used_at = ? WHERE request_id=? AND nid = ? AND active", - now, + "UPDATE hydra_oauth2_refresh SET active=false WHERE request_id=? AND nid = ? AND active", requestID, p.NetworkID(ctx), ).Exec()) diff --git a/spec/config.json b/spec/config.json index 72f81534c66..effd1cc866d 100644 --- a/spec/config.json +++ b/spec/config.json @@ -1071,9 +1071,9 @@ "refresh_token": { "type": "object", "properties": { - "grace_period": { + "rotation_grace_period": { "title": "Refresh Token Rotation Grace Period", - "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is one hour.", + "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is 5 minutes.", "default": "0s", "allOf": [ {