diff --git a/go.mod b/go.mod index 8092bf8a20..809bb14386 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,9 @@ replace github.com/gobuffalo/pop/v6 => github.com/ory/pop/v6 v6.2.0 // // This is needed until we release the next version of the master branch, as that branch already contains the redirect URI validation fix, which // may be breaking for some users. -replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 +//replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 + +replace github.com/ory/fosite => ../fosite require ( github.com/ThalesIgnite/crypto11 v1.2.5 diff --git a/go.sum b/go.sum index a81e9e76d2..2e7f7a6fb0 100644 --- a/go.sum +++ b/go.sum @@ -378,8 +378,6 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk= -github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 h1:X0ngo+uPWCw90ueY3Kh6q8IyF2fbwkJ8bf9RvAmD71U= -github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902/go.mod h1:AZyn1jrABUaGN12RHcWorRLbqLn52gTdHaIYY81m5J0= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI= github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8= diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index a106006df7..19407ccf56 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -166,8 +166,14 @@ func acceptConsentHandler(t *testing.T, c *client.Client, adminClient *hydra.API // - [x] If `id_token_hint` is handled properly // - [x] What happens if `id_token_hint` does not match the value from the handled authentication request ("accept login") func TestAuthCodeWithDefaultStrategy(t *testing.T) { + setupRegistries(t) + ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := internal.NewRegistrySQLFromURL(t, registries["cockroach"].Config().DSN(), true, &contextx.Default{}) + + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OpenIDConnectKeyName)) + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OAuth2JWTKeyName)) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) @@ -337,9 +343,15 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=graceful token rotation", func(t *testing.T) { run := func(t *testing.T, strategy string) { - reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, time.Second*3) + reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, time.Minute) + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, time.Minute) t.Cleanup(func() { reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, nil) + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, nil) }) c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) @@ -364,7 +376,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) + token.Expiry = time.Now().Add(-time.Hour * 24) iat := time.Now() refreshedToken, err := conf.TokenSource(context.Background(), token).Token() require.NoError(t, err) @@ -380,40 +392,63 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { return refreshedToken } - t.Run("followup=graceful token refresh with reuse detection", func(t *testing.T) { - start := time.Now() + t.Run("followup=graceful token refresh can handle concurrent refreshing", func(t *testing.T) { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) - token := issueTokens(t) - var first, second *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - first = refreshTokens(t, token) - }) + var wg sync.WaitGroup + refresh := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = time.Now().Add(-time.Hour * 24) + tt, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + return tt + } - t.Run("followup=second refresh", func(t *testing.T) { - second = refreshTokens(t, token) - }) + refreshes := make([]*oauth2.Token, 5) + for k := range refreshes { + wg.Add(1) + //time.Sleep(time.Millisecond * 100) + go func(k int) { + defer wg.Done() + t.Logf("Refreshing token %d", k) + refreshes[k] = refresh(t, token) + }(k) + } - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + wg.Wait() - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + for k, actual := range refreshes { + refresh := actual - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + require.NotEmpty(t, refresh.RefreshToken) + require.NotEmpty(t, refresh.AccessToken) + require.NotEmpty(t, refresh.Extra("id_token")) - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i := testhelpers.IntrospectToken(t, conf, refresh.AccessToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i = testhelpers.IntrospectToken(t, conf, refresh.RefreshToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) + } + + t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + _, err = conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + for k, actual := range refreshes { + i := testhelpers.IntrospectToken(t, conf, actual.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "token %d: %s", k, i) + } }) }) - t.Run("followup=graceful token refresh with reuse detection with consent revocation", func(t *testing.T) { + t.Run("followup=graceful token refresh with reuse detection", func(t *testing.T) { start := time.Now() token := issueTokens(t) @@ -426,66 +461,36 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { second = refreshTokens(t, token) }) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) - require.NoError(t, err) - - _, err = conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) - + t.Run("followup=all resulting tokens are valid", func(t *testing.T) { i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) }) - }) - - t.Run("followup=graceful token refresh can handle concurrent refreshing", func(t *testing.T) { - start := time.Now() - token := issueTokens(t) - var first, second *oauth2.Token - var wg sync.WaitGroup - refreshes := make([]*oauth2.Token, 5) - for k := range refreshes { - wg.Add(1) - go func(k int) { - defer wg.Done() - t.Logf("Refreshing token %d", k) - refreshes[k] = refreshTokens(t, token) - }(k) - } + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(4 * time.Second))) - wg.Wait() - for k, refresh := range refreshes { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - iat := time.Now() - introspectAccessToken(t, conf, refresh, subject) - assertJWTAccessToken(t, strategy, conf, refresh, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, refresh, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, refresh, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - }) - } + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + // Fetching the token again should cause an error because we are no longer in the grace period. + token.Expiry = time.Now().Add(-time.Hour * 24) + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) - require.NoError(t, err) + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - _, err = conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + result, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err, "%+v", result) - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + i = testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) assert.False(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) @@ -498,60 +503,92 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.False(t, i.Get("active").Bool(), "%s", i) }) }) - - t.Run("followup=graceful refresh tokens with multiple nested branches belong to the same request", func(t *testing.T) { - start := time.Now() - token := issueTokens(t) - var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - a1Refresh = refreshTokens(t, token) - }) - - t.Run("followup=second refresh", func(t *testing.T) { - b1Refresh = refreshTokens(t, token) - }) - - t.Run("followup=first refresh from first refresh", func(t *testing.T) { - a2RefreshA = refreshTokens(t, a1Refresh) - }) - - t.Run("followup=second refresh from first refresh", func(t *testing.T) { - a2RefreshB = refreshTokens(t, a1Refresh) - }) - - t.Run("followup=first refresh from second refresh", func(t *testing.T) { - b2RefreshA = refreshTokens(t, b1Refresh) - }) - - t.Run("followup=second refresh from second refresh", func(t *testing.T) { - b2RefreshB = refreshTokens(t, b1Refresh) - }) - - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) - - for k, token := range []*oauth2.Token{ - a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) - } - }) - }) + // + //t.Run("followup=graceful token refresh with reuse detection with consent revocation", func(t *testing.T) { + // token := issueTokens(t) + // var first, second *oauth2.Token + // t.Run("followup=first refresh", func(t *testing.T) { + // first = refreshTokens(t, token) + // }) + // + // t.Run("followup=second refresh", func(t *testing.T) { + // second = refreshTokens(t, token) + // }) + // + // t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + // err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + // require.NoError(t, err) + // + // _, err = conf.TokenSource(context.Background(), token).Token() + // assert.Error(t, err) + // + // i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // }) + //}) + // + //t.Run("followup=graceful refresh tokens with multiple nested branches belong to the same request", func(t *testing.T) { + // start := time.Now() + // token := issueTokens(t) + // var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token + // t.Run("followup=first refresh", func(t *testing.T) { + // a1Refresh = refreshTokens(t, token) + // }) + // + // t.Run("followup=second refresh", func(t *testing.T) { + // b1Refresh = refreshTokens(t, token) + // }) + // + // t.Run("followup=first refresh from first refresh", func(t *testing.T) { + // a2RefreshA = refreshTokens(t, a1Refresh) + // }) + // + // t.Run("followup=second refresh from first refresh", func(t *testing.T) { + // a2RefreshB = refreshTokens(t, a1Refresh) + // }) + // + // t.Run("followup=first refresh from second refresh", func(t *testing.T) { + // b2RefreshA = refreshTokens(t, b1Refresh) + // }) + // + // t.Run("followup=second refresh from second refresh", func(t *testing.T) { + // b2RefreshB = refreshTokens(t, b1Refresh) + // }) + // + // // Sleep until the grace period is over + // time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + // t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + // _, err := conf.TokenSource(context.Background(), token).Token() + // assert.Error(t, err) + // + // for k, token := range []*oauth2.Token{ + // a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, + // } { + // t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + // i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // }) + // } + // }) + //}) } t.Run("strategy=jwt", func(t *testing.T) { @@ -1535,21 +1572,17 @@ func createVCProofJWT(t *testing.T, pubKey *jose.JSONWebKey, privKey any, nonce // - [x] should pass with prompt=login when authentication time is recent // - [x] should fail with prompt=login when authentication time is in the past func TestAuthCodeWithMockStrategy(t *testing.T) { - setupRegistries(t) - - for k := range map[string]driver.Registry{ - "cockroach": registries["cockroach"], - } { - t.Run("registry="+k, func(t *testing.T) { - ctx := context.Background() - reg := internal.NewRegistrySQLFromURL(t, registries[k].Config().DSN(), true, &contextx.Default{}) - - require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OpenIDConnectKeyName)) - require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OAuth2JWTKeyName)) - - conf := reg.Config() + ctx := context.Background() + for _, strat := range []struct{ d string }{{d: "opaque"}, {d: "jwt"}} { + t.Run("strategy="+strat.d, func(t *testing.T) { + conf := internal.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Second*2) conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") + conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) + reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) + consentStrategy := &consentMock{} router := x.NewRouterPublic() ts := httptest.NewServer(router) @@ -1587,531 +1620,528 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { Scopes: []string{"hydra.*", "offline", "openid"}, } - for _, strat := range []struct{ d string }{{d: "opaque"}, {d: "jwt"}} { - conf := reg.Config() - conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) - - t.Run("strategy="+strat.d, func(t *testing.T) { - var code string - for k, tc := range []struct { - cj http.CookieJar - d string - cb func(t *testing.T) httprouter.Handle - authURL string - shouldPassConsentStrategy bool - expectOAuthAuthError bool - expectOAuthTokenError bool - checkExpiry bool - authTime time.Time - requestTime time.Time - assertAccessToken func(*testing.T, string) - }{ - { - d: "should pass request if strategy passes", - authURL: oauthConfig.AuthCodeURL("some-foo-state"), - shouldPassConsentStrategy: true, - checkExpiry: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - assertAccessToken: func(t *testing.T, token string) { - if strat.d != "jwt" { - return - } + var code string + for k, tc := range []struct { + cj http.CookieJar + d string + cb func(t *testing.T) httprouter.Handle + authURL string + shouldPassConsentStrategy bool + expectOAuthAuthError bool + expectOAuthTokenError bool + checkExpiry bool + authTime time.Time + requestTime time.Time + assertAccessToken func(*testing.T, string) + }{ + { + d: "should pass request if strategy passes", + authURL: oauthConfig.AuthCodeURL("some-foo-state"), + shouldPassConsentStrategy: true, + checkExpiry: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + assertAccessToken: func(t *testing.T, token string) { + if strat.d != "jwt" { + return + } - body, err := x.DecodeSegment(strings.Split(token, ".")[1]) - require.NoError(t, err) + body, err := x.DecodeSegment(strings.Split(token, ".")[1]) + require.NoError(t, err) - data := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &data)) - - assert.EqualValues(t, "app-client", data["client_id"]) - assert.EqualValues(t, "foo", data["sub"]) - assert.NotEmpty(t, data["iss"]) - assert.NotEmpty(t, data["jti"]) - assert.NotEmpty(t, data["exp"]) - assert.NotEmpty(t, data["iat"]) - assert.NotEmpty(t, data["nbf"]) - assert.EqualValues(t, data["nbf"], data["iat"]) - assert.EqualValues(t, []interface{}{"offline", "openid", "hydra.*"}, data["scp"]) - }, - }, - { - d: "should fail because prompt=none and max_age > auth_time", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=1", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - err := r.URL.Query().Get("error") - require.Empty(t, code) - require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) - } - }, - expectOAuthAuthError: true, - }, - { - d: "should pass because prompt=none and max_age is less than auth_time", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=3600", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - }, - { - d: "should fail because prompt=none but auth_time suggests recent authentication", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC().Add(-time.Hour), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - err := r.URL.Query().Get("error") - require.Empty(t, code) - require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) - } - }, - expectOAuthAuthError: true, - }, - { - d: "should fail because consent strategy fails", - authURL: oauthConfig.AuthCodeURL("some-foo-state"), - expectOAuthAuthError: true, - shouldPassConsentStrategy: false, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - require.Empty(t, r.URL.Query().Get("code")) - assert.Equal(t, fosite.ErrRequestForbidden.Error(), r.URL.Query().Get("error")) - } - }, - }, - { - d: "should pass with prompt=login when authentication time is recent", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", - authTime: time.Now().UTC().Add(-time.Second), - requestTime: time.Now().UTC().Add(-time.Minute), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - }, - { - d: "should fail with prompt=login when authentication time is in the past", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - expectOAuthAuthError: true, - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.Empty(t, code) - assert.Equal(t, fosite.ErrLoginRequired.Error(), r.URL.Query().Get("error")) - } - }, - }, - } { - t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - if tc.cb == nil { - tc.cb = noopHandler - } + data := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &data)) + + assert.EqualValues(t, "app-client", data["client_id"]) + assert.EqualValues(t, "foo", data["sub"]) + assert.NotEmpty(t, data["iss"]) + assert.NotEmpty(t, data["jti"]) + assert.NotEmpty(t, data["exp"]) + assert.NotEmpty(t, data["iat"]) + assert.NotEmpty(t, data["nbf"]) + assert.EqualValues(t, data["nbf"], data["iat"]) + assert.EqualValues(t, []interface{}{"offline", "openid", "hydra.*"}, data["scp"]) + }, + }, + { + d: "should fail because prompt=none and max_age > auth_time", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=1", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + err := r.URL.Query().Get("error") + require.Empty(t, code) + require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) + } + }, + expectOAuthAuthError: true, + }, + { + d: "should pass because prompt=none and max_age is less than auth_time", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=3600", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + }, + { + d: "should fail because prompt=none but auth_time suggests recent authentication", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC().Add(-time.Hour), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + err := r.URL.Query().Get("error") + require.Empty(t, code) + require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) + } + }, + expectOAuthAuthError: true, + }, + { + d: "should fail because consent strategy fails", + authURL: oauthConfig.AuthCodeURL("some-foo-state"), + expectOAuthAuthError: true, + shouldPassConsentStrategy: false, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + require.Empty(t, r.URL.Query().Get("code")) + assert.Equal(t, fosite.ErrRequestForbidden.Error(), r.URL.Query().Get("error")) + } + }, + }, + { + d: "should pass with prompt=login when authentication time is recent", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", + authTime: time.Now().UTC().Add(-time.Second), + requestTime: time.Now().UTC().Add(-time.Minute), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + }, + { + d: "should fail with prompt=login when authentication time is in the past", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + expectOAuthAuthError: true, + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.Empty(t, code) + assert.Equal(t, fosite.ErrLoginRequired.Error(), r.URL.Query().Get("error")) + } + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + if tc.cb == nil { + tc.cb = noopHandler + } - consentStrategy.deny = !tc.shouldPassConsentStrategy - consentStrategy.authTime = tc.authTime - consentStrategy.requestTime = tc.requestTime + consentStrategy.deny = !tc.shouldPassConsentStrategy + consentStrategy.authTime = tc.authTime + consentStrategy.requestTime = tc.requestTime - cb := tc.cb(t) - callbackHandler = &cb + cb := tc.cb(t) + callbackHandler = &cb - req, err := http.NewRequest("GET", tc.authURL, nil) - require.NoError(t, err) + req, err := http.NewRequest("GET", tc.authURL, nil) + require.NoError(t, err) - if tc.cj == nil { - tc.cj = testhelpers.NewEmptyCookieJar(t) - } + if tc.cj == nil { + tc.cj = testhelpers.NewEmptyCookieJar(t) + } + + resp, err := (&http.Client{Jar: tc.cj}).Do(req) + require.NoError(t, err, tc.authURL, ts.URL) + defer resp.Body.Close() + + if tc.expectOAuthAuthError { + require.Empty(t, code) + return + } + + require.NotEmpty(t, code) + + token, err := oauthConfig.Exchange(context.TODO(), code) + if tc.expectOAuthTokenError { + require.Error(t, err) + return + } + + require.NoError(t, err, code) + if tc.assertAccessToken != nil { + tc.assertAccessToken(t, token.AccessToken) + } + + t.Run("case=userinfo", func(t *testing.T) { + var makeRequest = func(req *http.Request) *http.Response { + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + return resp + } - resp, err := (&http.Client{Jar: tc.cj}).Do(req) - require.NoError(t, err, tc.authURL, ts.URL) + var testSuccess = func(response *http.Response) { defer resp.Body.Close() - if tc.expectOAuthAuthError { - require.Empty(t, code) - return - } + require.Equal(t, http.StatusOK, resp.StatusCode) - require.NotEmpty(t, code) + var claims map[string]interface{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&claims)) + assert.Equal(t, "foo", claims["sub"]) + } - token, err := oauthConfig.Exchange(context.TODO(), code) - if tc.expectOAuthTokenError { - require.Error(t, err) - return - } + req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + testSuccess(makeRequest(req)) - require.NoError(t, err, code) - if tc.assertAccessToken != nil { - tc.assertAccessToken(t, token.AccessToken) - } + req, err = http.NewRequest("POST", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + testSuccess(makeRequest(req)) - t.Run("case=userinfo", func(t *testing.T) { - var makeRequest = func(req *http.Request) *http.Response { - resp, err = http.DefaultClient.Do(req) - require.NoError(t, err) - return resp - } + req, err = http.NewRequest("POST", ts.URL+"/userinfo", bytes.NewBuffer([]byte("access_token="+token.AccessToken))) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + testSuccess(makeRequest(req)) - var testSuccess = func(response *http.Response) { - defer resp.Body.Close() + req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer asdfg") + resp := makeRequest(req) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) - require.Equal(t, http.StatusOK, resp.StatusCode) + res, err := testRefresh(t, token, ts.URL, tc.checkExpiry) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) - var claims map[string]interface{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&claims)) - assert.Equal(t, "foo", claims["sub"]) - } + body, err := io.ReadAll(res.Body) + require.NoError(t, err) - req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - testSuccess(makeRequest(req)) + var refreshedToken oauth2.Token + require.NoError(t, json.Unmarshal(body, &refreshedToken)) - req, err = http.NewRequest("POST", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - testSuccess(makeRequest(req)) + if tc.assertAccessToken != nil { + tc.assertAccessToken(t, refreshedToken.AccessToken) + } - req, err = http.NewRequest("POST", ts.URL+"/userinfo", bytes.NewBuffer([]byte("access_token="+token.AccessToken))) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - testSuccess(makeRequest(req)) + t.Run("the tokens should be different", func(t *testing.T) { + if strat.d != "jwt" { + t.Skip() + } - req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer asdfg") - resp := makeRequest(req) - require.Equal(t, http.StatusUnauthorized, resp.StatusCode) - }) + body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1]) + require.NoError(t, err) - res, err := testRefresh(t, token, ts.URL, tc.checkExpiry) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + origPayload := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &origPayload)) - body, err := io.ReadAll(res.Body) - require.NoError(t, err) + body, err = x.DecodeSegment(strings.Split(refreshedToken.AccessToken, ".")[1]) + require.NoError(t, err) - var refreshedToken oauth2.Token - require.NoError(t, json.Unmarshal(body, &refreshedToken)) + refreshedPayload := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &refreshedPayload)) - if tc.assertAccessToken != nil { - tc.assertAccessToken(t, refreshedToken.AccessToken) - } + if tc.checkExpiry { + assert.NotEqual(t, refreshedPayload["exp"], origPayload["exp"]) + assert.NotEqual(t, refreshedPayload["iat"], origPayload["iat"]) + assert.NotEqual(t, refreshedPayload["nbf"], origPayload["nbf"]) + } + assert.NotEqual(t, refreshedPayload["jti"], origPayload["jti"]) + assert.Equal(t, refreshedPayload["client_id"], origPayload["client_id"]) + }) - t.Run("the tokens should be different", func(t *testing.T) { - if strat.d != "jwt" { - t.Skip() - } + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1]) - require.NoError(t, err) + t.Run("old token should no longer be usable", func(t *testing.T) { + req, err := http.NewRequest("GET", ts.URL+"/userinfo", nil) + require.NoError(t, err) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.EqualValues(t, http.StatusUnauthorized, res.StatusCode) + }) - origPayload := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &origPayload)) + t.Run("refreshing new refresh token should work", func(t *testing.T) { + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) - body, err = x.DecodeSegment(strings.Split(refreshedToken.AccessToken, ".")[1]) - require.NoError(t, err) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &refreshedToken)) + }) - refreshedPayload := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &refreshedPayload)) + t.Run("should call refresh token hook if configured", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} + expectedSubject := "foo" + + exceptKeys := []string{ + "session.kid", + "session.id_token.expires_at", + "session.id_token.headers.extra.kid", + "session.id_token.id_token_claims.iat", + "session.id_token.id_token_claims.exp", + "session.id_token.id_token_claims.rat", + "session.id_token.id_token_claims.auth_time", + } + + if hookType == "legacy" { + var hookReq hydraoauth2.RefreshTokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.Equal(t, hookReq.Subject, expectedSubject) + require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) + require.Equal(t, hookReq.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Subject, expectedSubject) + require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Requester) + require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) + require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) + + snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) + } else { + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Subject, expectedSubject) + require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Request) + require.Equal(t, hookReq.Request.ClientID, oauthConfig.ClientID) + require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) + require.Equal(t, hookReq.Request.Payload, map[string][]string{"grant_type": {"refresh_token"}}) + + snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) + } + + claims := map[string]interface{}{ + "hooked": hookType, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: flow.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } - if tc.checkExpiry { - assert.NotEqual(t, refreshedPayload["exp"], origPayload["exp"]) - assert.NotEqual(t, refreshedPayload["iat"], origPayload["iat"]) - assert.NotEqual(t, refreshedPayload["nbf"], origPayload["nbf"]) + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { + conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { + conf.MustSet(ctx, config.KeyTokenHook, nil) + }) } - assert.NotEqual(t, refreshedPayload["jti"], origPayload["jti"]) - assert.Equal(t, refreshedPayload["client_id"], origPayload["client_id"]) - }) - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) - t.Run("old token should no longer be usable", func(t *testing.T) { - req, err := http.NewRequest("GET", ts.URL+"/userinfo", nil) + body, err := io.ReadAll(res.Body) require.NoError(t, err) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - res, err := http.DefaultClient.Do(req) + require.NoError(t, json.Unmarshal(body, &refreshedToken)) + + accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + require.Equal(t, accessTokenClaims.Get("ext.hooked").String(), hookType) + + idTokenBody, err := x.DecodeSegment( + strings.Split( + gjson.GetBytes(body, "id_token").String(), + ".", + )[1], + ) require.NoError(t, err) - assert.EqualValues(t, http.StatusUnauthorized, res.StatusCode) - }) - t.Run("refreshing new refresh token should work", func(t *testing.T) { + require.Equal(t, gjson.GetBytes(idTokenBody, "hooked").String(), hookType) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should not override session data if token refresh hook returns no content", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) + } + + origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) - body, err := io.ReadAll(res.Body) + body, err = io.ReadAll(res.Body) require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &refreshedToken)) - }) - - t.Run("should call refresh token hook if configured", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") - - expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} - expectedSubject := "foo" - - exceptKeys := []string{ - "session.kid", - "session.id_token.expires_at", - "session.id_token.headers.extra.kid", - "session.id_token.id_token_claims.iat", - "session.id_token.id_token_claims.exp", - "session.id_token.id_token_claims.rat", - "session.id_token.id_token_claims.auth_time", - } - - if hookType == "legacy" { - var hookReq hydraoauth2.RefreshTokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.Equal(t, hookReq.Subject, expectedSubject) - require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) - require.Equal(t, hookReq.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, hookReq.Session.Subject, expectedSubject) - require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) - require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) - - snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) - } else { - var hookReq hydraoauth2.TokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, hookReq.Session.Subject, expectedSubject) - require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Request) - require.Equal(t, hookReq.Request.ClientID, oauthConfig.ClientID) - require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) - require.Equal(t, hookReq.Request.Payload, map[string][]string{"grant_type": {"refresh_token"}}) - - snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) - } - - claims := map[string]interface{}{ - "hooked": hookType, - } - - hookResp := hydraoauth2.TokenHookResponse{ - Session: flow.AcceptOAuth2ConsentRequestSession{ - AccessToken: claims, - IDToken: claims, - }, - } - - w.WriteHeader(http.StatusOK) - require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - - accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - require.Equal(t, accessTokenClaims.Get("ext.hooked").String(), hookType) - - idTokenBody, err := x.DecodeSegment( - strings.Split( - gjson.GetBytes(body, "id_token").String(), - ".", - )[1], - ) - require.NoError(t, err) - - require.Equal(t, gjson.GetBytes(idTokenBody, "hooked").String(), hookType) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should not override session data if token refresh hook returns no content", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - body, err = io.ReadAll(res.Body) - require.NoError(t, err) - - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - - refreshedAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - assertx.EqualAsJSONExcept(t, json.RawMessage(origAccessTokenClaims.Raw), json.RawMessage(refreshedAccessTokenClaims.Raw), []string{"exp", "iat", "nbf"}) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `server_error` if refresh hook fails", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "An error occurred while executing the token hook.", errBody.Description) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `access_denied` if legacy refresh hook denied the request", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusForbidden, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) - require.Equal(t, "The token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) - } + + refreshedAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + assertx.EqualAsJSONExcept(t, json.RawMessage(origAccessTokenClaims.Raw), json.RawMessage(refreshedAccessTokenClaims.Raw), []string{"exp", "iat", "nbf"}) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `server_error` if refresh hook fails", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `server_error` if refresh hook response is malformed", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "The token hook target responded with an error.", errBody.Description) - } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) + require.Equal(t, "An error occurred while executing the token hook.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `access_denied` if legacy refresh hook denied the request", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - t.Run("refreshing old token should no longer work", func(t *testing.T) { - res, err := testRefresh(t, token, ts.URL, false) + res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) - }) + assert.Equal(t, http.StatusForbidden, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) + require.Equal(t, "The token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `server_error` if refresh hook response is malformed", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) + } - t.Run("attempt to refresh old token should revoke new token", func(t *testing.T) { res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) - }) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - t.Run("duplicate code exchange fails", func(t *testing.T) { - token, err := oauthConfig.Exchange(context.TODO(), code) - require.Error(t, err) - require.Nil(t, token) - }) + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) + require.Equal(t, "The token hook target responded with an error.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) - code = "" - }) - } + t.Run("refreshing old token should no longer work", func(t *testing.T) { + res, err := testRefresh(t, token, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("attempt to refresh old token should revoke new token", func(t *testing.T) { + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("duplicate code exchange fails", func(t *testing.T) { + token, err := oauthConfig.Exchange(context.TODO(), code) + require.Error(t, err) + require.Nil(t, token) + }) + + code = "" }) } }) @@ -2183,6 +2213,7 @@ func newOAuth2Client( return c, &oauth2.Config{ ClientID: c.GetID(), ClientSecret: secret, + RedirectURL: callbackURL, Endpoint: oauth2.Endpoint{ AuthURL: reg.Config().OAuth2AuthURL(ctx).String(), TokenURL: reg.Config().OAuth2TokenURL(ctx).String(), diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index b67b6ae17e..071bbe0550 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -10,6 +10,9 @@ import ( "encoding/hex" "encoding/json" "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/ory/x/dbal" + "go.opentelemetry.io/otel/attribute" "net/url" "strings" "time" @@ -60,7 +63,8 @@ type ( } OAuth2RefreshTable struct { OAuth2RequestSQL - FirstUsedAt sql.NullTime `db:"first_used_at"` + FirstUsedAt sql.NullTime `db:"first_used_at"` + AccessTokenSignature sqlxx.NullString `db:"access_token_signature"` } ) @@ -452,7 +456,7 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) } -func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { +func (p *Persister) RotateRefreshToken(ctx context.Context, refreshTokenSignature string) (requestID string, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) @@ -536,6 +540,140 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) } +func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connection, requestID string, refreshSignature string, period time.Duration) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.gracefulRefreshRotation", + trace.WithAttributes( + attribute.String("request_id", requestID), + attribute.String("refresh_signature", refreshSignature), + attribute.String("network_id", p.NetworkID(ctx).String()), + attribute.String("grace_period", period.String()), + )) + defer otelx.End(span, &err) + + if p.conn.Dialect.Name() == dbal.DriverMySQL { + // MySQL does not support returning values from an update query, so we need to do two queries. + var tokensToRevoke []OAuth2RefreshTable + if err := c. + Select("access_token_signature"). + Where("request_id=? AND nid = ? AND active", id, p.NetworkID(ctx)). + Limit(500). + All(&tokensToRevoke); err != nil { + return sqlcon.HandleError(err) + } + + } + + return nil +} + +func (p *Persister) RevokeRotatedTokens(ctx context.Context, refreshSignature string) (fosite.Requester, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRotatedTokens") + defer otelx.End(span, &err) + + err = p.QueryWithNetwork(ctx). + Where("request_id=?", id). + Delete(&OAuth2RequestSQL{Table: sqlTableAccess}) + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(fosite.ErrNotFound) + } + + if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 { + return p.gracefulRefreshRotation(ctx, c, requestID, refreshSignature, gracePeriod) + } + + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err + } + + _, err := c.Where("signature = ? AND nid = ? AND active", refreshSignature, p.NetworkID(ctx)).UpdateQuery(&OAuth2RefreshTable{ + OAuth2RequestSQL: OAuth2RequestSQL{ + Active: false, + }, + FirstUsedAt: sql.NullTime{ + Time: time.Now().UTC().Round(time.Millisecond), + Valid: true, + }, + }, "active", "first_used_at") + return sqlcon.HandleError(err) + }); err != nil { + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return err + } + + return nil + + /* #nosec G201 table is static */ + + /* + err = p.QueryWithNetwork(ctx). + Where("request_id=?", id). + Delete(&OAuth2RequestSQL{Table: sqlTableAccess}) + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(fosite.ErrNotFound) + } + + p.Transaction(ctx, func(ctx context.Context, c *sql.Tx) error { + type tokens []revokingRefreshToken + if p.conn.Dialect.Name() == dbal.DriverMySQL { + // MySQL does not support returning values from an update query, so we need to do two queries. + var t tokens + if err := p.Connection(ctx).Where("request_id=? AND nid = ? AND active", id, p.NetworkID(ctx)).Limit(500).All(&t); err != nil { + return sqlcon.HandleError(err) + } + + } else { + + } + + }) + + if err := p.Connection(ctx).RawQuery(` + SELECT access_token_signature, signature, first_used_at + FROM hydra_oauth2_refresh + WHERE request_id=? AND nid = ? AND active + ORDER BY signature LIMIT 500 + `).All(&tokens); err != nil { + return err + } + + p.Connection(ctx).Where("signature IN (?)", _).Limit(500).UpdateQuery(&OAuth2RequestSQL{ + + Table: sqlTableRefresh, + }, + "active", "first_used_at") + p.Connection(ctx).RawQuery(` + UPDATE hydra_oauth2_refresh + SET active=false, first_used_at = CURRENT_TIMESTAMP + WHERE signature in = (?) + LIMIT 500 + `).All(&t) + + p.Connection(ctx).RawQuery(` + UPDATE hydra_oauth2_refresh + SET active=false, first_used_at = CURRENT_TIMESTAMP + WHERE request_id=? AND nid = ? AND active + RETURNING access_token_signature + LIMIT 500 + `).All(&t) + + // mysql: + // "GET access_token_signature, id (?) WHERE request_id=? AND nid = ? AND active A RETURNING access_token_signature ORDER BY signature LIMIT 500" + // "UPDATE ... SET ..." + + // others: + // "UPDATE refresh SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active A RETURNING access_token_signature ORDER BY signature LIMIT 500" + + // "UPDATE access SET active=false WHERE request_id=? AND signature IN (?) LIMIT 500" + */ + /* #nosec G201 table is static */ +} + func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") defer otelx.End(span, &err) @@ -544,7 +682,7 @@ func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id s return sqlcon.HandleError( p.Connection(ctx). RawQuery( - fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), + fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active LIMIT 500", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), id, p.NetworkID(ctx), ).