diff --git a/.docker/Dockerfile-hsm b/.docker/Dockerfile-hsm index 39cd1b1ad99..c4199fe87e4 100644 --- a/.docker/Dockerfile-hsm +++ b/.docker/Dockerfile-hsm @@ -17,20 +17,32 @@ COPY . . ############################### -FROM builder as build-hydra +FROM builder AS build-hydra RUN go build -tags sqlite,hsm -o /usr/bin/hydra ############################### -FROM builder as test-hsm +FROM builder AS test-hsm ENV HSM_ENABLED=true ENV HSM_LIBRARY=/usr/lib/softhsm/libsofthsm2.so ENV HSM_TOKEN_LABEL=hydra ENV HSM_PIN=1234 -RUN apt-get -y install softhsm opensc &&\ - pkcs11-tool --module "$HSM_LIBRARY" --slot 0 --init-token --so-pin 0000 --init-pin --pin "$HSM_PIN" --label "$HSM_TOKEN_LABEL" &&\ - go test -p 1 -v -failfast -short -tags=sqlite,hsm ./... +RUN apt-get -y install softhsm opensc +RUN pkcs11-tool --module "$HSM_LIBRARY" --slot 0 --init-token --so-pin 0000 --init-pin --pin "$HSM_PIN" --label "$HSM_TOKEN_LABEL" +RUN go test -p 1 -failfast -short -tags=sqlite,hsm ./... + + +FROM builder AS test-refresh-hsm +ENV HSM_ENABLED=true +ENV HSM_LIBRARY=/usr/lib/softhsm/libsofthsm2.so +ENV HSM_TOKEN_LABEL=hydra +ENV HSM_PIN=1234 +ENV UPDATE_SNAPSHOTS=true + +RUN apt-get -y install softhsm opensc +RUN pkcs11-tool --module "$HSM_LIBRARY" --slot 0 --init-token --so-pin 0000 --init-pin --pin "$HSM_PIN" --label "$HSM_TOKEN_LABEL" +RUN go test -p 1 -failfast -short -tags=sqlite,hsm,refresh ./... ############################### diff --git a/.schema/config.schema.json b/.schema/config.schema.json index bc1d1476c08..804e6b6024f 100644 --- a/.schema/config.schema.json +++ b/.schema/config.schema.json @@ -1101,11 +1101,11 @@ "examples": ["https://my-example.app/token-refresh-hook"], "oneOf": [ { - "type": "string", - "format": "uri" + "$ref": "#/definitions/webhook_config" }, { - "$ref": "#/definitions/webhook_config" + "type": "string", + "format": "uri" } ] }, @@ -1114,11 +1114,11 @@ "examples": ["https://my-example.app/token-hook"], "oneOf": [ { - "type": "string", - "format": "uri" + "$ref": "#/definitions/webhook_config" }, { - "$ref": "#/definitions/webhook_config" + "type": "string", + "format": "uri" } ] } diff --git a/Makefile b/Makefile index 75b912e0521..49c66ec5a71 100644 --- a/Makefile +++ b/Makefile @@ -90,9 +90,10 @@ quicktest: quicktest-hsm: DOCKER_BUILDKIT=1 DOCKER_CONTENT_TRUST=1 docker build --progress=plain -f .docker/Dockerfile-hsm --target test-hsm -t oryd/hydra:${IMAGE_TAG} --target test-hsm . -.PHONY: refresh -refresh: +.PHONY: test-refresh +test-refresh: UPDATE_SNAPSHOTS=true go test -failfast -short -tags sqlite,sqlite_omit_load_extension ./... + DOCKER_BUILDKIT=1 DOCKER_CONTENT_TRUST=1 docker build --progress=plain -f .docker/Dockerfile-hsm --target test-refresh-hsm -t oryd/hydra:${IMAGE_TAG} --target test-refresh-hsm . authors: # updates the AUTHORS file curl https://raw.githubusercontent.com/ory/ci/master/authors/authors.sh | env PRODUCT="Ory Hydra" bash diff --git a/aead/aead_test.go b/aead/aead_test.go index 4cb93f5c3e7..d1b614710a2 100644 --- a/aead/aead_test.go +++ b/aead/aead_test.go @@ -10,13 +10,14 @@ import ( "io" "testing" - "github.com/ory/hydra/v2/aead" - "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" + "github.com/ory/hydra/v2/internal/testhelpers" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/ory/hydra/v2/aead" + "github.com/ory/hydra/v2/driver/config" ) func secret(t *testing.T) string { @@ -43,7 +44,7 @@ func TestAEAD(t *testing.T) { t.Run("case=without-rotation", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) a := NewCipher(c) @@ -63,7 +64,7 @@ func TestAEAD(t *testing.T) { t.Run("case=wrong-secret", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) a := NewCipher(c) @@ -78,7 +79,7 @@ func TestAEAD(t *testing.T) { t.Run("case=with-rotation", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() old := secret(t) c.MustSet(ctx, config.KeyGetSystemSecret, []string{old}) a := NewCipher(c) @@ -106,7 +107,7 @@ func TestAEAD(t *testing.T) { t.Run("case=with-rotation-wrong-secret", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) a := NewCipher(c) @@ -123,7 +124,7 @@ func TestAEAD(t *testing.T) { t.Run("suite=with additional data", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) a := NewCipher(c) diff --git a/client/handler_test.go b/client/handler_test.go index 3047ad4c87b..8e27caea754 100644 --- a/client/handler_test.go +++ b/client/handler_test.go @@ -35,7 +35,6 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/client" - "github.com/ory/hydra/v2/internal" ) type responseSnapshot struct { @@ -56,7 +55,7 @@ func getClientID(body string) string { func TestHandler(t *testing.T) { ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) h := client.NewHandler(reg) reg.WithContextualizer(&contextx.TestContextualizer{}) diff --git a/client/sdk_test.go b/client/sdk_test.go index 9db7ab7cddb..ad3193108ad 100644 --- a/client/sdk_test.go +++ b/client/sdk_test.go @@ -9,6 +9,8 @@ import ( "strings" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/x/assertx" "github.com/ory/x/ioutilx" @@ -26,8 +28,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/hydra/v2/internal" - hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" ) @@ -63,11 +63,11 @@ var defaultIgnoreFields = []string{"client_id", "registration_access_token", "re func TestClientSDK(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeySubjectTypesSupported, []string{"public"}) conf.MustSet(ctx, config.KeyDefaultClientScope, []string{"foo", "bar"}) conf.MustSet(ctx, config.KeyPublicAllowDynamicRegistration, true) - r := internal.NewRegistryMemory(t, conf, &contextx.Static{C: conf.Source(ctx)}) + r := testhelpers.NewRegistryMemory(t, conf, &contextx.Static{C: conf.Source(ctx)}) routerAdmin := x.NewRouterAdmin(conf.AdminURL) routerPublic := x.NewRouterPublic() diff --git a/client/validator_test.go b/client/validator_test.go index 09f69b26e30..4efe866d5a9 100644 --- a/client/validator_test.go +++ b/client/validator_test.go @@ -12,6 +12,8 @@ import ( "strings" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/hashicorp/go-retryablehttp" "github.com/ory/fosite" @@ -24,17 +26,16 @@ import ( . "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" ) func TestValidate(t *testing.T) { ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeySubjectTypesSupported, []string{"pairwise", "public"}) c.MustSet(ctx, config.KeyDefaultClientScope, []string{"openid"}) - reg := internal.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) + reg := testhelpers.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) v := NewValidator(reg) testCtx := context.TODO() @@ -186,7 +187,7 @@ func (f *fakeHTTP) HTTPClient(ctx context.Context, opts ...httpx.ResilientOption } func TestValidateSectorIdentifierURL(t *testing.T) { - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) var payload string var h http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { @@ -268,8 +269,8 @@ const validJWKS = ` func TestValidateIPRanges(t *testing.T) { ctx := context.Background() - c := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) + c := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) v := NewValidator(reg) c.MustSet(ctx, config.KeyClientHTTPNoPrivateIPRanges, true) @@ -287,10 +288,10 @@ func TestValidateIPRanges(t *testing.T) { func TestValidateDynamicRegistration(t *testing.T) { ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeySubjectTypesSupported, []string{"pairwise", "public"}) c.MustSet(ctx, config.KeyDefaultClientScope, []string{"openid"}) - reg := internal.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) + reg := testhelpers.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) testCtx := context.TODO() v := NewValidator(reg) diff --git a/cmd/cmd_helper_test.go b/cmd/cmd_helper_test.go index da386b4865d..4953f6b4321 100644 --- a/cmd/cmd_helper_test.go +++ b/cmd/cmd_helper_test.go @@ -19,7 +19,6 @@ import ( "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/x/cmdx" "github.com/ory/x/contextx" @@ -40,7 +39,7 @@ func setupRoutes(t *testing.T, cmd *cobra.Command) (*httptest.Server, *httptest. ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) public, admin := testhelpers.NewOAuth2Server(ctx, t, reg) cmdx.RegisterHTTPClientFlags(cmd.Flags()) diff --git a/consent/handler_test.go b/consent/handler_test.go index d5dfe5254ad..45ba2b7733a 100644 --- a/consent/handler_test.go +++ b/consent/handler_test.go @@ -13,13 +13,14 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/require" hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" . "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" "github.com/ory/x/pointerx" @@ -42,8 +43,8 @@ func TestGetLogoutRequest(t *testing.T) { challenge := "challenge" + key requestURL := "http://192.0.2.1" - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) if tc.exists { cl := &client.Client{ID: "client" + key} @@ -97,8 +98,8 @@ func TestGetLoginRequest(t *testing.T) { challenge := "challenge" + key requestURL := "http://192.0.2.1" - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) if tc.exists { cl := &client.Client{ID: "client" + key} @@ -163,8 +164,8 @@ func TestGetConsentRequest(t *testing.T) { challenge := "challenge" + key requestURL := "http://192.0.2.1" - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) if tc.exists { cl := &client.Client{ID: "client" + key} @@ -238,8 +239,8 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) { challenge := "challenge" requestURL := "http://192.0.2.1" - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) cl := &client.Client{ID: "client"} require.NoError(t, reg.ClientManager().CreateClient(ctx, cl)) diff --git a/consent/sdk_test.go b/consent/sdk_test.go index f749428d5d8..0f30d16e7c8 100644 --- a/consent/sdk_test.go +++ b/consent/sdk_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/hydra/v2/consent/test" hydra "github.com/ory/hydra-client-go/v2" @@ -23,7 +25,6 @@ import ( . "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" ) @@ -35,10 +36,10 @@ func makeID(base string, network string, key string) string { func TestSDK(t *testing.T) { ctx := context.Background() network := "t1" - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyIssuerURL, "https://www.ory.sh") conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Minute) - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) consentChallenge := func(f *Flow) string { return x.Must(f.ToConsentChallenge(ctx, reg)) } consentVerifier := func(f *Flow) string { return x.Must(f.ToConsentVerifier(ctx, reg)) } diff --git a/consent/strategy_logout_test.go b/consent/strategy_logout_test.go index 6432a3e13a0..80e633e7bf6 100644 --- a/consent/strategy_logout_test.go +++ b/consent/strategy_logout_test.go @@ -28,7 +28,6 @@ import ( hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/x/contextx" "github.com/ory/x/ioutilx" @@ -37,7 +36,7 @@ import ( func TestLogoutFlows(t *testing.T) { ctx := context.Background() fakeKratos := kratos.NewFake() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour) diff --git a/consent/strategy_oauth_test.go b/consent/strategy_oauth_test.go index 370a3378074..a2e39d5b6ec 100644 --- a/consent/strategy_oauth_test.go +++ b/consent/strategy_oauth_test.go @@ -37,12 +37,11 @@ import ( hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" ) func TestStrategyLoginConsentNext(t *testing.T) { ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour) reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour) diff --git a/consent/test/manager_test_helpers.go b/consent/test/manager_test_helpers.go index a5b141f5359..986b4f3144c 100644 --- a/consent/test/manager_test_helpers.go +++ b/consent/test/manager_test_helpers.go @@ -683,6 +683,7 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo require.NoError(t, fositeManager.CreateRefreshTokenSession( ctx, makeID("", network, "rrva1"), + "", &fosite.Request{Client: cr1.Client, ID: crr1.ID, RequestedAt: time.Now(), Session: &oauth2.Session{DefaultSession: openid.NewDefaultSession()}}, )) require.NoError(t, fositeManager.CreateAccessTokenSession( @@ -693,6 +694,7 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo require.NoError(t, fositeManager.CreateRefreshTokenSession( ctx, makeID("", network, "rrva2"), + "", &fosite.Request{Client: cr2.Client, ID: crr2.ID, RequestedAt: time.Now(), Session: &oauth2.Session{DefaultSession: openid.NewDefaultSession()}}, )) diff --git a/cypress/integration/oauth2/refresh_token.js b/cypress/integration/oauth2/refresh_token.js index fbbbf36e80b..2ddf7d30f19 100644 --- a/cypress/integration/oauth2/refresh_token.js +++ b/cypress/integration/oauth2/refresh_token.js @@ -87,13 +87,13 @@ describe("The OAuth 2.0 Refresh Token Grant", function () { return cy .refreshTokenBrowser(client, originalToken) .then((response) => { - expect(response.status).to.eq(401) - expect(response.body.error).to.eq("token_inactive") + expect(response.status).to.eq(400) + expect(response.body.error).to.eq("invalid_grant") }) .then(() => cy.refreshTokenBrowser(client, refreshedToken)) .then((response) => { - expect(response.status).to.eq(401) - expect(response.body.error).to.eq("token_inactive") + expect(response.status).to.eq(400) + expect(response.body.error).to.eq("invalid_grant") }) }, ) diff --git a/driver/config/provider.go b/driver/config/provider.go index 52b9ee45a3f..b02d0ae1da4 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -213,6 +213,10 @@ func (p *DefaultProvider) MustSet(ctx context.Context, key string, value interfa } } +func (p *DefaultProvider) Delete(ctx context.Context, key string) { + p.getProvider(ctx).Delete(key) +} + func (p *DefaultProvider) Source(ctx context.Context) *configx.Provider { return p.getProvider(ctx) } @@ -517,6 +521,10 @@ type ( ) func (p *DefaultProvider) getHookConfig(ctx context.Context, key string) *HookConfig { + if p.getProvider(ctx).String(key) == "" { + return nil + } + if hookURL := p.getProvider(ctx).RequestURIF(key, nil); hookURL != nil { return &HookConfig{ URL: hookURL.String(), @@ -673,8 +681,8 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string { func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration { gracePeriod := p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) - if gracePeriod > time.Hour { - return time.Hour + if gracePeriod > time.Minute*5 { + return time.Minute * 5 } return gracePeriod } diff --git a/driver/config/provider_test.go b/driver/config/provider_test.go index 168ca81d69f..7ec1dce8df9 100644 --- a/driver/config/provider_test.go +++ b/driver/config/provider_test.go @@ -296,7 +296,7 @@ func TestViperProviderValidates(t *testing.T) { require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "1s")) assert.Equal(t, time.Second, c.RefreshTokenRotationGracePeriod(ctx)) require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "2h")) - assert.Equal(t, time.Hour, c.RefreshTokenRotationGracePeriod(ctx)) + assert.Equal(t, time.Minute*5, c.RefreshTokenRotationGracePeriod(ctx)) // urls assert.Equal(t, urlx.ParseOrPanic("https://issuer"), c.IssuerURL(ctx)) diff --git a/go.mod b/go.mod index 0c9b9277cdf..210341d99d0 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/oleiade/reflections v1.0.1 github.com/ory/analytics-go/v5 v5.0.1 - github.com/ory/fosite v0.48.0 + github.com/ory/fosite v0.48.1-0.20241204100720-b57570a26c3e github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe github.com/ory/graceful v0.1.3 github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88 diff --git a/go.sum b/go.sum index c29c141d383..4761d71ae8e 100644 --- a/go.sum +++ b/go.sum @@ -378,8 +378,8 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk= -github.com/ory/fosite v0.48.0 h1:zxNPNrCBsFwujviVPhbHZzSHZNzjBFZ36MeBFz6tCuU= -github.com/ory/fosite v0.48.0/go.mod h1:M+C+Ng1UDNgwX4SaErnuZwEw26uDN7I3kNUt0WyValI= +github.com/ory/fosite v0.48.1-0.20241204100720-b57570a26c3e h1:C55B0tN1yuintGQ0N+nTnFlrHlxidM3vagM/+7xQrio= +github.com/ory/fosite v0.48.1-0.20241204100720-b57570a26c3e/go.mod h1:M+C+Ng1UDNgwX4SaErnuZwEw26uDN7I3kNUt0WyValI= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI= github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8= diff --git a/health/handler_test.go b/health/handler_test.go index 4b717a02c79..b7821d3cab4 100644 --- a/health/handler_test.go +++ b/health/handler_test.go @@ -9,6 +9,8 @@ import ( "net/http/httptest" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/ory/x/contextx" @@ -16,7 +18,6 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/healthx" ) @@ -71,12 +72,12 @@ func TestPublicHealthHandler(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() for k, v := range tc.config { conf.MustSet(ctx, config.PublicInterface.Key(k), v) } - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) public := x.NewRouterPublic() reg.RegisterRoutes(ctx, x.NewRouterAdmin(conf.AdminURL), public) diff --git a/internal/driver.go b/internal/testhelpers/driver.go similarity index 84% rename from internal/driver.go rename to internal/testhelpers/driver.go index 38a8d8144d4..34a3f40b8bd 100644 --- a/internal/driver.go +++ b/internal/testhelpers/driver.go @@ -1,13 +1,15 @@ // Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 -package internal +package testhelpers import ( "context" "sync" "testing" + "github.com/ory/x/dbal" + "github.com/go-jose/go-jose/v3" "github.com/stretchr/testify/require" @@ -44,24 +46,28 @@ func NewConfigurationWithDefaultsAndHTTPS() *config.DefaultProvider { } func NewRegistryMemory(t testing.TB, c *config.DefaultProvider, ctxer contextx.Contextualizer) driver.Registry { - return newRegistryDefault(t, "memory", c, true, ctxer) + return registryFactory(t, dbal.NewSQLiteTestDatabase(t), c, true, ctxer) } func NewMockedRegistry(t testing.TB, ctxer contextx.Contextualizer) driver.Registry { - return newRegistryDefault(t, "memory", NewConfigurationWithDefaults(), true, ctxer) + return registryFactory(t, dbal.NewSQLiteTestDatabase(t), NewConfigurationWithDefaults(), true, ctxer) } func NewRegistrySQLFromURL(t testing.TB, url string, migrate bool, ctxer contextx.Contextualizer) driver.Registry { - return newRegistryDefault(t, url, NewConfigurationWithDefaults(), migrate, ctxer) + return registryFactory(t, url, NewConfigurationWithDefaults(), migrate, ctxer) +} + +func registryFactory(t testing.TB, url string, c *config.DefaultProvider, migrate bool, ctxer contextx.Contextualizer) driver.Registry { + return RegistryFactory(t, url, c, !migrate, migrate, ctxer) } -func newRegistryDefault(t testing.TB, url string, c *config.DefaultProvider, migrate bool, ctxer contextx.Contextualizer) driver.Registry { +func RegistryFactory(t testing.TB, url string, c *config.DefaultProvider, networkInit, migrate bool, ctxer contextx.Contextualizer) driver.Registry { ctx := context.Background() c.MustSet(ctx, config.KeyLogLevel, "trace") c.MustSet(ctx, config.KeyDSN, url) c.MustSet(ctx, "dev", true) - r, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("test_hydra", "master"), false, migrate, ctxer) + r, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("test_hydra", "master"), networkInit, migrate, ctxer) require.NoError(t, err) return r diff --git a/internal/testhelpers/janitor_test_helper.go b/internal/testhelpers/janitor_test_helper.go index f70d7c27495..c452b3248f1 100644 --- a/internal/testhelpers/janitor_test_helper.go +++ b/internal/testhelpers/janitor_test_helper.go @@ -21,7 +21,6 @@ import ( "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/oauth2/trust" "github.com/ory/hydra/v2/x" @@ -50,7 +49,7 @@ type createGrantRequest struct { const lifespan = time.Hour func NewConsentJanitorTestHelper(uniqueName string) *JanitorConsentTestHelper { - conf := internal.NewConfigurationWithDefaults() + conf := NewConfigurationWithDefaults() conf.MustSet(context.Background(), config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") conf.MustSet(context.Background(), config.KeyIssuerURL, "http://hydra.localhost") conf.MustSet(context.Background(), config.KeyAccessTokenLifespan, lifespan) @@ -126,7 +125,7 @@ func (j *JanitorConsentTestHelper) RefreshTokenNotAfterSetup(ctx context.Context // Create refresh token clients and session for _, fr := range j.flushRefreshRequests { require.NoError(t, cl.CreateClient(ctx, fr.Client.(*client.Client))) - require.NoError(t, store.CreateRefreshTokenSession(ctx, fr.ID, fr)) + require.NoError(t, store.CreateRefreshTokenSession(ctx, fr.ID, "", fr)) } } } diff --git a/internal/testhelpers/oauth2.go b/internal/testhelpers/oauth2.go index 41f0ddaec8e..4a7b5bc696e 100644 --- a/internal/testhelpers/oauth2.go +++ b/internal/testhelpers/oauth2.go @@ -32,7 +32,6 @@ import ( "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" ) @@ -67,8 +66,8 @@ func NewOAuth2Server(ctx context.Context, t testing.TB, reg driver.Registry) (pu public, admin := x.NewRouterPublic(), x.NewRouterAdmin(reg.Config().AdminURL) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) - internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) + MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) reg.RegisterRoutes(ctx, admin, public) @@ -111,6 +110,20 @@ func IntrospectToken(t testing.TB, conf *oauth2.Config, token string, adminTS *h return gjson.ParseBytes(ioutilx.MustReadAll(res.Body)) } +func RevokeToken(t testing.TB, conf *oauth2.Config, token string, publicTS *httptest.Server) gjson.Result { + require.NotEmpty(t, token) + + req := httpx.MustNewRequest("POST", publicTS.URL+"/oauth2/revoke", + strings.NewReader((url.Values{"token": {token}}).Encode()), + "application/x-www-form-urlencoded") + + req.SetBasicAuth(conf.ClientID, conf.ClientSecret) + res, err := publicTS.Client().Do(req) + require.NoError(t, err) + defer res.Body.Close() + return gjson.ParseBytes(ioutilx.MustReadAll(res.Body)) +} + func UpdateClientTokenLifespans(t *testing.T, conf *oauth2.Config, clientID string, lifespans client.Lifespans, adminTS *httptest.Server) { b, err := json.Marshal(lifespans) require.NoError(t, err) diff --git a/jwk/handler_test.go b/jwk/handler_test.go index 5df8182de60..0dc8f6afcdc 100644 --- a/jwk/handler_test.go +++ b/jwk/handler_test.go @@ -10,6 +10,8 @@ import ( "net/http/httptest" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/x/httprouterx" "github.com/ory/hydra/v2/jwk" @@ -20,15 +22,14 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" ) func TestHandlerWellKnown(t *testing.T) { t.Parallel() - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) conf.MustSet(context.Background(), config.KeyWellKnownKeys, []string{x.OpenIDConnectKeyName, x.OpenIDConnectKeyName}) router := x.NewRouterPublic() h := reg.KeyHandler() diff --git a/jwk/helper_test.go b/jwk/helper_test.go index c1a5ee46387..5a6dabd6a60 100644 --- a/jwk/helper_test.go +++ b/jwk/helper_test.go @@ -17,6 +17,8 @@ import ( "strings" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + hydra "github.com/ory/hydra-client-go/v2" "github.com/go-jose/go-jose/v3" @@ -27,7 +29,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" @@ -210,7 +211,7 @@ func TestExcludeOpaquePrivateKeys(t *testing.T) { func TestGetOrGenerateKeys(t *testing.T) { t.Parallel() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) setId := uuid.NewUUID().String() keyId := uuid.NewUUID().String() diff --git a/jwk/jwt_strategy_test.go b/jwk/jwt_strategy_test.go index 8389d20a610..b4def161005 100644 --- a/jwk/jwt_strategy_test.go +++ b/jwk/jwt_strategy_test.go @@ -9,12 +9,13 @@ import ( "strings" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" "github.com/ory/fosite/token/jwt" - "github.com/ory/hydra/v2/internal" . "github.com/ory/hydra/v2/jwk" "github.com/ory/x/contextx" ) @@ -22,8 +23,8 @@ import ( func TestJWTStrategy(t *testing.T) { for _, alg := range []string{"RS256", "ES256", "ES512"} { t.Run("case="+alg, func(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) m := reg.KeyManager() _, err := m.GenerateAndPersistKeySet(context.Background(), "foo-set", "foo", alg, "sig") diff --git a/jwk/sdk_test.go b/jwk/sdk_test.go index f7f7d6a21e8..b2088239884 100644 --- a/jwk/sdk_test.go +++ b/jwk/sdk_test.go @@ -9,12 +9,13 @@ import ( "net/http/httptest" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" . "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" @@ -23,8 +24,8 @@ import ( func TestJWKSDK(t *testing.T) { t.Parallel() ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) router := x.NewRouterAdmin(conf.AdminURL) h := NewHandler(reg) diff --git a/oauth2/equalKeys.go b/oauth2/equalKeys.go deleted file mode 100644 index e16568e078a..00000000000 --- a/oauth2/equalKeys.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package oauth2 - -import ( - "testing" - - "github.com/oleiade/reflections" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func AssertObjectKeysEqual(t *testing.T, a, b interface{}, keys ...string) { - assert.True(t, len(keys) > 0, "No keys provided.") - for _, k := range keys { - c, err := reflections.GetField(a, k) - assert.Nil(t, err) - d, err := reflections.GetField(b, k) - assert.Nil(t, err) - assert.Equal(t, c, d, "%s", k) - } -} - -func AssertObjectKeysNotEqual(t *testing.T, a, b interface{}, keys ...string) { - assert.True(t, len(keys) > 0, "No keys provided.") - for _, k := range keys { - c, err := reflections.GetField(a, k) - assert.Nil(t, err) - d, err := reflections.GetField(b, k) - assert.Nil(t, err) - assert.NotEqual(t, c, d, "%s", k) - } -} - -func RequireObjectKeysEqual(t *testing.T, a, b interface{}, keys ...string) { - assert.True(t, len(keys) > 0, "No keys provided.") - for _, k := range keys { - c, err := reflections.GetField(a, k) - assert.Nil(t, err) - d, err := reflections.GetField(b, k) - assert.Nil(t, err) - require.Equal(t, c, d, "%s", k) - } -} -func RequireObjectKeysNotEqual(t *testing.T, a, b interface{}, keys ...string) { - assert.True(t, len(keys) > 0, "No keys provided.") - for _, k := range keys { - c, err := reflections.GetField(a, k) - assert.Nil(t, err) - d, err := reflections.GetField(b, k) - assert.Nil(t, err) - require.NotEqual(t, c, d, "%s", k) - } -} diff --git a/oauth2/equalKeys_test.go b/oauth2/equalKeys_test.go deleted file mode 100644 index 13243a94bf3..00000000000 --- a/oauth2/equalKeys_test.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package oauth2 - -import "testing" - -func TestAssertObjectsAreEqualByKeys(t *testing.T) { - type foo struct { - Name string - Body int - } - a := &foo{"foo", 1} - b := &foo{"bar", 1} - c := &foo{"baz", 3} - - AssertObjectKeysEqual(t, a, a, "Name", "Body") - AssertObjectKeysNotEqual(t, a, b, "Name") - AssertObjectKeysNotEqual(t, a, c, "Name", "Body") -} diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers_test.go similarity index 65% rename from oauth2/fosite_store_helpers.go rename to oauth2/fosite_store_helpers_test.go index 553a6bae62b..1084e31629c 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers_test.go @@ -1,85 +1,41 @@ // Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 -package oauth2 +package oauth2_test import ( "context" - "crypto/sha256" "fmt" "net/url" "slices" "testing" "time" - "github.com/ory/x/assertx" - - "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/jwk" + "github.com/ory/hydra/v2/persistence/sql" "github.com/go-jose/go-jose/v3" - "github.com/gobuffalo/pop/v6" - "github.com/pborman/uuid" - - "github.com/ory/fosite/handler/rfc7523" - - "github.com/ory/hydra/v2/oauth2/trust" - - "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/x" - - "github.com/ory/fosite/storage" - "github.com/ory/x/sqlxx" - gofrsuuid "github.com/gofrs/uuid" + "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" - "github.com/ory/x/sqlcon" - + "github.com/ory/fosite/handler/rfc7523" + "github.com/ory/fosite/storage" "github.com/ory/hydra/v2/client" + "github.com/ory/hydra/v2/driver/config" + "github.com/ory/hydra/v2/flow" + "github.com/ory/hydra/v2/jwk" + "github.com/ory/hydra/v2/oauth2" + "github.com/ory/hydra/v2/oauth2/trust" + "github.com/ory/hydra/v2/x" + "github.com/ory/x/assertx" + "github.com/ory/x/sqlcon" + "github.com/ory/x/sqlxx" ) -func signatureFromJTI(jti string) string { - return fmt.Sprintf("%x", sha256.Sum256([]byte(jti))) -} - -type BlacklistedJTI struct { - JTI string `db:"-"` - ID string `db:"signature"` - Expiry time.Time `db:"expires_at"` - NID gofrsuuid.UUID `db:"nid"` -} - -func (j *BlacklistedJTI) AfterFind(_ *pop.Connection) error { - j.Expiry = j.Expiry.UTC() - return nil -} - -func (BlacklistedJTI) TableName() string { - return "hydra_oauth2_jti_blacklist" -} - -func NewBlacklistedJTI(jti string, exp time.Time) *BlacklistedJTI { - return &BlacklistedJTI{ - JTI: jti, - ID: signatureFromJTI(jti), - // because the database timestamp types are not as accurate as time.Time we truncate to seconds (which should always work) - Expiry: exp.UTC().Truncate(time.Second), - } -} - -type AssertionJWTReader interface { - x.FositeStorer - - GetClientAssertionJWT(ctx context.Context, jti string) (*BlacklistedJTI, error) - - SetClientAssertionJWTRaw(context.Context, *BlacklistedJTI) error -} - var defaultIgnoreKeys = []string{ "id", "session", @@ -94,29 +50,33 @@ var defaultIgnoreKeys = []string{ "client.client_secret", } -var defaultRequest = fosite.Request{ - ID: "blank", - RequestedAt: time.Now().UTC().Round(time.Second), - Client: &client.Client{ - ID: "foobar", - Contacts: []string{}, - RedirectURIs: []string{}, - Audience: []string{}, - AllowedCORSOrigins: []string{}, - ResponseTypes: []string{}, - GrantTypes: []string{}, - JSONWebKeys: &x.JoseJSONWebKeySet{}, - Metadata: sqlxx.JSONRawMessage("{}"), - }, - RequestedScope: fosite.Arguments{"fa", "ba"}, - GrantedScope: fosite.Arguments{"fa", "ba"}, - RequestedAudience: fosite.Arguments{"ad1", "ad2"}, - GrantedAudience: fosite.Arguments{"ad1", "ad2"}, - Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: NewSession("bar"), +func newDefaultRequest(id string) fosite.Request { + return fosite.Request{ + ID: id, + RequestedAt: time.Now().UTC().Round(time.Second), + Client: &client.Client{ + ID: "foobar", + Contacts: []string{}, + RedirectURIs: []string{}, + Audience: []string{}, + AllowedCORSOrigins: []string{}, + ResponseTypes: []string{}, + GrantTypes: []string{}, + JSONWebKeys: &x.JoseJSONWebKeySet{}, + Metadata: sqlxx.JSONRawMessage("{}"), + }, + RequestedScope: fosite.Arguments{"fa", "ba"}, + GrantedScope: fosite.Arguments{"fa", "ba"}, + RequestedAudience: fosite.Arguments{"ad1", "ad2"}, + GrantedAudience: fosite.Arguments{"ad1", "ad2"}, + Form: url.Values{"foo": []string{"bar", "baz"}}, + Session: oauth2.NewSession("bar"), + } } -var lifespan = time.Hour +var defaultRequest = newDefaultRequest("blank") + +// var lifespan = time.Hour var flushRequests = []*fosite.Request{ { ID: "flush-1", @@ -125,7 +85,7 @@ var flushRequests = []*fosite.Request{ RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, { ID: "flush-2", @@ -134,7 +94,7 @@ var flushRequests = []*fosite.Request{ RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, { ID: "flush-3", @@ -143,11 +103,11 @@ var flushRequests = []*fosite.Request{ RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, } -func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry) { +func mockRequestForeignKey(t *testing.T, id string, x oauth2.InternalRegistry) { cl := &client.Client{ID: "foobar"} cr := &flow.OAuth2ConsentRequest{ Client: cl, @@ -193,43 +153,10 @@ func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry) { require.NoError(t, err) } -// TestHelperRunner is used to run the database suite of tests in this package. -// KEEP EXPORTED AND AVAILABLE FOR THIRD PARTIES TO TEST PLUGINS! -func TestHelperRunner(t *testing.T, store InternalRegistry, k string) { - t.Helper() - if k != "memory" { - t.Run(fmt.Sprintf("case=testHelperUniqueConstraints/db=%s", k), testHelperRequestIDMultiples(store, k)) - t.Run("case=testFositeSqlStoreTransactionsCommitAccessToken", testFositeSqlStoreTransactionCommitAccessToken(store)) - t.Run("case=testFositeSqlStoreTransactionsRollbackAccessToken", testFositeSqlStoreTransactionRollbackAccessToken(store)) - t.Run("case=testFositeSqlStoreTransactionCommitRefreshToken", testFositeSqlStoreTransactionCommitRefreshToken(store)) - t.Run("case=testFositeSqlStoreTransactionRollbackRefreshToken", testFositeSqlStoreTransactionRollbackRefreshToken(store)) - t.Run("case=testFositeSqlStoreTransactionCommitAuthorizeCode", testFositeSqlStoreTransactionCommitAuthorizeCode(store)) - t.Run("case=testFositeSqlStoreTransactionRollbackAuthorizeCode", testFositeSqlStoreTransactionRollbackAuthorizeCode(store)) - t.Run("case=testFositeSqlStoreTransactionCommitPKCERequest", testFositeSqlStoreTransactionCommitPKCERequest(store)) - t.Run("case=testFositeSqlStoreTransactionRollbackPKCERequest", testFositeSqlStoreTransactionRollbackPKCERequest(store)) - t.Run("case=testFositeSqlStoreTransactionCommitOpenIdConnectSession", testFositeSqlStoreTransactionCommitOpenIdConnectSession(store)) - t.Run("case=testFositeSqlStoreTransactionRollbackOpenIdConnectSession", testFositeSqlStoreTransactionRollbackOpenIdConnectSession(store)) - - } - t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAuthorizeCodes/db=%s", k), testHelperCreateGetDeleteAuthorizeCodes(store)) - t.Run(fmt.Sprintf("case=testHelperExpiryFields/db=%s", k), testHelperExpiryFields(store)) - t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAccessTokenSession/db=%s", k), testHelperCreateGetDeleteAccessTokenSession(store)) - t.Run(fmt.Sprintf("case=testHelperNilAccessToken/db=%s", k), testHelperNilAccessToken(store)) - t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteOpenIDConnectSession/db=%s", k), testHelperCreateGetDeleteOpenIDConnectSession(store)) - t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteRefreshTokenSession/db=%s", k), testHelperCreateGetDeleteRefreshTokenSession(store)) - t.Run(fmt.Sprintf("case=testHelperRevokeRefreshToken/db=%s", k), testHelperRevokeRefreshToken(store)) - t.Run(fmt.Sprintf("case=testHelperCreateGetDeletePKCERequestSession/db=%s", k), testHelperCreateGetDeletePKCERequestSession(store)) - t.Run(fmt.Sprintf("case=testHelperFlushTokens/db=%s", k), testHelperFlushTokens(store, time.Hour)) - t.Run(fmt.Sprintf("case=testHelperFlushTokensWithLimitAndBatchSize/db=%s", k), testHelperFlushTokensWithLimitAndBatchSize(store, 3, 2)) - t.Run(fmt.Sprintf("case=testFositeStoreSetClientAssertionJWT/db=%s", k), testFositeStoreSetClientAssertionJWT(store)) - t.Run(fmt.Sprintf("case=testFositeStoreClientAssertionJWTValid/db=%s", k), testFositeStoreClientAssertionJWTValid(store)) - t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store)) - t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store)) - t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store)) - t.Run(fmt.Sprintf("case=testHelperRevokeRefreshTokenMaybeGracePeriod/db=%s", k), testHelperRevokeRefreshTokenMaybeGracePeriod(store)) +func TestHelperRunner(t *testing.T) { } -func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) { +func testHelperRequestIDMultiples(m oauth2.InternalRegistry, _ string) func(t *testing.T) { return func(t *testing.T) { ctx := context.Background() requestID := uuid.New() @@ -240,12 +167,13 @@ func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing. ID: requestID, Client: cl, RequestedAt: time.Now().UTC().Round(time.Second), - Session: NewSession("bar"), + Session: oauth2.NewSession("bar"), } for i := 0; i < 4; i++ { signature := uuid.New() - err := m.OAuth2Storage().CreateRefreshTokenSession(ctx, signature, fositeRequest) + accessSignature := uuid.New() + err := m.OAuth2Storage().CreateRefreshTokenSession(ctx, signature, accessSignature, fositeRequest) assert.NoError(t, err) err = m.OAuth2Storage().CreateAccessTokenSession(ctx, signature, fositeRequest) assert.NoError(t, err) @@ -259,58 +187,60 @@ func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing. } } -func testHelperCreateGetDeleteOpenIDConnectSession(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeleteOpenIDConnectSession(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() + code := uuid.New() ctx := context.Background() - _, err := m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")}) + _, err := m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewSession("bar")}) assert.NotNil(t, err) - err = m.CreateOpenIDConnectSession(ctx, "4321", &defaultRequest) + err = m.CreateOpenIDConnectSession(ctx, code, &defaultRequest) require.NoError(t, err) - res, err := m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")}) + res, err := m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewSession("bar")}) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.DeleteOpenIDConnectSession(ctx, "4321") + err = m.DeleteOpenIDConnectSession(ctx, code) require.NoError(t, err) - _, err = m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")}) + _, err = m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewSession("bar")}) assert.NotNil(t, err) } } -func testHelperCreateGetDeleteRefreshTokenSession(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeleteRefreshTokenSession(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() + code := uuid.New() ctx := context.Background() - _, err := m.GetRefreshTokenSession(ctx, "4321", NewSession("bar")) + _, err := m.GetRefreshTokenSession(ctx, code, oauth2.NewSession("bar")) assert.NotNil(t, err) - err = m.CreateRefreshTokenSession(ctx, "4321", &defaultRequest) + err = m.CreateRefreshTokenSession(ctx, code, "", &defaultRequest) require.NoError(t, err) - res, err := m.GetRefreshTokenSession(ctx, "4321", NewSession("bar")) + res, err := m.GetRefreshTokenSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.DeleteRefreshTokenSession(ctx, "4321") + err = m.DeleteRefreshTokenSession(ctx, code) require.NoError(t, err) - _, err = m.GetRefreshTokenSession(ctx, "4321", NewSession("bar")) + _, err = m.GetRefreshTokenSession(ctx, code, oauth2.NewSession("bar")) assert.NotNil(t, err) } } -func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) { +func testHelperRevokeRefreshToken(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := context.Background() - _, err := m.GetRefreshTokenSession(ctx, "1111", NewSession("bar")) + _, err := m.GetRefreshTokenSession(ctx, "1111", oauth2.NewSession("bar")) assert.Error(t, err) reqIdOne := uuid.New() @@ -319,23 +249,23 @@ func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) { mockRequestForeignKey(t, reqIdOne, x) mockRequestForeignKey(t, reqIdTwo, x) - err = m.CreateRefreshTokenSession(ctx, "1111", &fosite.Request{ + err = m.CreateRefreshTokenSession(ctx, "1111", "", &fosite.Request{ ID: reqIdOne, Client: &client.Client{ID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), - Session: NewSession("user"), + Session: oauth2.NewSession("user"), }) require.NoError(t, err) - err = m.CreateRefreshTokenSession(ctx, "1122", &fosite.Request{ + err = m.CreateRefreshTokenSession(ctx, "1122", "", &fosite.Request{ ID: reqIdTwo, Client: &client.Client{ID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), - Session: NewSession("user"), + Session: oauth2.NewSession("user"), }) require.NoError(t, err) - _, err = m.GetRefreshTokenSession(ctx, "1111", NewSession("bar")) + _, err = m.GetRefreshTokenSession(ctx, "1111", oauth2.NewSession("bar")) require.NoError(t, err) err = m.RevokeRefreshToken(ctx, reqIdOne) @@ -344,39 +274,40 @@ func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) { err = m.RevokeRefreshToken(ctx, reqIdTwo) require.NoError(t, err) - req, err := m.GetRefreshTokenSession(ctx, "1111", NewSession("bar")) - assert.NotNil(t, req) - assert.EqualError(t, err, fosite.ErrInactiveToken.Error()) - - req, err = m.GetRefreshTokenSession(ctx, "1122", NewSession("bar")) - assert.NotNil(t, req) - assert.EqualError(t, err, fosite.ErrInactiveToken.Error()) + req, err := m.GetRefreshTokenSession(ctx, "1111", oauth2.NewSession("bar")) + assert.Nil(t, req) + assert.EqualError(t, err, fosite.ErrNotFound.Error()) + req, err = m.GetRefreshTokenSession(ctx, "1122", oauth2.NewSession("bar")) + assert.Nil(t, req) + assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } -func testHelperCreateGetDeleteAuthorizeCodes(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeleteAuthorizeCodes(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() mockRequestForeignKey(t, "blank", x) + code := uuid.New() + ctx := context.Background() - res, err := m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar")) + res, err := m.GetAuthorizeCodeSession(ctx, code, oauth2.NewSession("bar")) assert.Error(t, err) assert.Nil(t, res) - err = m.CreateAuthorizeCodeSession(ctx, "4321", &defaultRequest) + err = m.CreateAuthorizeCodeSession(ctx, code, &defaultRequest) require.NoError(t, err) - res, err = m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar")) + res, err = m.GetAuthorizeCodeSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.InvalidateAuthorizeCodeSession(ctx, "4321") + err = m.InvalidateAuthorizeCodeSession(ctx, code) require.NoError(t, err) - res, err = m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar")) + res, err = m.GetAuthorizeCodeSession(ctx, code, oauth2.NewSession("bar")) require.Error(t, err) assert.EqualError(t, err, fosite.ErrInvalidatedAuthorizeCode.Error()) assert.NotNil(t, res) @@ -392,7 +323,7 @@ func (r testHelperExpiryFieldsResult) TableName() string { return "hydra_oauth2_" + r.name } -func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { +func testHelperExpiryFields(reg oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := reg.OAuth2Storage() t.Parallel() @@ -401,7 +332,7 @@ func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { ctx := context.Background() - s := NewSession("bar") + s := oauth2.NewSession("bar") s.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour).Round(time.Minute)) s.SetExpiresAt(fosite.RefreshToken, time.Now().Add(time.Hour*2).Round(time.Minute)) s.SetExpiresAt(fosite.AuthorizeCode, time.Now().Add(time.Hour*3).Round(time.Minute)) @@ -433,7 +364,7 @@ func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { t.Run("case=CreateRefreshTokenSession", func(t *testing.T) { id := uuid.New() - err := m.CreateRefreshTokenSession(ctx, id, &request) + err := m.CreateRefreshTokenSession(ctx, id, "", &request) require.NoError(t, err) r := testHelperExpiryFieldsResult{name: "refresh"} @@ -473,12 +404,12 @@ func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { } } -func testHelperNilAccessToken(x InternalRegistry) func(t *testing.T) { +func testHelperNilAccessToken(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() - c := &client.Client{ID: "nil-request-client-id-123"} + c := &client.Client{ID: uuid.New()} require.NoError(t, x.ClientManager().CreateClient(context.Background(), c)) - err := m.CreateAccessTokenSession(context.Background(), "nil-request-id", &fosite.Request{ + err := m.CreateAccessTokenSession(context.Background(), uuid.New(), &fosite.Request{ ID: "", RequestedAt: time.Now().UTC().Round(time.Second), Client: c, @@ -487,158 +418,251 @@ func testHelperNilAccessToken(x InternalRegistry) func(t *testing.T) { RequestedAudience: fosite.Arguments{"ad1", "ad2"}, GrantedAudience: fosite.Arguments{"ad1", "ad2"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: NewSession("bar"), + Session: oauth2.NewSession("bar"), }) require.NoError(t, err) } } -func testHelperCreateGetDeleteAccessTokenSession(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeleteAccessTokenSession(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() + code := uuid.New() ctx := context.Background() - _, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + _, err := m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) assert.Error(t, err) - err = m.CreateAccessTokenSession(ctx, "4321", &defaultRequest) + err = m.CreateAccessTokenSession(ctx, code, &defaultRequest) require.NoError(t, err) - res, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + res, err := m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.DeleteAccessTokenSession(ctx, "4321") + err = m.DeleteAccessTokenSession(ctx, code) require.NoError(t, err) - _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) assert.Error(t, err) } } -func testHelperDeleteAccessTokens(x InternalRegistry) func(t *testing.T) { +func testHelperDeleteAccessTokens(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := context.Background() - err := m.CreateAccessTokenSession(ctx, "4321", &defaultRequest) + code := uuid.New() + err := m.CreateAccessTokenSession(ctx, code, &defaultRequest) require.NoError(t, err) - _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) err = m.DeleteAccessTokens(ctx, defaultRequest.Client.GetID()) require.NoError(t, err) - req, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + req, err := m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) assert.Nil(t, req) assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } -func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) { +func testHelperRevokeAccessToken(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := context.Background() - err := m.CreateAccessTokenSession(ctx, "4321", &defaultRequest) + code := uuid.New() + err := m.CreateAccessTokenSession(ctx, code, &defaultRequest) require.NoError(t, err) - _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) err = m.RevokeAccessToken(ctx, defaultRequest.GetID()) require.NoError(t, err) - req, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + req, err := m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) assert.Nil(t, req) assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } -func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) { +func testHelperRotateRefreshToken(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { ctx := context.Background() + createTokens := func(t *testing.T, r *fosite.Request) (refreshTokenSession string, accessTokenSession string) { + refreshTokenSession = fmt.Sprintf("refresh_token_%s", uuid.New()) + accessTokenSession = fmt.Sprintf("access_token_%s", uuid.New()) + err := x.OAuth2Storage().CreateAccessTokenSession(ctx, accessTokenSession, r) + require.NoError(t, err) + + err = x.OAuth2Storage().CreateRefreshTokenSession(ctx, refreshTokenSession, accessTokenSession, r) + require.NoError(t, err) + + // Sanity check + req, err := x.OAuth2Storage().GetRefreshTokenSession(ctx, refreshTokenSession, nil) + require.NoError(t, err) + require.EqualValues(t, r.GetID(), req.GetID()) + + req, err = x.OAuth2Storage().GetAccessTokenSession(ctx, accessTokenSession, nil) + require.NoError(t, err) + require.EqualValues(t, r.GetID(), req.GetID()) + return + } + t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) { - // SETUP m := x.OAuth2Storage() + r := newDefaultRequest(uuid.New()) + refreshTokenSession, accessTokenSession := createTokens(t, &r) - refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix()) - err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) - require.NoError(t, err, "precondition failed: could not create refresh token session") - - // ACT - err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession) + err := m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession) require.NoError(t, err) - tmpSession := new(fosite.Session) - _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + _, err = m.GetAccessTokenSession(ctx, accessTokenSession, nil) + assert.ErrorIs(t, err, fosite.ErrNotFound, "Token is no longer active because it was refreshed") + + _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + assert.ErrorIs(t, err, fosite.ErrInactiveToken, "Token is no longer active because it was refreshed") + }) + + t.Run("refresh token is valid until the grace period has ended", func(t *testing.T) { + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") + + // By setting this to one hour we ensure that using the refresh token triggers the start of the grace period. + x.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1h") + t.Cleanup(func() { + x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) + }) + + m := x.OAuth2Storage() + r := newDefaultRequest(uuid.New()) + refreshTokenSession, accessTokenSession1 := createTokens(t, &r) + accessTokenSession2 := fmt.Sprintf("access_token_%s", uuid.New()) + require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession2, &r)) + + // Create a second access token + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + + req, err := m.GetAccessTokenSession(ctx, accessTokenSession1, nil) + assert.ErrorIs(t, err, fosite.ErrNotFound) + + req, err = m.GetAccessTokenSession(ctx, accessTokenSession2, nil) + assert.NoError(t, err, "The second access token is still valid.") + + req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + assert.NoError(t, err) + assert.Equal(t, r.GetID(), req.GetID()) - // ASSERT - // a revoked refresh token returns an error when getting the token again - assert.ErrorIs(t, err, fosite.ErrInactiveToken) + // We only wait a second, meaning that the token is theoretically still within TTL, but since the + // grace period was issued, the token is still valid. + time.Sleep(time.Second * 2) + req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + assert.Error(t, err) }) - t.Run("refresh token enters grace period when configured,", func(t *testing.T) { - // SETUP - x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1m") + t.Run("the used at time does not change", func(t *testing.T) { + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") + + // By setting this to one hour we ensure that using the refresh token triggers the start of the grace period. + x.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1h") + t.Cleanup(func() { + x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) + }) + + m := x.OAuth2Storage() + r := newDefaultRequest(uuid.New()) + + refreshTokenSession, _ := createTokens(t, &r) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + + var expected sql.OAuth2RefreshTable + require.NoError(t, x.Persister().Connection(ctx).Where("signature=?", refreshTokenSession).First(&expected)) + assert.False(t, expected.FirstUsedAt.Time.IsZero()) + assert.True(t, expected.FirstUsedAt.Valid) + + // Refresh does not change the time + time.Sleep(time.Second * 2) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + + var actual sql.OAuth2RefreshTable + require.NoError(t, x.Persister().Connection(ctx).Where("signature=?", refreshTokenSession).First(&actual)) + assert.Equal(t, expected.FirstUsedAt.Time, actual.FirstUsedAt.Time) + }) - // always reset back to the default + t.Run("refresh token revokes all access tokens from the request if the access token signature is not found", func(t *testing.T) { + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") t.Cleanup(func() { - x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "0m") + x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) }) m := x.OAuth2Storage() + r := newDefaultRequest(uuid.New()) - refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix()) - err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) - require.NoError(t, err, "precondition failed: could not create refresh token session") + refreshTokenSession := fmt.Sprintf("refresh_token_%s", uuid.New()) + accessTokenSession1 := fmt.Sprintf("access_token_%s", uuid.New()) + accessTokenSession2 := fmt.Sprintf("access_token_%s", uuid.New()) + require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession1, &r)) + require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession2, &r)) + + require.NoError(t, m.CreateRefreshTokenSession(ctx, refreshTokenSession, "", &r), + "precondition failed: could not create refresh token session") // ACT - require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) - require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) - require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + + req, err := m.GetAccessTokenSession(ctx, accessTokenSession1, nil) + assert.ErrorIs(t, err, fosite.ErrNotFound) - req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + req, err = m.GetAccessTokenSession(ctx, accessTokenSession2, nil) + assert.ErrorIs(t, err, fosite.ErrNotFound) - // ASSERT - // when grace period is configured the refresh token can be obtained within - // the grace period without error + req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) assert.NoError(t, err) + assert.Equal(t, r.GetID(), req.GetID()) + + time.Sleep(time.Second * 2) - assert.Equal(t, defaultRequest.GetID(), req.GetID()) + req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + assert.Error(t, err) }) } - } -func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeletePKCERequestSession(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() + code := uuid.New() ctx := context.Background() - _, err := m.GetPKCERequestSession(ctx, "4321", NewSession("bar")) + _, err := m.GetPKCERequestSession(ctx, code, oauth2.NewSession("bar")) assert.NotNil(t, err) - err = m.CreatePKCERequestSession(ctx, "4321", &defaultRequest) + err = m.CreatePKCERequestSession(ctx, code, &defaultRequest) require.NoError(t, err) - res, err := m.GetPKCERequestSession(ctx, "4321", NewSession("bar")) + res, err := m.GetPKCERequestSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.DeletePKCERequestSession(ctx, "4321") + err = m.DeletePKCERequestSession(ctx, code) require.NoError(t, err) - _, err = m.GetPKCERequestSession(ctx, "4321", NewSession("bar")) + _, err = m.GetPKCERequestSession(ctx, code, oauth2.NewSession("bar")) assert.NotNil(t, err) } } -func testHelperFlushTokens(x InternalRegistry, lifespan time.Duration) func(t *testing.T) { +func testHelperFlushTokens(x oauth2.InternalRegistry, lifespan time.Duration) func(t *testing.T) { m := x.OAuth2Storage() - ds := &Session{} + ds := &oauth2.Session{} return func(t *testing.T) { ctx := context.Background() @@ -676,9 +700,9 @@ func testHelperFlushTokens(x InternalRegistry, lifespan time.Duration) func(t *t } } -func testHelperFlushTokensWithLimitAndBatchSize(x InternalRegistry, limit int, batchSize int) func(t *testing.T) { +func testHelperFlushTokensWithLimitAndBatchSize(x oauth2.InternalRegistry, limit int, batchSize int) func(t *testing.T) { m := x.OAuth2Storage() - ds := &Session{} + ds := &oauth2.Session{} return func(t *testing.T) { ctx := context.Background() @@ -712,7 +736,7 @@ func testHelperFlushTokensWithLimitAndBatchSize(x InternalRegistry, limit int, b } } -func testFositeSqlStoreTransactionCommitAccessToken(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionCommitAccessToken(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { { doTestCommit(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().RevokeAccessToken) @@ -721,7 +745,7 @@ func testFositeSqlStoreTransactionCommitAccessToken(m InternalRegistry) func(t * } } -func testFositeSqlStoreTransactionRollbackAccessToken(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackAccessToken(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { { doTestRollback(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().RevokeAccessToken) @@ -730,42 +754,41 @@ func testFositeSqlStoreTransactionRollbackAccessToken(m InternalRegistry) func(t } } -func testFositeSqlStoreTransactionCommitRefreshToken(m InternalRegistry) func(t *testing.T) { - +func testFositeSqlStoreTransactionCommitRefreshToken(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { - doTestCommit(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) - doTestCommit(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) + doTestCommitRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) + doTestCommitRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) } } -func testFositeSqlStoreTransactionRollbackRefreshToken(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackRefreshToken(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { - doTestRollback(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) - doTestRollback(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) + doTestRollbackRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) + doTestRollbackRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) } } -func testFositeSqlStoreTransactionCommitAuthorizeCode(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionCommitAuthorizeCode(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { doTestCommit(m, t, m.OAuth2Storage().CreateAuthorizeCodeSession, m.OAuth2Storage().GetAuthorizeCodeSession, m.OAuth2Storage().InvalidateAuthorizeCodeSession) } } -func testFositeSqlStoreTransactionRollbackAuthorizeCode(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackAuthorizeCode(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { doTestRollback(m, t, m.OAuth2Storage().CreateAuthorizeCodeSession, m.OAuth2Storage().GetAuthorizeCodeSession, m.OAuth2Storage().InvalidateAuthorizeCodeSession) } } -func testFositeSqlStoreTransactionCommitPKCERequest(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionCommitPKCERequest(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { doTestCommit(m, t, m.OAuth2Storage().CreatePKCERequestSession, m.OAuth2Storage().GetPKCERequestSession, m.OAuth2Storage().DeletePKCERequestSession) } } -func testFositeSqlStoreTransactionRollbackPKCERequest(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackPKCERequest(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { doTestRollback(m, t, m.OAuth2Storage().CreatePKCERequestSession, m.OAuth2Storage().GetPKCERequestSession, m.OAuth2Storage().DeletePKCERequestSession) } @@ -773,7 +796,7 @@ func testFositeSqlStoreTransactionRollbackPKCERequest(m InternalRegistry) func(t // OpenIdConnect tests can't use the helper functions, due to the signature of GetOpenIdConnectSession being // different from the other getter methods -func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { txnStore, ok := m.OAuth2Storage().(storage.Transactional) require.True(t, ok) @@ -808,7 +831,7 @@ func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m InternalRegistry) } } -func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { txnStore, ok := m.OAuth2Storage().(storage.Transactional) require.True(t, ok) @@ -849,12 +872,12 @@ func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m InternalRegistr } } -func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) { +func testFositeStoreSetClientAssertionJWT(m oauth2.InternalRegistry) func(*testing.T) { return func(t *testing.T) { t.Run("case=basic setting works", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("basic jti", time.Now().Add(time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry)) @@ -866,20 +889,20 @@ func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) { }) t.Run("case=errors when the JTI is blacklisted", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("already set jti", time.Now().Add(time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) assert.ErrorIs(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry), fosite.ErrJTIKnown) }) t.Run("case=deletes expired JTIs", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - expiredJTI := NewBlacklistedJTI("expired jti", time.Now().Add(-time.Minute)) + expiredJTI := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), expiredJTI)) - newJTI := NewBlacklistedJTI("some new jti", time.Now().Add(time.Minute)) + newJTI := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWT(context.Background(), newJTI.JTI, newJTI.Expiry)) @@ -893,9 +916,9 @@ func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) { }) t.Run("case=inserts same JTI if expired", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("going to be reused jti", time.Now().Add(-time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) jti.Expiry = jti.Expiry.Add(2 * time.Minute) @@ -907,19 +930,19 @@ func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) { } } -func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) { +func testFositeStoreClientAssertionJWTValid(m oauth2.InternalRegistry) func(*testing.T) { return func(t *testing.T) { t.Run("case=returns valid on unknown JTI", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), "unknown jti")) + assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), uuid.New())) }) t.Run("case=returns invalid on known JTI", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("known jti", time.Now().Add(time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) @@ -927,9 +950,9 @@ func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) }) t.Run("case=returns valid on expired JTI", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("expired jti 2", time.Now().Add(-time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) @@ -938,7 +961,7 @@ func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) } } -func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { +func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { ctx := context.Background() grantManager := x.GrantManager() @@ -946,12 +969,12 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { grantStorage := x.OAuth2Storage().(rfc7523.RFC7523KeyStorage) t.Run("case=associated key added with grant", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "token-service-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "token-service" - subject := "bob@example.com" + issuer := uuid.New() + subject := "bob+" + uuid.New() + "@example.com" grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -992,14 +1015,14 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=only associated key returns", func(t *testing.T) { - keySetToNotReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "some-key", "sig") + keySetToNotReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, uuid.New(), "sig") require.NoError(t, err) - require.NoError(t, keyManager.AddKeySet(context.Background(), "some-set", keySetToNotReturn), "adding a random key should not fail") + require.NoError(t, keyManager.AddKeySet(context.Background(), uuid.New(), keySetToNotReturn), "adding a random key should not fail") - issuer := "maria" - subject := "maria@example.com" + issuer := uuid.New() + subject := "maria+" + uuid.New() + "@example.com" - keySet1ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-1", "sig") + keySet1ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, uuid.New(), "sig") require.NoError(t, err) require.NoError(t, grantManager.CreateGrant(context.Background(), trust.Grant{ ID: uuid.New(), @@ -1012,7 +1035,7 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), }, keySet1ToReturn.Keys[0].Public())) - keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-2", "sig") + keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, uuid.New(), "sig") require.NoError(t, err) require.NoError(t, grantManager.CreateGrant(ctx, trust.Grant{ ID: uuid.New(), @@ -1055,12 +1078,12 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=associated key is deleted, when granted is deleted", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "hackerman-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "aeneas" - subject := "aeneas@example.com" + issuer := uuid.New() + subject := "aeneas+" + uuid.New() + "@example.com" grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1092,12 +1115,12 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=associated grant is deleted, when key is deleted", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "vladimir-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "vladimir" - subject := "vladimir@example.com" + issuer := uuid.New() + subject := "vladimir+" + uuid.New() + "@example.com" grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1129,12 +1152,12 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=only returns the key when subject matches", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "limited-issuer" - subject := "jagoba" + issuer := uuid.New() + subject := "jagoba+" + uuid.New() + "@example.com" grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1171,11 +1194,11 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=returns the key when any subject is allowed", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "unlimited-issuer" + issuer := uuid.New() grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1204,11 +1227,11 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=does not return expired values", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-expired-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "expired-issuer" + issuer := uuid.New() grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1230,12 +1253,11 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { } } -func doTestCommit(m InternalRegistry, t *testing.T, +func doTestCommit(m oauth2.InternalRegistry, t *testing.T, createFn func(context.Context, string, fosite.Requester) error, getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), revokeFn func(context.Context, string) error, ) { - txnStore, ok := m.OAuth2Storage().(storage.Transactional) require.True(t, ok) ctx := context.Background() @@ -1248,7 +1270,44 @@ func doTestCommit(m InternalRegistry, t *testing.T, require.NoError(t, err) // Require a new context, since the old one contains the transaction. - res, err := getFn(context.Background(), signature, NewSession("bar")) + res, err := getFn(context.Background(), signature, oauth2.NewSession("bar")) + // token should have been created successfully because Commit did not return an error + require.NoError(t, err) + assertx.EqualAsJSONExcept(t, &defaultRequest, res, defaultIgnoreKeys) + // AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") + + // testrevoke within a transaction + ctx, err = txnStore.BeginTX(context.Background()) + require.NoError(t, err) + err = revokeFn(ctx, signature) + require.NoError(t, err) + err = txnStore.Commit(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + _, err = getFn(context.Background(), signature, oauth2.NewSession("bar")) + // Since commit worked for revoke, we should get an error here. + require.Error(t, err) +} + +func doTestCommitRefresh(m oauth2.InternalRegistry, t *testing.T, + createFn func(context.Context, string, string, fosite.Requester) error, + getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), + revokeFn func(context.Context, string) error, +) { + txnStore, ok := m.OAuth2Storage().(storage.Transactional) + require.True(t, ok) + ctx := context.Background() + ctx, err := txnStore.BeginTX(ctx) + require.NoError(t, err) + signature := uuid.New() + err = createFn(ctx, signature, "", createTestRequest(signature)) + require.NoError(t, err) + err = txnStore.Commit(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + res, err := getFn(context.Background(), signature, oauth2.NewSession("bar")) // token should have been created successfully because Commit did not return an error require.NoError(t, err) assertx.EqualAsJSONExcept(t, &defaultRequest, res, defaultIgnoreKeys) @@ -1263,12 +1322,12 @@ func doTestCommit(m InternalRegistry, t *testing.T, require.NoError(t, err) // Require a new context, since the old one contains the transaction. - _, err = getFn(context.Background(), signature, NewSession("bar")) + _, err = getFn(context.Background(), signature, oauth2.NewSession("bar")) // Since commit worked for revoke, we should get an error here. require.Error(t, err) } -func doTestRollback(m InternalRegistry, t *testing.T, +func doTestRollback(m oauth2.InternalRegistry, t *testing.T, createFn func(context.Context, string, fosite.Requester) error, getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), revokeFn func(context.Context, string) error, @@ -1287,7 +1346,7 @@ func doTestRollback(m InternalRegistry, t *testing.T, // Require a new context, since the old one contains the transaction. ctx = context.Background() - _, err = getFn(ctx, signature, NewSession("bar")) + _, err = getFn(ctx, signature, oauth2.NewSession("bar")) // Since we rolled back above, the token should not exist and getting it should result in an error require.Error(t, err) @@ -1295,7 +1354,48 @@ func doTestRollback(m InternalRegistry, t *testing.T, signature2 := uuid.New() err = createFn(ctx, signature2, createTestRequest(signature2)) require.NoError(t, err) - _, err = getFn(ctx, signature2, NewSession("bar")) + _, err = getFn(ctx, signature2, oauth2.NewSession("bar")) + require.NoError(t, err) + + ctx, err = txnStore.BeginTX(context.Background()) + require.NoError(t, err) + err = revokeFn(ctx, signature2) + require.NoError(t, err) + err = txnStore.Rollback(ctx) + require.NoError(t, err) + + _, err = getFn(context.Background(), signature2, oauth2.NewSession("bar")) + require.NoError(t, err) +} + +func doTestRollbackRefresh(m oauth2.InternalRegistry, t *testing.T, + createFn func(context.Context, string, string, fosite.Requester) error, + getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), + revokeFn func(context.Context, string) error, +) { + txnStore, ok := m.OAuth2Storage().(storage.Transactional) + require.True(t, ok) + + ctx := context.Background() + ctx, err := txnStore.BeginTX(ctx) + require.NoError(t, err) + signature := uuid.New() + err = createFn(ctx, signature, "", createTestRequest(signature)) + require.NoError(t, err) + err = txnStore.Rollback(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + ctx = context.Background() + _, err = getFn(ctx, signature, oauth2.NewSession("bar")) + // Since we rolled back above, the token should not exist and getting it should result in an error + require.Error(t, err) + + // create a new token, revoke it, then rollback the revoke. We should be able to then get it successfully. + signature2 := uuid.New() + err = createFn(ctx, signature2, "", createTestRequest(signature2)) + require.NoError(t, err) + _, err = getFn(ctx, signature2, oauth2.NewSession("bar")) require.NoError(t, err) ctx, err = txnStore.BeginTX(context.Background()) @@ -1305,7 +1405,7 @@ func doTestRollback(m InternalRegistry, t *testing.T, err = txnStore.Rollback(ctx) require.NoError(t, err) - _, err = getFn(context.Background(), signature2, NewSession("bar")) + _, err = getFn(context.Background(), signature2, oauth2.NewSession("bar")) require.NoError(t, err) } @@ -1319,6 +1419,6 @@ func createTestRequest(id string) *fosite.Request { RequestedAudience: fosite.Arguments{"ad1", "ad2"}, GrantedAudience: fosite.Arguments{"ad1", "ad2"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, } } diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index 2a48a52f8e7..9804704102c 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -7,16 +7,13 @@ import ( "context" "flag" "testing" + "time" - "github.com/stretchr/testify/require" + "github.com/ory/hydra/v2/internal/testhelpers" - "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" - . "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/contextx" - "github.com/ory/x/networkx" "github.com/ory/x/sqlcon/dockertest" ) @@ -29,7 +26,7 @@ func TestMain(m *testing.M) { var registries = make(map[string]driver.Registry) var cleanRegistries = func(t *testing.T) { - registries["memory"] = internal.NewRegistryMemory(t, internal.NewConfigurationWithDefaults(), &contextx.Default{}) + registries["memory"] = testhelpers.NewRegistryMemory(t, testhelpers.NewConfigurationWithDefaults(), &contextx.Default{}) } // returns clean registries that can safely be used for one test @@ -38,7 +35,7 @@ func setupRegistries(t *testing.T) { if len(registries) == 0 && !testing.Short() { // first time called and sql tests var cleanSQL func(*testing.T) - registries["postgres"], registries["mysql"], registries["cockroach"], cleanSQL = internal.ConnectDatabases(t, true, &contextx.Default{}) + registries["postgres"], registries["mysql"], registries["cockroach"], cleanSQL = testhelpers.ConnectDatabases(t, false, &contextx.Default{}) cleanMem := cleanRegistries cleanMem(t) cleanRegistries = func(t *testing.T) { @@ -52,6 +49,8 @@ func setupRegistries(t *testing.T) { } func TestManagers(t *testing.T) { + setupRegistries(t) + ctx := context.Background() tests := []struct { name string @@ -68,18 +67,43 @@ func TestManagers(t *testing.T) { } for _, tc := range tests { t.Run("suite="+tc.name, func(t *testing.T) { - setupRegistries(t) + for k, r := range registries { + t.Run("database="+k, func(t *testing.T) { + store := testhelpers.NewRegistrySQLFromURL(t, r.Config().DSN(), true, &contextx.Default{}) + store.Config().MustSet(ctx, config.KeyEncryptSessionData, tc.enableSessionEncrypted) - require.NoError(t, registries["memory"].ClientManager().CreateClient(context.Background(), &client.Client{ID: "foobar"})) // this is a workaround because the client is not being created for memory store by test helpers. + if k != "memory" { + t.Run("testHelperUniqueConstraints", testHelperRequestIDMultiples(store, k)) + t.Run("case=testFositeSqlStoreTransactionsCommitAccessToken", testFositeSqlStoreTransactionCommitAccessToken(store)) + t.Run("case=testFositeSqlStoreTransactionsRollbackAccessToken", testFositeSqlStoreTransactionRollbackAccessToken(store)) + t.Run("case=testFositeSqlStoreTransactionCommitRefreshToken", testFositeSqlStoreTransactionCommitRefreshToken(store)) + t.Run("case=testFositeSqlStoreTransactionRollbackRefreshToken", testFositeSqlStoreTransactionRollbackRefreshToken(store)) + t.Run("case=testFositeSqlStoreTransactionCommitAuthorizeCode", testFositeSqlStoreTransactionCommitAuthorizeCode(store)) + t.Run("case=testFositeSqlStoreTransactionRollbackAuthorizeCode", testFositeSqlStoreTransactionRollbackAuthorizeCode(store)) + t.Run("case=testFositeSqlStoreTransactionCommitPKCERequest", testFositeSqlStoreTransactionCommitPKCERequest(store)) + t.Run("case=testFositeSqlStoreTransactionRollbackPKCERequest", testFositeSqlStoreTransactionRollbackPKCERequest(store)) + t.Run("case=testFositeSqlStoreTransactionCommitOpenIdConnectSession", testFositeSqlStoreTransactionCommitOpenIdConnectSession(store)) + t.Run("case=testFositeSqlStoreTransactionRollbackOpenIdConnectSession", testFositeSqlStoreTransactionRollbackOpenIdConnectSession(store)) + } - for k, store := range registries { - net := &networkx.Network{} - require.NoError(t, store.Persister().Connection(context.Background()).First(net)) - store.Config().MustSet(ctx, config.KeyEncryptSessionData, tc.enableSessionEncrypted) - store.WithContextualizer(&contextx.Static{NID: net.ID, C: store.Config().Source(ctx)}) - TestHelperRunner(t, store, k) + t.Run("testHelperCreateGetDeleteAuthorizeCodes", testHelperCreateGetDeleteAuthorizeCodes(store)) + t.Run("testHelperExpiryFields", testHelperExpiryFields(store)) + t.Run("testHelperCreateGetDeleteAccessTokenSession", testHelperCreateGetDeleteAccessTokenSession(store)) + t.Run("testHelperNilAccessToken", testHelperNilAccessToken(store)) + t.Run("testHelperCreateGetDeleteOpenIDConnectSession", testHelperCreateGetDeleteOpenIDConnectSession(store)) + t.Run("testHelperCreateGetDeleteRefreshTokenSession", testHelperCreateGetDeleteRefreshTokenSession(store)) + t.Run("testHelperRevokeRefreshToken", testHelperRevokeRefreshToken(store)) + t.Run("testHelperCreateGetDeletePKCERequestSession", testHelperCreateGetDeletePKCERequestSession(store)) + t.Run("testHelperFlushTokens", testHelperFlushTokens(store, time.Hour)) + t.Run("testHelperFlushTokensWithLimitAndBatchSize", testHelperFlushTokensWithLimitAndBatchSize(store, 3, 2)) + t.Run("testFositeStoreSetClientAssertionJWT", testFositeStoreSetClientAssertionJWT(store)) + t.Run("testFositeStoreClientAssertionJWTValid", testFositeStoreClientAssertionJWTValid(store)) + t.Run("testHelperDeleteAccessTokens", testHelperDeleteAccessTokens(store)) + t.Run("testHelperRevokeAccessToken", testHelperRevokeAccessToken(store)) + t.Run("testFositeJWTBearerGrantStorage", testFositeJWTBearerGrantStorage(store)) + t.Run("testHelperRotateRefreshToken", testHelperRotateRefreshToken(store)) + }) } }) - } } diff --git a/oauth2/handler.go b/oauth2/handler.go index 288ed1f16f0..3f1a633038d 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -727,11 +727,14 @@ type revokeOAuth2Token struct { // default: errorOAuth2 func (h *Handler) revokeOAuth2Token(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - events.Trace(ctx, events.AccessTokenRevoked) - err := h.r.OAuth2Provider().NewRevocationRequest(ctx, r) + err := h.r.Persister().Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { + return h.r.OAuth2Provider().NewRevocationRequest(ctx, r) + }) if err != nil { x.LogError(r, err, h.r.Logger()) + } else { + events.Trace(ctx, events.AccessTokenRevoked) } h.r.OAuth2Provider().WriteRevocationResponse(ctx, w, err) diff --git a/oauth2/handler_fallback_endpoints_test.go b/oauth2/handler_fallback_endpoints_test.go index 191cd15a03a..9e3107b722b 100644 --- a/oauth2/handler_fallback_endpoints_test.go +++ b/oauth2/handler_fallback_endpoints_test.go @@ -10,22 +10,23 @@ import ( "net/http/httptest" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/x/httprouterx" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/stretchr/testify/assert" ) func TestHandlerConsent(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(context.Background(), config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) h := reg.OAuth2Handler() r := x.NewRouterAdmin(conf.AdminURL) diff --git a/oauth2/handler_test.go b/oauth2/handler_test.go index f2d159af614..50705fad6bf 100644 --- a/oauth2/handler_test.go +++ b/oauth2/handler_test.go @@ -15,6 +15,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/x/httprouterx" @@ -31,13 +33,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" - "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" "github.com/ory/hydra/v2/client" + "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/oauth2" ) @@ -45,9 +45,9 @@ var lifespan = time.Hour func TestHandlerDeleteHandler(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyIssuerURL, "http://hydra.localhost") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) cm := reg.ClientManager() store := reg.OAuth2Storage() @@ -88,12 +88,12 @@ func TestHandlerDeleteHandler(t *testing.T) { func TestUserinfo(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyScopeStrategy, "") conf.MustSet(ctx, config.KeyAuthCodeLifespan, lifespan) conf.MustSet(ctx, config.KeyIssuerURL, "http://hydra.localhost") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) + testhelpers.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) ctrl := gomock.NewController(t) op := NewMockOAuth2Provider(ctrl) @@ -340,7 +340,7 @@ func TestUserinfo(t *testing.T) { func TestHandlerWellKnown(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() t.Run(fmt.Sprintf("hsm_enabled=%v", conf.HSMEnabled()), func(t *testing.T) { conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") conf.MustSet(ctx, config.KeyIssuerURL, "http://hydra.localhost") @@ -348,7 +348,7 @@ func TestHandlerWellKnown(t *testing.T) { conf.MustSet(ctx, config.KeyOIDCDiscoverySupportedClaims, []string{"sub"}) conf.MustSet(ctx, config.KeyOAuth2ClientRegistrationURL, "http://client-register/registration") conf.MustSet(ctx, config.KeyOIDCDiscoveryUserinfoEndpoint, "/userinfo") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) h := oauth2.NewHandler(reg, conf) diff --git a/oauth2/helper_test.go b/oauth2/helper_test.go index 3a40592bfdd..04f41298b71 100644 --- a/oauth2/helper_test.go +++ b/oauth2/helper_test.go @@ -5,6 +5,11 @@ package oauth2_test import ( "context" + "testing" + + "github.com/oleiade/reflections" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" @@ -20,3 +25,60 @@ func Tokens(c fosite.Configurator, length int) (res [][]string) { } return res } + +func AssertObjectKeysEqual(t *testing.T, a, b interface{}, keys ...string) { + assert.True(t, len(keys) > 0, "No keys provided.") + for _, k := range keys { + c, err := reflections.GetField(a, k) + assert.Nil(t, err) + d, err := reflections.GetField(b, k) + assert.Nil(t, err) + assert.Equal(t, c, d, "%s", k) + } +} + +func AssertObjectKeysNotEqual(t *testing.T, a, b interface{}, keys ...string) { + assert.True(t, len(keys) > 0, "No keys provided.") + for _, k := range keys { + c, err := reflections.GetField(a, k) + assert.Nil(t, err) + d, err := reflections.GetField(b, k) + assert.Nil(t, err) + assert.NotEqual(t, c, d, "%s", k) + } +} + +func RequireObjectKeysEqual(t *testing.T, a, b interface{}, keys ...string) { + assert.True(t, len(keys) > 0, "No keys provided.") + for _, k := range keys { + c, err := reflections.GetField(a, k) + assert.Nil(t, err) + d, err := reflections.GetField(b, k) + assert.Nil(t, err) + require.Equal(t, c, d, "%s", k) + } +} +func RequireObjectKeysNotEqual(t *testing.T, a, b interface{}, keys ...string) { + assert.True(t, len(keys) > 0, "No keys provided.") + for _, k := range keys { + c, err := reflections.GetField(a, k) + assert.Nil(t, err) + d, err := reflections.GetField(b, k) + assert.Nil(t, err) + require.NotEqual(t, c, d, "%s", k) + } +} + +func TestAssertObjectsAreEqualByKeys(t *testing.T) { + type foo struct { + Name string + Body int + } + a := &foo{"foo", 1} + b := &foo{"bar", 1} + c := &foo{"baz", 3} + + AssertObjectKeysEqual(t, a, a, "Name", "Body") + AssertObjectKeysNotEqual(t, a, b, "Name") + AssertObjectKeysNotEqual(t, a, c, "Name", "Body") +} diff --git a/oauth2/helpers.go b/oauth2/helpers.go new file mode 100644 index 00000000000..4db4bf84d8e --- /dev/null +++ b/oauth2/helpers.go @@ -0,0 +1,51 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "context" + "crypto/sha256" + "fmt" + "time" + + "github.com/gobuffalo/pop/v6" + gofrsuuid "github.com/gofrs/uuid" + + "github.com/ory/hydra/v2/x" +) + +func signatureFromJTI(jti string) string { + return fmt.Sprintf("%x", sha256.Sum256([]byte(jti))) +} + +type BlacklistedJTI struct { + JTI string `db:"-"` + ID string `db:"signature"` + Expiry time.Time `db:"expires_at"` + NID gofrsuuid.UUID `db:"nid"` +} + +func (j *BlacklistedJTI) AfterFind(_ *pop.Connection) error { + j.Expiry = j.Expiry.UTC() + return nil +} + +func (BlacklistedJTI) TableName() string { + return "hydra_oauth2_jti_blacklist" +} + +func NewBlacklistedJTI(jti string, exp time.Time) *BlacklistedJTI { + return &BlacklistedJTI{ + JTI: jti, + ID: signatureFromJTI(jti), + // because the database timestamp types are not as accurate as time.Time we truncate to seconds (which should always work) + Expiry: exp.UTC().Truncate(time.Second), + } +} + +type AssertionJWTReader interface { + x.FositeStorer + GetClientAssertionJWT(ctx context.Context, jti string) (*BlacklistedJTI, error) + SetClientAssertionJWTRaw(context.Context, *BlacklistedJTI) error +} diff --git a/oauth2/introspector_test.go b/oauth2/introspector_test.go index 16b279f036f..43b565d2f58 100644 --- a/oauth2/introspector_test.go +++ b/oauth2/introspector_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/x/httprouterx" @@ -30,12 +32,12 @@ import ( func TestIntrospectorSDK(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyScopeStrategy, "wildcard") conf.MustSet(ctx, config.KeyIssuerURL, "https://foobariss") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + testhelpers.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) internal.AddFositeExamples(reg) tokens := Tokens(reg.OAuth2ProviderConfig(), 4) diff --git a/oauth2/oauth2_auth_code_bench_test.go b/oauth2/oauth2_auth_code_bench_test.go index 568ff00287c..9347982630a 100644 --- a/oauth2/oauth2_auth_code_bench_test.go +++ b/oauth2/oauth2_auth_code_bench_test.go @@ -33,7 +33,6 @@ import ( hydra "github.com/ory/hydra-client-go/v2" hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" @@ -79,7 +78,7 @@ func BenchmarkAuthCode(b *testing.B) { dsn := stringsx.Coalesce(os.Getenv("DSN"), "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable&max_conns=20&max_idle_conns=20") // dsn := "mysql://root:secret@tcp(localhost:3444)/mysql?max_conns=16&max_idle_conns=16" // dsn := "cockroach://root@localhost:3446/defaultdb?sslmode=disable&max_conns=16&max_idle_conns=16" - reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) + reg := testhelpers.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) reg.Config().MustSet(ctx, config.KeyLogLevel, "error") reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index feea2451e27..170896b9f69 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net/http" "net/http/httptest" "net/url" @@ -20,6 +21,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/jwk" + "github.com/go-jose/go-jose/v3" "github.com/golang-jwt/jwt/v5" "github.com/julienschmidt/httprouter" @@ -36,7 +39,6 @@ import ( "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" hydraoauth2 "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/x" @@ -62,1299 +64,1482 @@ type clientCreator interface { CreateClient(context.Context, *client.Client) error } -// TestAuthCodeWithDefaultStrategy runs proper integration tests against in-memory and database connectors, specifically -// we test: -// -// - [x] If the flow - in general - works -// - [x] If `authenticatedAt` is properly managed across the lifecycle -// - [x] The value `authenticatedAt` should be an old time if no user interaction wrt login was required -// - [x] The value `authenticatedAt` should be a recent time if user interaction wrt login was required -// -// - [x] If `requestedAt` is properly managed across the lifecycle -// - [x] The value of `requestedAt` must be the initial request time, not some other time (e.g. when accepting login) -// -// - [x] If `id_token_hint` is handled properly -// - [x] What happens if `id_token_hint` does not match the value from the handled authentication request ("accept login") -func TestAuthCodeWithDefaultStrategy(t *testing.T) { - ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") - publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) - - publicClient := hydra.NewAPIClient(hydra.NewConfiguration()) - publicClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: publicTS.URL}} - adminClient := hydra.NewAPIClient(hydra.NewConfiguration()) - adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}} - - getAuthorizeCode := func(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { - if c == nil { - c = testhelpers.NewEmptyJarClient(t) - } +func getAuthorizeCode(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { + if c == nil { + c = testhelpers.NewEmptyJarClient(t) + } - state := uuid.New() - resp, err := c.Get(conf.AuthCodeURL(state, params...)) - require.NoError(t, err) - defer resp.Body.Close() + state := uuid.New() + resp, err := c.Get(conf.AuthCodeURL(state, params...)) + require.NoError(t, err) + defer resp.Body.Close() - q := resp.Request.URL.Query() - require.EqualValues(t, state, q.Get("state")) - return q.Get("code"), resp - } + q := resp.Request.URL.Query() + require.EqualValues(t, state, q.Get("state")) + return q.Get("code"), resp +} - acceptLoginHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() - require.NoError(t, err) - - assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) - assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) - assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) - assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) - assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) - assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) - assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) - assert.Contains(t, rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) - - acceptBody := hydra.AcceptOAuth2LoginRequest{ - Subject: subject, - Remember: pointerx.Ptr(!rr.Skip), - Acr: pointerx.Ptr("1"), - Amr: []string{"pwd"}, - Context: map[string]interface{}{"context": "bar"}, - } - if checkRequestPayload != nil { - if b := checkRequestPayload(rr); b != nil { - acceptBody = *b - } - } +func acceptLoginHandler(t *testing.T, c *client.Client, adminClient *hydra.APIClient, reg driver.Registry, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() + require.NoError(t, err) - v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). - LoginChallenge(r.URL.Query().Get("login_challenge")). - AcceptOAuth2LoginRequest(acceptBody). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.Contains(t, rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) + + acceptBody := hydra.AcceptOAuth2LoginRequest{ + Subject: subject, + Remember: pointerx.Ptr(!rr.Skip), + Acr: pointerx.Ptr("1"), + Amr: []string{"pwd"}, + Context: map[string]interface{}{"context": "bar"}, } - } - - acceptConsentHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() - require.NoError(t, err) - - assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) - assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) - assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) - assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) - assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) - assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) - assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) - assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) - assert.Contains(t, *rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) - if checkRequestPayload != nil { - checkRequestPayload(rr) + if checkRequestPayload != nil { + if b := checkRequestPayload(rr); b != nil { + acceptBody = *b } - - assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) - v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). - ConsentChallenge(r.URL.Query().Get("consent_challenge")). - AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ - GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), - GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, - Session: &hydra.AcceptOAuth2ConsentRequestSession{ - AccessToken: map[string]interface{}{"foo": "bar"}, - IdToken: map[string]interface{}{"bar": "baz", "email": "foo@bar.com"}, - }, - }). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) } - } - - assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { - introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) - actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) - require.NoError(t, err, "%s", introspect) - requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second) - } - assertIDToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedSubject, expectedNonce string, expectedExp time.Time) gjson.Result { - idt, ok := token.Extra("id_token").(string) - require.True(t, ok) - assert.NotEmpty(t, idt) - - body, err := x.DecodeSegment(strings.Split(idt, ".")[1]) + v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(acceptBody). + Execute() require.NoError(t, err) - - claims := gjson.ParseBytes(body) - assert.True(t, time.Now().After(time.Unix(claims.Get("iat").Int(), 0)), "%s", claims) - assert.True(t, time.Now().After(time.Unix(claims.Get("nbf").Int(), 0)), "%s", claims) - assert.True(t, time.Now().Before(time.Unix(claims.Get("exp").Int(), 0)), "%s", claims) - requirex.EqualTime(t, expectedExp, time.Unix(claims.Get("exp").Int(), 0), 2*time.Second) - assert.NotEmpty(t, claims.Get("jti").String(), "%s", claims) - assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), claims.Get("iss").String(), "%s", claims) - assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) - assert.Equal(t, "1", claims.Get("acr").String(), "%s", claims) - require.Len(t, claims.Get("amr").Array(), 1, "%s", claims) - assert.EqualValues(t, "pwd", claims.Get("amr").Array()[0].String(), "%s", claims) - - require.Len(t, claims.Get("aud").Array(), 1, "%s", claims) - assert.EqualValues(t, c.ClientID, claims.Get("aud").Array()[0].String(), "%s", claims) - assert.EqualValues(t, expectedSubject, claims.Get("sub").String(), "%s", claims) - assert.EqualValues(t, expectedNonce, claims.Get("nonce").String(), "%s", claims) - assert.EqualValues(t, `baz`, claims.Get("bar").String(), "%s", claims) - assert.EqualValues(t, `foo@bar.com`, claims.Get("email").String(), "%s", claims) - assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) - - return claims + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) } +} - introspectAccessToken := func(t *testing.T, conf *oauth2.Config, token *oauth2.Token, expectedSubject string) gjson.Result { - require.NotEmpty(t, token.AccessToken) - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.True(t, i.Get("active").Bool(), "%s", i) - assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) - assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) - assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) - return i - } +func acceptConsentHandler(t *testing.T, c *client.Client, adminClient *hydra.APIClient, reg driver.Registry, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() + require.NoError(t, err) - assertJWTAccessToken := func(t *testing.T, strat string, conf *oauth2.Config, token *oauth2.Token, expectedSubject string, expectedExp time.Time, scopes string) gjson.Result { - require.NotEmpty(t, token.AccessToken) - parts := strings.Split(token.AccessToken, ".") - if strat != "jwt" { - require.Len(t, parts, 2) - return gjson.Parse("null") + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) + assert.Contains(t, *rr.RequestUrl, reg.Config().OAuth2AuthURL(r.Context()).String()) + if checkRequestPayload != nil { + checkRequestPayload(rr) } - require.Len(t, parts, 3) - body, err := x.DecodeSegment(parts[1]) + assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) + v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). + ConsentChallenge(r.URL.Query().Get("consent_challenge")). + AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ + GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), + GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, + Session: &hydra.AcceptOAuth2ConsentRequestSession{ + AccessToken: map[string]interface{}{"foo": "bar"}, + IdToken: map[string]interface{}{"bar": "baz", "email": "foo@bar.com"}, + }, + }). + Execute() require.NoError(t, err) - - i := gjson.ParseBytes(body) - assert.NotEmpty(t, i.Get("jti").String()) - assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) - assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) - assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), i.Get("iss").String(), "%s", i) - assert.True(t, time.Now().After(time.Unix(i.Get("iat").Int(), 0)), "%s", i) - assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i) - assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i) - requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second) - assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) - assert.EqualValues(t, scopes, i.Get("scp").Raw, "%s", i) - return i - } - - waitForRefreshTokenExpiry := func() { - time.Sleep(reg.Config().GetRefreshTokenLifespan(ctx) + time.Second) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) } +} - t.Run("case=checks if request fails when audience does not match", func(t *testing.T) { - testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t)) - _, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("audience", "https://not-ory-api/")) - require.Empty(t, code) - }) - - subject := "aeneas-rekkas" - nonce := uuid.New() - t.Run("case=perform authorize code flow with ID token and refresh tokens", func(t *testing.T) { - run := func(t *testing.T, strategy string) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - iat := time.Now() - require.NoError(t, err) - - assert.Empty(t, token.Extra("c_nonce_draft_00"), "should not be set if not requested") - assert.Empty(t, token.Extra("c_nonce_expires_in_draft_00"), "should not be set if not requested") - introspectAccessToken(t, conf, token, subject) - assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) - iat = time.Now() - refreshedToken, err := conf.TokenSource(context.Background(), token).Token() - require.NoError(t, err) - - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) - require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) - introspectAccessToken(t, conf, refreshedToken, subject) - - t.Run("followup=refreshed tokens contain valid tokens", func(t *testing.T) { - assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - }) - - t.Run("followup=original access token is no longer valid", func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) - - t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) - }) - - t.Run("followup=but fail subsequent refresh because expiry was reached", func(t *testing.T) { - waitForRefreshTokenExpiry() +// TestAuthCodeWithDefaultStrategy runs proper integration tests against in-memory and database connectors, specifically +// we test: +// +// - [x] If the flow - in general - works +// - [x] If `authenticatedAt` is properly managed across the lifecycle +// - [x] The value `authenticatedAt` should be an old time if no user interaction wrt login was required +// - [x] The value `authenticatedAt` should be a recent time if user interaction wrt login was required +// +// - [x] If `requestedAt` is properly managed across the lifecycle +// - [x] The value of `requestedAt` must be the initial request time, not some other time (e.g. when accepting login) +// +// - [x] If `id_token_hint` is handled properly +// - [x] What happens if `id_token_hint` does not match the value from the handled authentication request ("accept login") +func TestAuthCodeWithDefaultStrategy(t *testing.T) { + setupRegistries(t) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + ctx := context.Background() - // Force golang to refresh token - refreshedToken.Expiry = refreshedToken.Expiry.Add(-time.Hour * 24) - _, err := conf.TokenSource(context.Background(), refreshedToken).Token() - require.Error(t, err) - }) - }) - } + for dbName, reg := range registries { + t.Run("registry="+dbName, func(t *testing.T) { + reg := testhelpers.NewRegistrySQLFromURL(t, reg.Config().DSN(), true, &contextx.Default{}) - t.Run("strategy=jwt", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") - run(t, "jwt") - }) + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OpenIDConnectKeyName)) + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OAuth2JWTKeyName)) - t.Run("strategy=opaque", func(t *testing.T) { reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - run(t, "opaque") - }) - }) - - t.Run("case=graceful token rotation", func(t *testing.T) { - run := func(t *testing.T, strategy string) { - reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") - t.Cleanup(func() { - reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) - }) + reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") + publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) + + publicClient := hydra.NewAPIClient(hydra.NewConfiguration()) + publicClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: publicTS.URL}} + adminClient := hydra.NewAPIClient(hydra.NewConfiguration()) + adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}} + + assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { + introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) + actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) + require.NoError(t, err, "%s", introspect) + requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second*3) + } - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) + assertIDToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedSubject, expectedNonce string, expectedExp time.Time) gjson.Result { + idt, ok := token.Extra("id_token").(string) + require.True(t, ok) + assert.NotEmpty(t, idt) - issueTokens := func(t *testing.T) *oauth2.Token { - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - iat := time.Now() + body, err := x.DecodeSegment(strings.Split(idt, ".")[1]) require.NoError(t, err) - introspectAccessToken(t, conf, token, subject) - assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - return token + claims := gjson.ParseBytes(body) + assert.True(t, time.Now().After(time.Unix(claims.Get("iat").Int(), 0)), "%s", claims) + assert.True(t, time.Now().After(time.Unix(claims.Get("nbf").Int(), 0)), "%s", claims) + assert.True(t, time.Now().Before(time.Unix(claims.Get("exp").Int(), 0)), "%s", claims) + if !expectedExp.IsZero() { + requirex.EqualTime(t, expectedExp, time.Unix(claims.Get("exp").Int(), 0), 1*time.Second) + } + assert.NotEmpty(t, claims.Get("jti").String(), "%s", claims) + assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), claims.Get("iss").String(), "%s", claims) + assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) + assert.Equal(t, "1", claims.Get("acr").String(), "%s", claims) + require.Len(t, claims.Get("amr").Array(), 1, "%s", claims) + assert.EqualValues(t, "pwd", claims.Get("amr").Array()[0].String(), "%s", claims) + + require.Len(t, claims.Get("aud").Array(), 1, "%s", claims) + assert.EqualValues(t, c.ClientID, claims.Get("aud").Array()[0].String(), "%s", claims) + assert.EqualValues(t, expectedSubject, claims.Get("sub").String(), "%s", claims) + assert.EqualValues(t, expectedNonce, claims.Get("nonce").String(), "%s", claims) + assert.EqualValues(t, `baz`, claims.Get("bar").String(), "%s", claims) + assert.EqualValues(t, `foo@bar.com`, claims.Get("email").String(), "%s", claims) + assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) + + return claims } - refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { - require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) - iat := time.Now() - refreshedToken, err := conf.TokenSource(context.Background(), token).Token() - require.NoError(t, err) - - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) - require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) - - introspectAccessToken(t, conf, refreshedToken, subject) - assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - return refreshedToken + introspectAccessToken := func(t *testing.T, conf *oauth2.Config, token *oauth2.Token, expectedSubject string) gjson.Result { + require.NotEmpty(t, token.AccessToken) + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.True(t, i.Get("active").Bool(), "%s", i) + assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) + assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) + return i } - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - start := time.Now() - - token := issueTokens(t) - var first, second *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - first = refreshTokens(t, token) - }) - - t.Run("followup=second refresh", func(t *testing.T) { - second = refreshTokens(t, token) - }) + assertJWTAccessToken := func(t *testing.T, strat string, conf *oauth2.Config, token *oauth2.Token, expectedSubject string, expectedExp time.Time, scopes string) gjson.Result { + require.NotEmpty(t, token.AccessToken) + parts := strings.Split(token.AccessToken, ".") + if strat != "jwt" { + require.Len(t, parts, 2) + return gjson.Parse("null") + } + require.Len(t, parts, 3) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + body, err := x.DecodeSegment(parts[1]) + require.NoError(t, err) - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i := gjson.ParseBytes(body) + assert.NotEmpty(t, i.Get("jti").String()) + assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) + assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) + assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), i.Get("iss").String(), "%s", i) + assert.True(t, time.Now().After(time.Unix(i.Get("iat").Int(), 0)), "%s", i) + assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i) + assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i) + requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) + assert.EqualValues(t, scopes, i.Get("scp").Raw, "%s", i) + return i + } - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + waitForRefreshTokenExpiry := func() { + time.Sleep(reg.Config().GetRefreshTokenLifespan(ctx) + time.Second) + } - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + subject := "aeneas-rekkas" + nonce := uuid.New() - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) + t.Run("case=checks if request fails when audience does not match", func(t *testing.T) { + testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t)) + _, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("audience", "https://not-ory-api/")) + require.Empty(t, code) }) - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - start := time.Now() - - token := issueTokens(t) - var first, second *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - first = refreshTokens(t, token) - }) - - t.Run("followup=second refresh", func(t *testing.T) { - second = refreshTokens(t, token) - }) + t.Run("case=perform authorize code flow with ID token and refresh tokens", func(t *testing.T) { + run := func(t *testing.T, strategy string) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() require.NoError(t, err) - _, err = conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + assert.Empty(t, token.Extra("c_nonce_draft_00"), "should not be set if not requested") + assert.Empty(t, token.Extra("c_nonce_expires_in_draft_00"), "should not be set if not requested") + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + iat = time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + introspectAccessToken(t, conf, refreshedToken, subject) - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("followup=refreshed tokens contain valid tokens", func(t *testing.T) { + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + }) - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("followup=original access token is no longer valid", func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) - }) + t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + }) - t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) { - start := time.Now() - token := issueTokens(t) - var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - a1Refresh = refreshTokens(t, token) - }) + t.Run("followup=but fail subsequent refresh because expiry was reached", func(t *testing.T) { + waitForRefreshTokenExpiry() - t.Run("followup=second refresh", func(t *testing.T) { - b1Refresh = refreshTokens(t, token) - }) + // Force golang to refresh token + refreshedToken.Expiry = refreshedToken.Expiry.Add(-time.Hour * 24) + _, err := conf.TokenSource(context.Background(), refreshedToken).Token() + require.Error(t, err) + }) + }) + } - t.Run("followup=first refresh from first refresh", func(t *testing.T) { - a2RefreshA = refreshTokens(t, a1Refresh) + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") }) - t.Run("followup=second refresh from first refresh", func(t *testing.T) { - a2RefreshB = refreshTokens(t, a1Refresh) + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") }) + }) - t.Run("followup=first refresh from second refresh", func(t *testing.T) { - b2RefreshA = refreshTokens(t, b1Refresh) - }) + t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) { + // Make sure we test against all crypto suites that we advertise. + cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute() + require.NoError(t, err) + supportedCryptoSuites := cfg.CredentialsSupportedDraft00[0].CryptographicSuitesSupported + + run := func(t *testing.T, strategy string) { + _, conf := newOAuth2Client( + t, + reg, + testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler), + withScope("openid userinfo_credential_draft_00"), + ) + testhelpers.NewLoginConsentUI(t, reg.Config(), + func(w http.ResponseWriter, r *http.Request) { + acceptBody := hydra.AcceptOAuth2LoginRequest{ + Subject: subject, + Acr: pointerx.Ptr("1"), + Amr: []string{"pwd"}, + Context: map[string]interface{}{"context": "bar"}, + } + v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(acceptBody). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + }, + func(w http.ResponseWriter, r *http.Request) { + rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() + require.NoError(t, err) - t.Run("followup=second refresh from second refresh", func(t *testing.T) { - b2RefreshB = refreshTokens(t, b1Refresh) - }) + assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) + v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). + ConsentChallenge(r.URL.Query().Get("consent_challenge")). + AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ + GrantScope: []string{"openid", "userinfo_credential_draft_00"}, + GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, + Session: &hydra.AcceptOAuth2ConsentRequestSession{ + AccessToken: map[string]interface{}{"foo": "bar"}, + IdToken: map[string]interface{}{"email": "foo@bar.com", "bar": "baz"}, + }, + }). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + }, + ) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("scope", "openid userinfo_credential_draft_00"), + ) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + iat := time.Now() - for k, token := range []*oauth2.Token{ - a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + vcNonce := token.Extra("c_nonce_draft_00").(string) + assert.NotEmpty(t, vcNonce) + expiry := token.Extra("c_nonce_expires_in_draft_00") + assert.NotEmpty(t, expiry) + assert.NoError(t, reg.Persister().IsNonceValid(ctx, token.AccessToken, vcNonce)) - i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("followup=successfully create a verifiable credential", func(t *testing.T) { + t.Parallel() - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + for _, alg := range supportedCryptoSuites { + alg := alg + t.Run(fmt.Sprintf("alg=%s", alg), func(t *testing.T) { + t.Parallel() + assertCreateVerifiableCredential(t, reg, vcNonce, token, jose.SignatureAlgorithm(alg)) + }) + } + }) - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("followup=get new nonce from priming request", func(t *testing.T) { + t.Parallel() + // Assert that we can fetch a verifiable credential with the nonce. + res, err := doPrimingRequest(t, reg, token, &hydraoauth2.CreateVerifiableCredentialRequestBody{ + Format: "jwt_vc_json", + Types: []string{"VerifiableCredential", "UserInfoCredential"}, }) - } - }) - }) - } + assert.NoError(t, err) - t.Run("strategy=jwt", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") - run(t, "jwt") - }) + t.Run("followup=successfully create a verifiable credential from fresh nonce", func(t *testing.T) { + assertCreateVerifiableCredential(t, reg, res.Nonce, token, jose.ES256) + }) + }) - t.Run("strategy=opaque", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - run(t, "opaque") - }) - }) + t.Run("followup=rejects proof signed by another key", func(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + name string + format string + proofType string + proof func() string + }{ + { + name: "proof=mismatching keys", + proof: func() string { + // Create mismatching public and private keys. + pubKey, _, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + _, privKey, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} + return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) + }, + }, + { + name: "proof=invalid format", + format: "invalid_format", + proof: func() string { + // Create mismatching public and private keys. + pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} + return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) + }, + }, + { + name: "proof=invalid type", + proofType: "invalid", + proof: func() string { + // Create mismatching public and private keys. + pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} + return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) + }, + }, + { + name: "proof=invalid nonce", + proof: func() string { + // Create mismatching public and private keys. + pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} + return createVCProofJWT(t, pubKeyJWK, privKey, "invalid nonce") + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, err := createVerifiableCredential(t, reg, token, &hydraoauth2.CreateVerifiableCredentialRequestBody{ + Format: stringsx.Coalesce(tc.format, "jwt_vc_json"), + Types: []string{"VerifiableCredential", "UserInfoCredential"}, + Proof: &hydraoauth2.VerifiableCredentialProof{ + ProofType: stringsx.Coalesce(tc.proofType, "jwt"), + JWT: tc.proof(), + }, + }) + require.Error(t, err) + assert.Equal(t, "invalid_request", err.Error()) + }) + } - t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) { - // Make sure we test against all crypto suites that we advertise. - cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute() - require.NoError(t, err) - supportedCryptoSuites := cfg.CredentialsSupportedDraft00[0].CryptographicSuitesSupported - - run := func(t *testing.T, strategy string) { - _, conf := newOAuth2Client( - t, - reg, - testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler), - withScope("openid userinfo_credential_draft_00"), - ) - testhelpers.NewLoginConsentUI(t, reg.Config(), - func(w http.ResponseWriter, r *http.Request) { - acceptBody := hydra.AcceptOAuth2LoginRequest{ - Subject: subject, - Acr: pointerx.Ptr("1"), - Amr: []string{"pwd"}, - Context: map[string]interface{}{"context": "bar"}, - } - v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). - LoginChallenge(r.URL.Query().Get("login_challenge")). - AcceptOAuth2LoginRequest(acceptBody). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) - }, - func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() - require.NoError(t, err) + }) - assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) - v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). - ConsentChallenge(r.URL.Query().Get("consent_challenge")). - AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ - GrantScope: []string{"openid", "userinfo_credential_draft_00"}, - GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, - Session: &hydra.AcceptOAuth2ConsentRequestSession{ - AccessToken: map[string]interface{}{"foo": "bar"}, - IdToken: map[string]interface{}{"email": "foo@bar.com", "bar": "baz"}, - }, - }). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) - }, - ) - - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("nonce", nonce), - oauth2.SetAuthURLParam("scope", "openid userinfo_credential_draft_00"), - ) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) - iat := time.Now() - - vcNonce := token.Extra("c_nonce_draft_00").(string) - assert.NotEmpty(t, vcNonce) - expiry := token.Extra("c_nonce_expires_in_draft_00") - assert.NotEmpty(t, expiry) - assert.NoError(t, reg.Persister().IsNonceValid(ctx, token.AccessToken, vcNonce)) - - t.Run("followup=successfully create a verifiable credential", func(t *testing.T) { - t.Parallel() - - for _, alg := range supportedCryptoSuites { - alg := alg - t.Run(fmt.Sprintf("alg=%s", alg), func(t *testing.T) { - t.Parallel() - assertCreateVerifiableCredential(t, reg, vcNonce, token, jose.SignatureAlgorithm(alg)) + t.Run("followup=access token and id token are valid", func(t *testing.T) { + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["openid","userinfo_credential_draft_00"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) }) } - }) - t.Run("followup=get new nonce from priming request", func(t *testing.T) { - t.Parallel() - // Assert that we can fetch a verifiable credential with the nonce. - res, err := doPrimingRequest(t, reg, token, &hydraoauth2.CreateVerifiableCredentialRequestBody{ - Format: "jwt_vc_json", - Types: []string{"VerifiableCredential", "UserInfoCredential"}, + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") }) - assert.NoError(t, err) - t.Run("followup=successfully create a verifiable credential from fresh nonce", func(t *testing.T) { - assertCreateVerifiableCredential(t, reg, res.Nonce, token, jose.ES256) + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") }) }) - t.Run("followup=rejects proof signed by another key", func(t *testing.T) { - t.Parallel() - for _, tc := range []struct { - name string - format string - proofType string - proof func() string - }{ - { - name: "proof=mismatching keys", - proof: func() string { - // Create mismatching public and private keys. - pubKey, _, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - _, privKey, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} - return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) - }, + t.Run("suite=invalid query params", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + otherClient, _ := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) + + withWrongClientAfterLogin := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("login_verifier") { + return nil + } + q.Set("client_id", otherClient.GetID()) + req.URL.RawQuery = q.Encode() + return nil }, - { - name: "proof=invalid format", - format: "invalid_format", - proof: func() string { - // Create mismatching public and private keys. - pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} - return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) - }, + } + withWrongClientAfterConsent := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("consent_verifier") { + return nil + } + q.Set("client_id", otherClient.GetID()) + req.URL.RawQuery = q.Encode() + return nil }, - { - name: "proof=invalid type", - proofType: "invalid", - proof: func() string { - // Create mismatching public and private keys. - pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} - return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) - }, + } + + withWrongScopeAfterLogin := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("login_verifier") { + return nil + } + q.Set("scope", "invalid scope") + req.URL.RawQuery = q.Encode() + return nil }, - { - name: "proof=invalid nonce", - proof: func() string { - // Create mismatching public and private keys. - pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} - return createVCProofJWT(t, pubKeyJWK, privKey, "invalid nonce") - }, + } + + withWrongScopeAfterConsent := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("consent_verifier") { + return nil + } + q.Set("scope", "invalid scope") + req.URL.RawQuery = q.Encode() + return nil }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - _, res := createVerifiableCredential(t, reg, token, &hydraoauth2.CreateVerifiableCredentialRequestBody{ - Format: stringsx.Coalesce(tc.format, "jwt_vc_json"), - Types: []string{"VerifiableCredential", "UserInfoCredential"}, - Proof: &hydraoauth2.VerifiableCredentialProof{ - ProofType: stringsx.Coalesce(tc.proofType, "jwt"), - JWT: tc.proof(), - }, - }) + } + for _, tc := range []struct { + name string + client *http.Client + expectedResponse string + }{{ + name: "fails with wrong client ID after login", + client: withWrongClientAfterLogin, + expectedResponse: "invalid_client", + }, { + name: "fails with wrong client ID after consent", + client: withWrongClientAfterConsent, + expectedResponse: "invalid_client", + }, { + name: "fails with wrong scopes after login", + client: withWrongScopeAfterLogin, + expectedResponse: "invalid_scope", + }, { + name: "fails with wrong scopes after consent", + client: withWrongScopeAfterConsent, + expectedResponse: "invalid_scope", + }} { + t.Run("case="+tc.name, func(t *testing.T) { + state := uuid.New() + resp, err := tc.client.Get(conf.AuthCodeURL(state)) require.NoError(t, err) - require.NotNil(t, res) - assert.Equal(t, "invalid_request", res.Error()) + assert.Equal(t, tc.expectedResponse, resp.Request.URL.Query().Get("error"), "%s", resp.Request.URL.RawQuery) + resp.Body.Close() }) } - }) - t.Run("followup=access token and id token are valid", func(t *testing.T) { - assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["openid","userinfo_credential_draft_00"]`) - assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + t.Run("case=checks if request fails when subject is empty", func(t *testing.T) { + testhelpers.NewLoginConsentUI(t, reg.Config(), func(w http.ResponseWriter, r *http.Request) { + _, res, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(ctx). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(hydra.AcceptOAuth2LoginRequest{Subject: "", Remember: pointerx.Ptr(true)}).Execute() + require.Error(t, err) // expects 400 + body := string(ioutilx.MustReadAll(res.Body)) + assert.Contains(t, body, "Field 'subject' must not be empty", "%s", body) + }, testhelpers.HTTPServerNoExpectedCallHandler(t)) + _, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + + _, err := testhelpers.NewEmptyJarClient(t).Get(conf.AuthCodeURL(uuid.New())) + require.NoError(t, err) }) - } - t.Run("strategy=jwt", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") - run(t, "jwt") - }) + t.Run("case=perform flow with prompt=registration", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - t.Run("strategy=opaque", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - run(t, "opaque") - }) - }) + regUI := httptest.NewServer(acceptLoginHandler(t, c, adminClient, reg, subject, nil)) + t.Cleanup(regUI.Close) + reg.Config().MustSet(ctx, config.KeyRegistrationURL, regUI.URL) - t.Run("suite=invalid query params", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - otherClient, _ := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - withWrongClientAfterLogin := &http.Client{ - Jar: testhelpers.NewEmptyCookieJar(t), - CheckRedirect: func(req *http.Request, _ []*http.Request) error { - if req.URL.Path != "/oauth2/auth" { - return nil - } - q := req.URL.Query() - if !q.Has("login_verifier") { - return nil - } - q.Set("client_id", otherClient.GetID()) - req.URL.RawQuery = q.Encode() - return nil - }, - } - withWrongClientAfterConsent := &http.Client{ - Jar: testhelpers.NewEmptyCookieJar(t), - CheckRedirect: func(req *http.Request, _ []*http.Request) error { - if req.URL.Path != "/oauth2/auth" { - return nil - } - q := req.URL.Query() - if !q.Has("consent_verifier") { - return nil - } - q.Set("client_id", otherClient.GetID()) - req.URL.RawQuery = q.Encode() - return nil - }, - } + testhelpers.NewLoginConsentUI(t, reg.Config(), + nil, + acceptConsentHandler(t, c, adminClient, reg, subject, nil)) - withWrongScopeAfterLogin := &http.Client{ - Jar: testhelpers.NewEmptyCookieJar(t), - CheckRedirect: func(req *http.Request, _ []*http.Request) error { - if req.URL.Path != "/oauth2/auth" { - return nil - } - q := req.URL.Query() - if !q.Has("login_verifier") { - return nil - } - q.Set("scope", "invalid scope") - req.URL.RawQuery = q.Encode() - return nil - }, - } + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("prompt", "registration"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) - withWrongScopeAfterConsent := &http.Client{ - Jar: testhelpers.NewEmptyCookieJar(t), - CheckRedirect: func(req *http.Request, _ []*http.Request) error { - if req.URL.Path != "/oauth2/auth" { - return nil - } - q := req.URL.Query() - if !q.Has("consent_verifier") { - return nil - } - q.Set("scope", "invalid scope") - req.URL.RawQuery = q.Encode() - return nil - }, - } - for _, tc := range []struct { - name string - client *http.Client - expectedResponse string - }{{ - name: "fails with wrong client ID after login", - client: withWrongClientAfterLogin, - expectedResponse: "invalid_client", - }, { - name: "fails with wrong client ID after consent", - client: withWrongClientAfterConsent, - expectedResponse: "invalid_client", - }, { - name: "fails with wrong scopes after login", - client: withWrongScopeAfterLogin, - expectedResponse: "invalid_scope", - }, { - name: "fails with wrong scopes after consent", - client: withWrongScopeAfterConsent, - expectedResponse: "invalid_scope", - }} { - t.Run("case="+tc.name, func(t *testing.T) { - state := uuid.New() - resp, err := tc.client.Get(conf.AuthCodeURL(state)) + token, err := conf.Exchange(context.Background(), code) require.NoError(t, err) - assert.Equal(t, tc.expectedResponse, resp.Request.URL.Query().Get("error"), "%s", resp.Request.URL.RawQuery) - resp.Body.Close() + + assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) }) - } - }) - t.Run("case=checks if request fails when subject is empty", func(t *testing.T) { - testhelpers.NewLoginConsentUI(t, reg.Config(), func(w http.ResponseWriter, r *http.Request) { - _, res, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(ctx). - LoginChallenge(r.URL.Query().Get("login_challenge")). - AcceptOAuth2LoginRequest(hydra.AcceptOAuth2LoginRequest{Subject: "", Remember: pointerx.Ptr(true)}).Execute() - require.Error(t, err) // expects 400 - body := string(ioutilx.MustReadAll(res.Body)) - assert.Contains(t, body, "Field 'subject' must not be empty", "%s", body) - }, testhelpers.HTTPServerNoExpectedCallHandler(t)) - _, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - - _, err := testhelpers.NewEmptyJarClient(t).Get(conf.AuthCodeURL(uuid.New())) - require.NoError(t, err) - }) + t.Run("case=perform flow with audience", func(t *testing.T) { + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) - t.Run("case=perform flow with prompt=registration", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) - regUI := httptest.NewServer(acceptLoginHandler(t, c, subject, nil)) - t.Cleanup(regUI.Close) - reg.Config().MustSet(ctx, config.KeyRegistrationURL, regUI.URL) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) - testhelpers.NewLoginConsentUI(t, reg.Config(), - nil, - acceptConsentHandler(t, c, subject, nil)) + claims := introspectAccessToken(t, conf, token, subject) + aud := claims.Get("aud").Array() + require.Len(t, aud, 1) + assert.EqualValues(t, aud[0].String(), expectAud) - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("prompt", "registration"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + }) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + t.Run("case=respects client token lifespan configuration", func(t *testing.T) { + run := func(t *testing.T, strategy string, c *client.Client, conf *oauth2.Config, expectedLifespans client.Lifespans) { + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) - assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) - }) + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() + require.NoError(t, err) - t.Run("case=perform flow with audience", func(t *testing.T) { - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + body := introspectAccessToken(t, conf, token, subject) + requirex.EqualTime(t, iat.Add(expectedLifespans.AuthorizationCodeGrantAccessTokenLifespan.Duration), time.Unix(body.Get("exp").Int(), 0), time.Second) - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(expectedLifespans.AuthorizationCodeGrantAccessTokenLifespan.Duration), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(expectedLifespans.AuthorizationCodeGrantIDTokenLifespan.Duration)) + assertRefreshToken(t, token, conf, iat.Add(expectedLifespans.AuthorizationCodeGrantRefreshTokenLifespan.Duration)) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + iat = time.Now() + require.NoError(t, err) + assertRefreshToken(t, refreshedToken, conf, iat.Add(expectedLifespans.RefreshTokenGrantRefreshTokenLifespan.Duration)) + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(expectedLifespans.RefreshTokenGrantAccessTokenLifespan.Duration), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(expectedLifespans.RefreshTokenGrantIDTokenLifespan.Duration)) - claims := introspectAccessToken(t, conf, token, subject) - aud := claims.Get("aud").Array() - require.Len(t, aud, 1) - assert.EqualValues(t, aud[0].String(), expectAud) + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) - assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) - }) + body := introspectAccessToken(t, conf, refreshedToken, subject) + requirex.EqualTime(t, iat.Add(expectedLifespans.RefreshTokenGrantAccessTokenLifespan.Duration), time.Unix(body.Get("exp").Int(), 0), time.Second) - t.Run("case=respects client token lifespan configuration", func(t *testing.T) { - run := func(t *testing.T, strategy string, c *client.Client, conf *oauth2.Config, expectedLifespans client.Lifespans) { - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - iat := time.Now() - require.NoError(t, err) - - body := introspectAccessToken(t, conf, token, subject) - requirex.EqualTime(t, iat.Add(expectedLifespans.AuthorizationCodeGrantAccessTokenLifespan.Duration), time.Unix(body.Get("exp").Int(), 0), time.Second) - - assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(expectedLifespans.AuthorizationCodeGrantAccessTokenLifespan.Duration), `["hydra","offline","openid"]`) - assertIDToken(t, token, conf, subject, nonce, iat.Add(expectedLifespans.AuthorizationCodeGrantIDTokenLifespan.Duration)) - assertRefreshToken(t, token, conf, iat.Add(expectedLifespans.AuthorizationCodeGrantRefreshTokenLifespan.Duration)) - - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) - refreshedToken, err := conf.TokenSource(context.Background(), token).Token() - iat = time.Now() - require.NoError(t, err) - assertRefreshToken(t, refreshedToken, conf, iat.Add(expectedLifespans.RefreshTokenGrantRefreshTokenLifespan.Duration)) - assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(expectedLifespans.RefreshTokenGrantAccessTokenLifespan.Duration), `["hydra","offline","openid"]`) - assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(expectedLifespans.RefreshTokenGrantIDTokenLifespan.Duration)) + t.Run("followup=original access token is no longer valid", func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) - require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + }) + }) + } - body := introspectAccessToken(t, conf, refreshedToken, subject) - requirex.EqualTime(t, iat.Add(expectedLifespans.RefreshTokenGrantAccessTokenLifespan.Duration), time.Unix(body.Get("exp").Int(), 0), time.Second) + t.Run("case=custom-lifespans-active-jwt", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + ls := testhelpers.TestLifespans + ls.AuthorizationCodeGrantAccessTokenLifespan = x.NullDuration{Valid: true, Duration: 6 * time.Second} + testhelpers.UpdateClientTokenLifespans( + t, + &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, + c.GetID(), + ls, adminTS, + ) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt", c, conf, ls) + }) - t.Run("followup=original access token is no longer valid", func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("case=custom-lifespans-active-opaque", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + ls := testhelpers.TestLifespans + ls.AuthorizationCodeGrantAccessTokenLifespan = x.NullDuration{Valid: true, Duration: 6 * time.Second} + testhelpers.UpdateClientTokenLifespans( + t, + &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, + c.GetID(), + ls, adminTS, + ) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque", c, conf, ls) }) - t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + t.Run("case=custom-lifespans-unset", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.UpdateClientTokenLifespans(t, &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, c.GetID(), testhelpers.TestLifespans, adminTS) + testhelpers.UpdateClientTokenLifespans(t, &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, c.GetID(), client.Lifespans{}, adminTS) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + + //goland:noinspection GoDeprecation + expectedLifespans := client.Lifespans{ + AuthorizationCodeGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + AuthorizationCodeGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, + AuthorizationCodeGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, + ClientCredentialsGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + ImplicitGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + ImplicitGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, + JwtBearerGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + PasswordGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + PasswordGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, + RefreshTokenGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, + RefreshTokenGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + RefreshTokenGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, + } + run(t, "opaque", c, conf, expectedLifespans) }) }) - } - t.Run("case=custom-lifespans-active-jwt", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - ls := testhelpers.TestLifespans - ls.AuthorizationCodeGrantAccessTokenLifespan = x.NullDuration{Valid: true, Duration: 6 * time.Second} - testhelpers.UpdateClientTokenLifespans( - t, - &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, - c.GetID(), - ls, adminTS, - ) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") - run(t, "jwt", c, conf, ls) - }) + t.Run("case=use remember feature and prompt=none", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) - t.Run("case=custom-lifespans-active-opaque", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - ls := testhelpers.TestLifespans - ls.AuthorizationCodeGrantAccessTokenLifespan = x.NullDuration{Valid: true, Duration: 6 * time.Second} - testhelpers.UpdateClientTokenLifespans( - t, - &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, - c.GetID(), - ls, adminTS, - ) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - run(t, "opaque", c, conf, ls) - }) + oc := testhelpers.NewEmptyJarClient(t) + code, _ := getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("prompt", "login consent"), + oauth2.SetAuthURLParam("max_age", "1"), + ) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + introspectAccessToken(t, conf, token, subject) - t.Run("case=custom-lifespans-unset", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.UpdateClientTokenLifespans(t, &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, c.GetID(), testhelpers.TestLifespans, adminTS) - testhelpers.UpdateClientTokenLifespans(t, &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, c.GetID(), client.Lifespans{}, adminTS) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + // Reset UI to check for skip values + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + require.True(t, r.Skip) + require.EqualValues(t, subject, r.Subject) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + require.True(t, *r.Skip) + require.EqualValues(t, subject, *r.Subject) + }), + ) - //goland:noinspection GoDeprecation - expectedLifespans := client.Lifespans{ - AuthorizationCodeGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - AuthorizationCodeGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, - AuthorizationCodeGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, - ClientCredentialsGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - ImplicitGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - ImplicitGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, - JwtBearerGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - PasswordGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - PasswordGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, - RefreshTokenGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, - RefreshTokenGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - RefreshTokenGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, - } - run(t, "opaque", c, conf, expectedLifespans) - }) - }) + t.Run("followup=checks if authenticatedAt/requestedAt is properly forwarded across the lifecycle by checking if prompt=none works", func(t *testing.T) { + // In order to check if authenticatedAt/requestedAt works, we'll sleep first in order to ensure that authenticatedAt is in the past + // if handled correctly. + time.Sleep(time.Second + time.Nanosecond) - t.Run("case=use remember feature and prompt=none", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - oc := testhelpers.NewEmptyJarClient(t) - code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("nonce", nonce), - oauth2.SetAuthURLParam("prompt", "login consent"), - oauth2.SetAuthURLParam("max_age", "1"), - ) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) - introspectAccessToken(t, conf, token, subject) - - // Reset UI to check for skip values - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - require.True(t, r.Skip) - require.EqualValues(t, subject, r.Subject) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - require.True(t, *r.Skip) - require.EqualValues(t, subject, *r.Subject) - }), - ) - - t.Run("followup=checks if authenticatedAt/requestedAt is properly forwarded across the lifecycle by checking if prompt=none works", func(t *testing.T) { - // In order to check if authenticatedAt/requestedAt works, we'll sleep first in order to ensure that authenticatedAt is in the past - // if handled correctly. - time.Sleep(time.Second + time.Nanosecond) - - code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("nonce", nonce), - oauth2.SetAuthURLParam("prompt", "none"), - oauth2.SetAuthURLParam("max_age", "60"), - ) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) - original := introspectAccessToken(t, conf, token, subject) - - t.Run("followup=run the flow three more times", func(t *testing.T) { - for i := 0; i < 3; i++ { - t.Run(fmt.Sprintf("run=%d", i), func(t *testing.T) { + code, _ := getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("prompt", "none"), + oauth2.SetAuthURLParam("max_age", "60"), + ) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + original := introspectAccessToken(t, conf, token, subject) + + t.Run("followup=run the flow three more times", func(t *testing.T) { + for i := 0; i < 3; i++ { + t.Run(fmt.Sprintf("run=%d", i), func(t *testing.T) { + code, _ := getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("prompt", "none"), + oauth2.SetAuthURLParam("max_age", "60"), + ) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + followup := introspectAccessToken(t, conf, token, subject) + assert.Equal(t, original.Get("auth_time").Int(), followup.Get("auth_time").Int()) + }) + } + }) + + t.Run("followup=fails when max age is reached and prompt is none", func(t *testing.T) { code, _ := getAuthorizeCode(t, conf, oc, oauth2.SetAuthURLParam("nonce", nonce), oauth2.SetAuthURLParam("prompt", "none"), - oauth2.SetAuthURLParam("max_age", "60"), + oauth2.SetAuthURLParam("max_age", "1"), + ) + require.Empty(t, code) + }) + + t.Run("followup=passes and resets skip when prompt=login", func(t *testing.T) { + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + require.False(t, r.Skip) + require.Empty(t, r.Subject) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + require.True(t, *r.Skip) + require.EqualValues(t, subject, *r.Subject) + }), + ) + code, _ := getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("prompt", "login"), + oauth2.SetAuthURLParam("max_age", "1"), ) require.NotEmpty(t, code) token, err := conf.Exchange(context.Background(), code) require.NoError(t, err) - followup := introspectAccessToken(t, conf, token, subject) - assert.Equal(t, original.Get("auth_time").Int(), followup.Get("auth_time").Int()) + introspectAccessToken(t, conf, token, subject) + assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) }) - } + }) }) - t.Run("followup=fails when max age is reached and prompt is none", func(t *testing.T) { + t.Run("case=should fail if prompt=none but no auth session given", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) + + oc := testhelpers.NewEmptyJarClient(t) code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("nonce", nonce), oauth2.SetAuthURLParam("prompt", "none"), - oauth2.SetAuthURLParam("max_age", "1"), ) require.Empty(t, code) }) - t.Run("followup=passes and resets skip when prompt=login", func(t *testing.T) { + t.Run("case=requires re-authentication when id_token_hint is set to a user 'patrik-neu' but the session is 'aeneas-rekkas' and then fails because the user id from the log in endpoint is 'aeneas-rekkas'", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { require.False(t, r.Skip) require.Empty(t, r.Subject) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - require.True(t, *r.Skip) - require.EqualValues(t, subject, *r.Subject) + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) + + oc := testhelpers.NewEmptyJarClient(t) + + // Create login session for aeneas-rekkas + code, _ := getAuthorizeCode(t, conf, oc) + require.NotEmpty(t, code) + + // Perform authentication for aeneas-rekkas which fails because id_token_hint is patrik-neu + code, _ = getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("id_token_hint", testhelpers.NewIDToken(t, reg, "patrik-neu")), + ) + require.Empty(t, code) + }) + + t.Run("case=should not cause issues if max_age is very low and consent takes a long time", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + time.Sleep(time.Second * 2) + return nil }), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) - code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("nonce", nonce), - oauth2.SetAuthURLParam("prompt", "login"), - oauth2.SetAuthURLParam("max_age", "1"), + + code, _ := getAuthorizeCode(t, conf, nil) + require.NotEmpty(t, code) + }) + + t.Run("case=ensure consistent claims returned for userinfo", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) + + code, _ := getAuthorizeCode(t, conf, nil) require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) require.NoError(t, err) - introspectAccessToken(t, conf, token, subject) - assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + + idClaims := assertIDToken(t, token, conf, subject, "", time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + + uiClaims := testhelpers.Userinfo(t, token, publicTS) + + for _, f := range []string{ + "sub", + "iss", + "aud", + "bar", + "auth_time", + } { + assert.NotEmpty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) + assert.EqualValues(t, idClaims.Get(f).Raw, uiClaims.Get(f).Raw, "%s\nuserinfo: %s\nidtoken: %s", f, uiClaims, idClaims) + } + + for _, f := range []string{ + "at_hash", + "c_hash", + "nonce", + "sid", + "jti", + } { + assert.Empty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) + } }) - }) - }) - t.Run("case=should fail if prompt=none but no auth session given", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - oc := testhelpers.NewEmptyJarClient(t) - code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("prompt", "none"), - ) - require.Empty(t, code) - }) + t.Run("case=add ext claims from hook if configured", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + assert.Equal(t, r.Header.Get("Authorization"), "Bearer secret value") + + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, map[string]interface{}{"foo": "bar"}, hookReq.Session.Extra) + require.NotEmpty(t, hookReq.Request) + require.ElementsMatch(t, []string{}, hookReq.Request.GrantedAudience) + require.Equal(t, map[string][]string{"grant_type": {"authorization_code"}}, hookReq.Request.Payload) + + claims := map[string]interface{}{ + "hooked": true, + } - t.Run("case=requires re-authentication when id_token_hint is set to a user 'patrik-neu' but the session is 'aeneas-rekkas' and then fails because the user id from the log in endpoint is 'aeneas-rekkas'", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - require.False(t, r.Skip) - require.Empty(t, r.Subject) - return nil - }), - acceptConsentHandler(t, c, subject, nil), - ) - - oc := testhelpers.NewEmptyJarClient(t) - - // Create login session for aeneas-rekkas - code, _ := getAuthorizeCode(t, conf, oc) - require.NotEmpty(t, code) - - // Perform authentication for aeneas-rekkas which fails because id_token_hint is patrik-neu - code, _ = getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("id_token_hint", testhelpers.NewIDToken(t, reg, "patrik-neu")), - ) - require.Empty(t, code) - }) + hookResp := hydraoauth2.TokenHookResponse{ + Session: flow.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } - t.Run("case=should not cause issues if max_age is very low and consent takes a long time", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - time.Sleep(time.Second * 2) - return nil - }), - acceptConsentHandler(t, c, subject, nil), - ) - - code, _ := getAuthorizeCode(t, conf, nil) - require.NotEmpty(t, code) - }) + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{ + URL: hs.URL, + Auth: &config.Auth{ + Type: "api_key", + Config: config.AuthConfig{ + In: "header", + Name: "Authorization", + Value: "Bearer secret value", + }, + }, + }) - t.Run("case=ensure consistent claims returned for userinfo", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) + t.Cleanup(func() { + reg.Config().Delete(ctx, config.KeyTokenHook) + }) - code, _ := getAuthorizeCode(t, conf, nil) - require.NotEmpty(t, code) + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) - idClaims := assertIDToken(t, token, conf, subject, "", time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + assertJWTAccessToken(t, strategy, conf, token, subject, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - time.Sleep(time.Second) - uiClaims := testhelpers.Userinfo(t, token, publicTS) + // NOTE: using introspect to cover both jwt and opaque strategies + accessTokenClaims := introspectAccessToken(t, conf, token, subject) + require.True(t, accessTokenClaims.Get("ext.hooked").Bool()) - for _, f := range []string{ - "sub", - "iss", - "aud", - "bar", - "auth_time", - } { - assert.NotEmpty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) - assert.EqualValues(t, idClaims.Get(f).Raw, uiClaims.Get(f).Raw, "%s\nuserinfo: %s\nidtoken: %s", f, uiClaims, idClaims) - } + idTokenClaims := assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + require.True(t, idTokenClaims.Get("hooked").Bool()) + } + } - for _, f := range []string{ - "at_hash", - "c_hash", - "nonce", - "sid", - "jti", - } { - assert.Empty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) - } - }) + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook fails", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) - t.Run("case=add ext claims from hook if configured", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") - assert.Equal(t, r.Header.Get("Authorization"), "Bearer secret value") - - var hookReq hydraoauth2.TokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, map[string]interface{}{"foo": "bar"}, hookReq.Session.Extra) - require.NotEmpty(t, hookReq.Request) - require.ElementsMatch(t, []string{}, hookReq.Request.GrantedAudience) - require.Equal(t, map[string][]string{"grant_type": {"authorization_code"}}, hookReq.Request.Payload) - - claims := map[string]interface{}{ - "hooked": true, + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) } + } - hookResp := hydraoauth2.TokenHookResponse{ - Session: flow.AcceptOAuth2ConsentRequestSession{ - AccessToken: claims, - IDToken: claims, - }, + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook denies the request", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) } + } - w.WriteHeader(http.StatusOK) - require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) - })) - defer hs.Close() - - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{ - URL: hs.URL, - Auth: &config.Auth{ - Type: "api_key", - Config: config.AuthConfig{ - In: "header", - Name: "Authorization", - Value: "Bearer secret value", - }, - }, + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook response is malformed", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=graceful token rotation", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "2s") + reg.Config().Delete(ctx, config.KeyTokenHook) + reg.Config().Delete(ctx, config.KeyRefreshTokenHook) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, "1m") + t.Cleanup(func() { + reg.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) + reg.Config().Delete(ctx, config.KeyRefreshTokenLifespan) + reg.Config().Delete(ctx, config.KeyAccessTokenLifespan) }) - defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + // This is an essential and complex test suite. We need to cover the following cases: + // + // * Graceful refresh token rotation invalidates the previous access token. + // * An expired refresh token cannot be used even if grace period is active. + // * A used refresh token cannot be re-used once the grace period ends, and it triggers re-use detection. + // * A test suite with a variety of concurrent refresh token chains. + run := func(t *testing.T, strategy string) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) + + issueTokens := func(t *testing.T) *oauth2.Token { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() + require.NoError(t, err) - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return token + } - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = time.Now().Add(-time.Hour * 24) + iat := time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) - assertJWTAccessToken(t, strategy, conf, token, subject, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + introspectAccessToken(t, conf, refreshedToken, subject) + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return refreshedToken + } - // NOTE: using introspect to cover both jwt and opaque strategies - accessTokenClaims := introspectAccessToken(t, conf, token, subject) - require.True(t, accessTokenClaims.Get("ext.hooked").Bool()) + assertInactive := func(t *testing.T, token string, c *oauth2.Config) { + t.Helper() + at := testhelpers.IntrospectToken(t, conf, token, adminTS) + assert.False(t, at.Get("active").Bool(), "%s", at) + } - idTokenClaims := assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) - require.True(t, idTokenClaims.Get("hooked").Bool()) - } - } + t.Run("gracefully refreshing a token does invalidate the previous access token", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "2s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) + token := issueTokens(t) + _ = refreshTokens(t, token) - t.Run("case=fail token exchange if hook fails", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer hs.Close() + assertInactive(t, token.AccessToken, conf) // Original access token is invalid - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + _ = refreshTokens(t, token) + assertInactive(t, token.AccessToken, conf) // Original access token is still invalid + }) - defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + t.Run("an expired refresh token can not be used even if we are in the grace period", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1s") - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + token := issueTokens(t) + time.Sleep(time.Second * 2) // Let token expire - we need 2 seconds to reliably be longer than TTL - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err, "Rotating an expired token is not possible even when we are in the grace period") - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } + // The access token is still valid because using an expired refresh token has no effect on the access token. + assertInactive(t, token.RefreshToken, conf) + }) - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) + t.Run("a used refresh token can not be re-used once the grace period ends and it triggers re-use detection", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") - t.Run("case=fail token exchange if hook denies the request", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer hs.Close() + token := issueTokens(t) + refreshed := refreshTokens(t, token) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + time.Sleep(time.Second * 2) // Wait for the grace period to end - defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err, "Rotating a used refresh token is not possible after the grace period") - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + assertInactive(t, refreshed.AccessToken, conf) + assertInactive(t, refreshed.RefreshToken, conf) + }) - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } + // This test suite covers complex scenarios where we have multiple generations of tokens and we need to ensure + // that key security mitigations are in place: + // + // - Token re-use detection clears all tokens if a refresh token is re-used after the grace period. + // - Revoking consent clears all tokens. + // - Token revokation clears all tokens. + // + // The test creates 4 token generations, where each generations has twice as many tokens as the previous generation. + // The generations are created like this: + // + // - In the first scenario, all token generations are created at the same time. + // - In the second scenario, we create token generations with a delay that is longer than the grace period between them. + // + // Tokens for each generation are created in parallel to ensure we have no state leak anywhere.0 + t.Run("token generations", func(t *testing.T) { + + gracePeriod := time.Second + aboveGracePeriod := time.Second * 2 + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, gracePeriod.String()) + reg.Config().Delete(ctx, config.KeyTokenHook) + reg.Config().Delete(ctx, config.KeyRefreshTokenHook) + + createTokenGenerations := func(t *testing.T, count int, withSleep time.Duration) [][]*oauth2.Token { + generations := make([][]*oauth2.Token, count) + generations[0] = []*oauth2.Token{issueTokens(t)} + // Start from the first generation. For every next generation, we refresh all the tokens of the previous generation twice. + for i := 1; i < len(generations); i++ { + generations[i] = make([]*oauth2.Token, 0, len(generations[i-1])*2) + + var wg sync.WaitGroup + gen := func(i int, token *oauth2.Token) { + defer wg.Done() + generations[i] = append(generations[i], refreshTokens(t, token)) + } - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) + for _, token := range generations[i-1] { + wg.Add(2) + if dbName == "memory" { + // SQLite can not handle concurrency + gen(i, token) + gen(i, token) + } else { + go gen(i, token) + go gen(i, token) + } + } - t.Run("case=fail token exchange if hook response is malformed", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer hs.Close() + wg.Wait() + if withSleep > 0 { + time.Sleep(withSleep) + } + } + return generations + } - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Run("re-using an old graceful refresh token invalidates all tokens", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + // This test only works if the refresh token lifespan is longer than the grace period. + generations := createTokenGenerations(t, 4, time.Second*2) + + generationIndex := rng.Intn(len(generations) - 1) // Exclude the last generation + tokenIndex := rng.Intn(len(generations[generationIndex])) + + token := generations[generationIndex][tokenIndex] + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err) + + // Now all tokens are inactive + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) - defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + for _, withSleep := range []time.Duration{0, aboveGracePeriod} { + t.Run(fmt.Sprintf("withSleep=%s", withSleep), func(t *testing.T) { + createTokenGenerations := func(t *testing.T, count int) [][]*oauth2.Token { + return createTokenGenerations(t, count, withSleep) + } - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + t.Run("only the most recent token generation is valid across the board", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + // All generations except the last one are valid. + for i, generation := range generations[:len(generations)-1] { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + }) + } + }) + } - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + // The last generation is valid: + t.Run(fmt.Sprintf("generation=%d", len(generations)-1), func(t *testing.T) { + for j, token := range generations[len(generations)-1] { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + introspectAccessToken(t, conf, token, subject) + assertIDToken(t, token, conf, subject, nonce, time.Time{}) + assertRefreshToken(t, token, conf, time.Time{}) + }) + } + }) + }) + + t.Run("revoking consent revokes all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + // After revoking consent, all generations are invalid. + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + + t.Run("re-using the a recent refresh token after the grace period has ended invalidates all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + token := generations[len(generations)-1][0] + + finalToken := refreshTokens(t, token) + time.Sleep(aboveGracePeriod) // Wait for the grace period to end + + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err) + + // Now all tokens are inactive + for i, generation := range append(generations, []*oauth2.Token{finalToken}) { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + + t.Run("revoking a refresh token in the chain revokes all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + testhelpers.RevokeToken(t, conf, generations[len(generations)-1][0].RefreshToken, publicTS) + + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + token := token + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + }) + } + }) - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } + t.Run("it is possible to refresh tokens concurrently", func(t *testing.T) { + // SQLite can not handle concurrency + if dbName == "memory" { + t.Skip("Skipping test because SQLite can not handle concurrency") + } - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + + token := issueTokens(t) + + var wg sync.WaitGroup + refresh := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = time.Now().Add(-time.Hour * 24) + tt, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + return tt + } + + refreshes := make([]*oauth2.Token, 5) + for k := range refreshes { + wg.Add(1) + go func(k int) { + defer wg.Done() + refreshes[k] = refresh(t, token) + }(k) + } + wg.Wait() + + // All tokens are valid. + for k, actual := range refreshes { + refresh := actual + require.NotEmpty(t, refresh.RefreshToken, "token %d:\ntoken:%+v", k, refresh) + require.NotEmpty(t, refresh.AccessToken, "token %d:\ntoken:%+v", k, refresh) + require.NotEmpty(t, refresh.Extra("id_token"), "token %d:\ntoken:%+v", k, refresh) + + i := testhelpers.IntrospectToken(t, conf, refresh.AccessToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) + + i = testhelpers.IntrospectToken(t, conf, refresh.RefreshToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) + } + }) + } + + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") + }) + + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") + }) + }) + }) + } } func assertCreateVerifiableCredential(t *testing.T, reg driver.Registry, nonce string, accessToken *oauth2.Token, alg jose.SignatureAlgorithm) { @@ -1365,7 +1550,7 @@ func assertCreateVerifiableCredential(t *testing.T, reg driver.Registry, nonce s proofJWT := createVCProofJWT(t, pubKeyJWK, privKey, nonce) // Assert that we can fetch a verifiable credential with the nonce. - verifiableCredential, _ := createVerifiableCredential(t, reg, accessToken, &hydraoauth2.CreateVerifiableCredentialRequestBody{ + verifiableCredential, err := createVerifiableCredential(t, reg, accessToken, &hydraoauth2.CreateVerifiableCredentialRequestBody{ Format: "jwt_vc_json", Types: []string{"VerifiableCredential", "UserInfoCredential"}, Proof: &hydraoauth2.VerifiableCredentialProof{ @@ -1414,7 +1599,7 @@ func createVerifiableCredential( reg driver.Registry, token *oauth2.Token, createVerifiableCredentialReq *hydraoauth2.CreateVerifiableCredentialRequestBody, -) (vcRes *hydraoauth2.VerifiableCredentialResponse, vcErr *fosite.RFC6749Error) { +) (vcRes *hydraoauth2.VerifiableCredentialResponse, vcErr error) { var ( ctx = context.Background() body bytes.Buffer @@ -1486,18 +1671,18 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { ctx := context.Background() for _, strat := range []struct{ d string }{{d: "opaque"}, {d: "jwt"}} { t.Run("strategy="+strat.d, func(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Second*2) conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) - internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) + testhelpers.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + testhelpers.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) consentStrategy := &consentMock{} router := x.NewRouterPublic() ts := httptest.NewServer(router) - defer ts.Close() + t.Cleanup(ts.Close) reg.WithConsentStrategy(consentStrategy) handler := reg.OAuth2Handler() @@ -1511,7 +1696,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { }) var mutex sync.Mutex - require.NoError(t, reg.ClientManager().CreateClient(context.TODO(), &client.Client{ + require.NoError(t, reg.ClientManager().CreateClient(ctx, &client.Client{ ID: "app-client", Secret: "secret", RedirectURIs: []string{ts.URL + "/callback"}, @@ -1874,6 +2059,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { if hookType == "legacy" { conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + } else { conf.MustSet(ctx, config.KeyTokenHook, hs.URL) defer conf.MustSet(ctx, config.KeyTokenHook, nil) @@ -2033,13 +2219,13 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { t.Run("refreshing old token should no longer work", func(t *testing.T) { res, err := testRefresh(t, token, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + assert.Equal(t, http.StatusBadRequest, res.StatusCode) }) t.Run("attempt to refresh old token should revoke new token", func(t *testing.T) { res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + assert.Equal(t, http.StatusBadRequest, res.StatusCode) }) t.Run("duplicate code exchange fails", func(t *testing.T) { diff --git a/oauth2/oauth2_client_credentials_bench_test.go b/oauth2/oauth2_client_credentials_bench_test.go index 310727f34cc..560925ffb3e 100644 --- a/oauth2/oauth2_client_credentials_bench_test.go +++ b/oauth2/oauth2_client_credentials_bench_test.go @@ -22,7 +22,6 @@ import ( hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" @@ -36,7 +35,7 @@ func BenchmarkClientCredentials(b *testing.B) { tracer := trace.NewTracerProvider(trace.WithSpanProcessor(spans)).Tracer("") dsn := "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable" - reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) + reg := testhelpers.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") public, admin := testhelpers.NewOAuth2Server(ctx, b, reg) diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go index 9d5067dafb1..a93ea067716 100644 --- a/oauth2/oauth2_client_credentials_test.go +++ b/oauth2/oauth2_client_credentials_test.go @@ -28,14 +28,13 @@ import ( hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/requirex" ) func TestClientCredentials(t *testing.T) { ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") public, admin := testhelpers.NewOAuth2Server(ctx, t, reg) diff --git a/oauth2/oauth2_jwt_bearer_test.go b/oauth2/oauth2_jwt_bearer_test.go index e9e7ddf9120..0b1a862ba05 100644 --- a/oauth2/oauth2_jwt_bearer_test.go +++ b/oauth2/oauth2_jwt_bearer_test.go @@ -35,13 +35,12 @@ import ( hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" ) func TestJWTBearer(t *testing.T) { ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") _, admin := testhelpers.NewOAuth2Server(ctx, t, reg) diff --git a/oauth2/oauth2_refresh_token_test.go b/oauth2/oauth2_refresh_token_test.go index ffabb0dd2a0..018af343104 100644 --- a/oauth2/oauth2_refresh_token_test.go +++ b/oauth2/oauth2_refresh_token_test.go @@ -14,6 +14,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" @@ -22,7 +24,6 @@ import ( "github.com/ory/fosite" hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/contextx" "github.com/ory/x/dbal" @@ -89,12 +90,12 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) { } net := &networkx.Network{} require.NoError(t, dbRegistry.Persister().Connection(context.Background()).First(net)) - dbRegistry.WithContextualizer(&contextx.Static{NID: net.ID, C: internal.NewConfigurationWithDefaults().Source(context.Background())}) + dbRegistry.WithContextualizer(&contextx.Static{NID: net.ID, C: testhelpers.NewConfigurationWithDefaults().Source(context.Background())}) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) t.Cleanup(cancel) require.NoError(t, dbRegistry.OAuth2Storage().(clientCreator).CreateClient(ctx, &testClient)) - require.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, request)) + require.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, "", request)) _, err := dbRegistry.OAuth2Storage().GetRefreshTokenSession(ctx, tokenSignature, nil) require.NoError(t, err) provider := dbRegistry.OAuth2Provider() @@ -250,7 +251,7 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) { // reset state for the next test iteration assert.NoError(t, dbRegistry.OAuth2Storage().DeleteRefreshTokenSession(ctx, tokenSignature)) - assert.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, request)) + assert.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, "", request)) } } } diff --git a/oauth2/oauth2_rop_test.go b/oauth2/oauth2_rop_test.go index 4adb4904452..0428e86e7a1 100644 --- a/oauth2/oauth2_rop_test.go +++ b/oauth2/oauth2_rop_test.go @@ -22,7 +22,6 @@ import ( "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/fositex" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/kratos" "github.com/ory/hydra/v2/internal/testhelpers" hydraoauth2 "github.com/ory/hydra/v2/oauth2" @@ -34,7 +33,7 @@ import ( func TestResourceOwnerPasswordGrant(t *testing.T) { ctx := context.Background() fakeKratos := kratos.NewFake() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.WithKratos(fakeKratos) reg.WithExtraFositeFactories([]fositex.Factory{compose.OAuth2ResourceOwnerPasswordCredentialsFactory}) publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) diff --git a/oauth2/revocator_test.go b/oauth2/revocator_test.go index 4ad0be8cac7..32283730fa9 100644 --- a/oauth2/revocator_test.go +++ b/oauth2/revocator_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/gobuffalo/pop/v6" "github.com/ory/x/httprouterx" @@ -60,10 +62,10 @@ func countAccessTokens(t *testing.T, c *pop.Connection) int { } func TestRevoke(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(context.Background(), reg, x.OpenIDConnectKeyName) + testhelpers.MustEnsureRegistryKeys(context.Background(), reg, x.OpenIDConnectKeyName) internal.AddFositeExamples(reg) tokens := Tokens(reg.OAuth2ProviderConfig(), 4) diff --git a/oauth2/session_custom_claims_test.go b/oauth2/session_custom_claims_test.go index 5fbe3c5c1a5..5594df88021 100644 --- a/oauth2/session_custom_claims_test.go +++ b/oauth2/session_custom_claims_test.go @@ -7,11 +7,12 @@ import ( "context" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/stretchr/testify/assert" @@ -39,7 +40,7 @@ func createSessionWithCustomClaims(ctx context.Context, p *config.DefaultProvide func TestCustomClaimsInSession(t *testing.T) { ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() t.Run("no_custom_claims", func(t *testing.T) { c.MustSet(ctx, config.KeyAllowedTopLevelClaims, []string{}) diff --git a/oauth2/trust/handler_test.go b/oauth2/trust/handler_test.go index daacc8ed282..e93066eac97 100644 --- a/oauth2/trust/handler_test.go +++ b/oauth2/trust/handler_test.go @@ -15,6 +15,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/go-jose/go-jose/v3" "github.com/tidwall/gjson" @@ -33,7 +35,6 @@ import ( hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" ) @@ -50,10 +51,10 @@ type HandlerTestSuite struct { // Setup will run before the tests in the suite are run. func (s *HandlerTestSuite) SetupSuite() { - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(context.Background(), config.KeySubjectTypesSupported, []string{"public"}) conf.MustSet(context.Background(), config.KeyDefaultClientScope, []string{"foo", "bar"}) - s.registry = internal.NewRegistryMemory(s.T(), conf, &contextx.Default{}) + s.registry = testhelpers.NewRegistryMemory(s.T(), conf, &contextx.Default{}) router := x.NewRouterAdmin(conf.AdminURL) handler := trust.NewHandler(s.registry) @@ -80,7 +81,7 @@ func (s *HandlerTestSuite) TearDownSuite() { // Will run after each test in the suite. func (s *HandlerTestSuite) TearDownTest() { - internal.CleanAndMigrate(s.registry)(s.T()) + testhelpers.CleanAndMigrate(s.registry)(s.T()) } // In order for 'go test' to run this suite, we need to create diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 8564cfab969..71435d95687 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -13,7 +13,8 @@ import ( "testing" "time" - "github.com/ory/hydra/v2/internal" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/x/contextx" "github.com/bradleyjkemp/cupaloy/v2" @@ -64,7 +65,7 @@ func TestMigrations(t *testing.T) { connections := make(map[string]*pop.Connection, 1) if testing.Short() { - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) require.NoError(t, reg.Persister().MigrateUp(context.Background())) c := reg.Persister().Connection(context.Background()) connections["sqlite"] = c diff --git a/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.down.sql b/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.down.sql new file mode 100644 index 00000000000..46db0f98db5 --- /dev/null +++ b/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.down.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh DROP COLUMN access_token_signature; diff --git a/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.up.sql b/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.up.sql new file mode 100644 index 00000000000..3b389709bc7 --- /dev/null +++ b/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.up.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh ADD access_token_signature VARCHAR(255) DEFAULT NULL; diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index ba2647393a5..413e40a8eaa 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -33,7 +33,6 @@ var _ persistence.Persister = new(Persister) var _ storage.Transactional = new(Persister) var ( - ErrTransactionOpen = errors.New("There is already a Transaction in this context.") ErrNoTransactionOpen = errors.New("There is no Transaction in this context.") ) diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index 5d556d44b4d..93bccdcfe58 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/fosite/handler/openid" "github.com/stretchr/testify/assert" @@ -29,7 +31,6 @@ import ( "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/oauth2/trust" @@ -57,11 +58,11 @@ var _ interface { func (s *PersisterTestSuite) SetupSuite() { s.registries = map[string]driver.Registry{ - "memory": internal.NewRegistrySQLFromURL(s.T(), dbal.NewSQLiteTestDatabase(s.T()), true, &contextx.Default{}), + "memory": testhelpers.NewRegistrySQLFromURL(s.T(), dbal.NewSQLiteTestDatabase(s.T()), true, &contextx.Default{}), } if !testing.Short() { - s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{}) + s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = testhelpers.ConnectDatabases(s.T(), true, &contextx.Default{}) } s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4()) @@ -533,7 +534,7 @@ func (s *PersisterTestSuite) TestCreateRefreshTokenSession() { authorizeCode := uuid.Must(uuid.NewV4()).String() actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, authorizeCode, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, authorizeCode, "", request)) require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) require.Equal(t, s.t1NID, actual.NID) }) @@ -727,7 +728,7 @@ func (s *PersisterTestSuite) TestDeleteRefreshTokenSession() { request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} signature := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, "", request)) actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} @@ -933,7 +934,7 @@ func (s *PersisterTestSuite) TestFlushInactiveRefreshTokens() { signature := uuid.Must(uuid.NewV4()).String() require.NoError(t, r.Persister().CreateClient(s.t1, client)) - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, "", request)) actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} @@ -1392,7 +1393,7 @@ func (s *PersisterTestSuite) TestGetRefreshTokenSession() { request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} sig := uuid.Must(uuid.NewV4()).String() require.NoError(t, r.Persister().CreateClient(s.t1, client)) - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, sig, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, sig, "", request)) actual, err := r.Persister().GetRefreshTokenSession(s.t2, sig, &fosite.DefaultSession{}) require.Error(t, err) @@ -1777,47 +1778,114 @@ func (s *PersisterTestSuite) TestRevokeRefreshToken() { request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} signature := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, request)) - - actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, "", request)) + var actualt2 persistencesql.OAuth2RefreshTable require.NoError(t, r.Persister().RevokeRefreshToken(s.t2, request.ID)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) - require.Equal(t, true, actual.Active) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actualt2, signature)) + require.Equal(t, true, actualt2.Active) + require.NoError(t, r.Persister().RevokeRefreshToken(s.t1, request.ID)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) - require.Equal(t, false, actual.Active) + require.ErrorIs(t, r.Persister().Connection(context.Background()).Find(new(persistencesql.OAuth2RefreshTable), signature), sql.ErrNoRows) }) } } -func (s *PersisterTestSuite) TestRevokeRefreshTokenMaybeGracePeriod() { +func (s *PersisterTestSuite) TestRotateRefreshToken() { t := s.T() for k, r := range s.registries { t.Run(k, func(t *testing.T) { - client := &client.Client{ID: "client-id"} - require.NoError(t, r.Persister().CreateClient(s.t1, client)) + t.Run("with access signature", func(t *testing.T) { + clientID := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateClient(s.t1, &client.Client{ID: clientID})) + require.NoError(t, r.Persister().CreateClient(s.t2, &client.Client{ID: clientID})) - request := fosite.NewRequest() - request.Client = &fosite.DefaultClient{ID: "client-id"} - request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} + request := fosite.NewRequest() + request.Client = &fosite.DefaultClient{ID: clientID} + request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} - signature := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, request)) + // Create token T1 + signatureT1 := uuid.Must(uuid.NewV4()).String() + accessSignatureT1 := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateAccessTokenSession(s.t1, accessSignatureT1, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signatureT1, accessSignatureT1, request)) - actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} + // Create token T2 + signatureT2 := uuid.Must(uuid.NewV4()).String() + accessSignatureT2 := uuid.Must(uuid.NewV4()).String() + require.ErrorIs(t, r.Persister().RotateRefreshToken(s.t2, request.ID, signatureT2), fosite.ErrNotFound, "Rotation fails as token is non-existent.") + require.NoError(t, r.Persister().CreateAccessTokenSession(s.t2, accessSignatureT2, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t2, signatureT2, accessSignatureT2, request)) - store, ok := r.Persister().(*persistencesql.Persister) - if !ok { - t.Fatal("type assertion failed") - } + accessT2 := persistencesql.OAuth2RequestSQL{Table: "access"} + assert.NoError(t, r.Persister().Connection(s.t2).Where("signature = ?", x.SignatureHash(accessSignatureT2)).First(&accessT2)) + require.Equal(t, true, accessT2.Active) - require.NoError(t, store.RevokeRefreshTokenMaybeGracePeriod(s.t2, request.ID, signature)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) - require.Equal(t, true, actual.Active) - require.NoError(t, store.RevokeRefreshTokenMaybeGracePeriod(s.t1, request.ID, signature)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) - require.Equal(t, false, actual.Active) + accessT1 := persistencesql.OAuth2RequestSQL{Table: "access"} + assert.NoError(t, r.Persister().Connection(s.t1).Where("signature = ?", x.SignatureHash(accessSignatureT1)).First(&accessT1)) + require.Equal(t, true, accessT2.Active) + + // Rotate token T1 + require.NoError(t, r.Persister().RotateRefreshToken(s.t1, request.ID, signatureT1)) + { + refreshT1 := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().Connection(s.t1).Where("signature = ?", signatureT1).First(&refreshT1)) + require.Equal(t, false, refreshT1.Active) + + accessT1 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.ErrorIs(t, r.Persister().Connection(s.t1).Where("signature = ?", x.SignatureHash(accessSignatureT1)).First(&accessT1), sql.ErrNoRows) + + refreshT2 := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().Connection(s.t2).Where("signature = ?", signatureT2).First(&refreshT2)) + require.Equal(t, true, refreshT2.Active) + + accessT2 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.NoError(t, r.Persister().Connection(s.t2).Where("signature = ?", x.SignatureHash(accessSignatureT2)).First(&accessT2)) + require.Equal(t, true, accessT2.Active) + } + + require.NoError(t, r.Persister().RotateRefreshToken(s.t2, request.ID, signatureT2)) + { + refreshT2 := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().Connection(s.t2).Where("signature = ?", signatureT2).First(&refreshT2)) + require.Equal(t, false, refreshT2.Active) + + accessT2 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.ErrorIs(t, r.Persister().Connection(s.t2).Where("signature = ?", x.SignatureHash(accessSignatureT2)).First(&accessT2), sql.ErrNoRows) + require.Equal(t, false, accessT2.Active) + } + }) + + t.Run("without access signature", func(t *testing.T) { + clientID := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateClient(s.t1, &client.Client{ID: clientID})) + + request1 := fosite.NewRequest() + request1.Client = &fosite.DefaultClient{ID: clientID} + request1.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} + + signature := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, "", request1)) + + accessSignature1 := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateAccessTokenSession(s.t1, accessSignature1, request1)) + + accessSignature2 := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateAccessTokenSession(s.t1, accessSignature2, request1)) + + require.NoError(t, r.Persister().RotateRefreshToken(s.t1, request1.ID, signature)) + { + accessT1 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.ErrorIs(t, r.Persister().Connection(s.t1).Where("signature = ?", x.SignatureHash(accessSignature1)).First(&accessT1), sql.ErrNoRows) + + refresh := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().Connection(s.t1).Where("signature = ?", signature).First(&refresh)) + require.Equal(t, false, refresh.Active) + + accessT2 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.ErrorIs(t, r.Persister().Connection(s.t1).Where("signature = ?", x.SignatureHash(accessSignature2)).First(&accessT2), sql.ErrNoRows) + } + }) }) } } diff --git a/persistence/sql/persister_nonce_test.go b/persistence/sql/persister_nonce_test.go index 1de7eda543a..933af0a9a7a 100644 --- a/persistence/sql/persister_nonce_test.go +++ b/persistence/sql/persister_nonce_test.go @@ -8,18 +8,19 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ory/fosite" - "github.com/ory/hydra/v2/internal" "github.com/ory/x/contextx" "github.com/ory/x/randx" ) func TestPersister_Nonce(t *testing.T) { ctx := context.Background() - p := internal.NewMockedRegistry(t, new(contextx.Default)).Persister() + p := testhelpers.NewMockedRegistry(t, new(contextx.Default)).Persister() accessToken := randx.MustString(100, randx.AlphaNum) anotherToken := randx.MustString(100, randx.AlphaNum) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index b67b6ae17ed..58eb107306b 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -14,23 +14,22 @@ import ( "strings" "time" - "github.com/ory/hydra/v2/x" - - "github.com/ory/x/sqlxx" - - "go.opentelemetry.io/otel/trace" - "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "github.com/ory/fosite" "github.com/ory/fosite/storage" "github.com/ory/hydra/v2/oauth2" + "github.com/ory/hydra/v2/x" "github.com/ory/hydra/v2/x/events" + "github.com/ory/x/dbal" "github.com/ory/x/errorsx" "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" + "github.com/ory/x/sqlxx" "github.com/ory/x/stringsx" ) @@ -60,7 +59,8 @@ type ( } OAuth2RefreshTable struct { OAuth2RequestSQL - FirstUsedAt sql.NullTime `db:"first_used_at"` + FirstUsedAt sql.NullTime `db:"first_used_at"` + AccessTokenSignature sql.NullString `db:"access_token_signature"` } ) @@ -445,41 +445,61 @@ func toEventOptions(requester fosite.Requester) []trace.EventOption { } } -func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { +func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, accessTokenSignature string, requester fosite.Requester) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRefreshTokenSession") defer otelx.End(span, &err) events.Trace(ctx, events.RefreshTokenIssued, toEventOptions(requester)...) - return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) + + req, err := p.sqlSchemaFromRequest(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) + if err != nil { + return err + } + + var sig sql.NullString + if len(accessTokenSignature) > 0 { + sig = sql.NullString{ + Valid: true, + String: x.SignatureHash(accessTokenSignature), + } + } + + if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, &OAuth2RefreshTable{ + OAuth2RequestSQL: *req, + AccessTokenSignature: sig, + })); errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } else if err != nil { + return err + } + + return nil } func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) - r := OAuth2RefreshTable{OAuth2RequestSQL: OAuth2RequestSQL{Table: sqlTableRefresh}} - err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) - if errors.Is(err, sql.ErrNoRows) { + var row OAuth2RefreshTable + if err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&row); errors.Is(err, sql.ErrNoRows) { return nil, errorsx.WithStack(fosite.ErrNotFound) } else if err != nil { return nil, sqlcon.HandleError(err) } - fositeRequest, err := r.toRequest(ctx, session, p) - if err != nil { - return nil, err - } - - if r.Active { - return fositeRequest, nil + gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx) + if row.Active { + // Token is active + return row.toRequest(ctx, session, p) + } else if gracePeriod > 0 && + row.FirstUsedAt.Valid && + row.FirstUsedAt.Time.Add(gracePeriod).After(time.Now()) { + // We return the request as is, which indicates that the token is active (because we are in the grace period still). + return row.toRequest(ctx, session, p) } - if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 && r.FirstUsedAt.Valid { - if r.FirstUsedAt.Time.Add(gracePeriod).Before(time.Now()) { - return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) - } - - r.Active = true // We set active to true because we are in the grace period. - return r.toRequest(ctx, session, p) // And re-generate the request + fositeRequest, err := row.toRequest(ctx, session, p) + if err != nil { + return nil, err } return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) @@ -533,23 +553,7 @@ func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature stri func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshToken") defer otelx.End(span, &err) - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) -} - -func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") - defer otelx.End(span, &err) - - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx). - RawQuery( - fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), - id, - p.NetworkID(ctx), - ). - Exec(), - ) + return p.deleteSessionByRequestID(ctx, id, sqlTableRefresh) } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { @@ -612,3 +616,123 @@ func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (er p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), ) } + +func handleRetryError(err error) error { + if err == nil { + return nil + } + + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return err +} + +// strictRefreshRotation implements the strict refresh token rotation strategy. In strict rotation, we disable all +// refresh and access tokens associated with a request ID and subsequently create the only valid, new token pair. +func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.strictRefreshRotation", + trace.WithAttributes( + attribute.String("request_id", requestID), + attribute.String("network_id", p.NetworkID(ctx).String()))) + defer otelx.End(span, &err) + + c := p.Connection(ctx) + + // In strict rotation we only have one token chain for every request. Therefore, we remove all + // access tokens associated with the request ID. + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err + } + + // The same applies to refresh tokens in strict mode. We disable all old refresh tokens when rotating. + count, err := c.RawQuery( + "UPDATE hydra_oauth2_refresh SET active=false WHERE request_id=? AND nid = ? AND active", + requestID, + p.NetworkID(ctx), + ).ExecWithCount() + if err != nil { + return sqlcon.HandleError(err) + } else if count == 0 { + return errorsx.WithStack(fosite.ErrNotFound) + } + + return nil +} + +func (p *Persister) gracefulRefreshRotation(ctx context.Context, requestID string, refreshSignature string, period time.Duration) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.gracefulRefreshRotation", + trace.WithAttributes( + attribute.String("request_id", requestID), + attribute.String("network_id", p.NetworkID(ctx).String()))) + defer otelx.End(span, &err) + + c := p.Connection(ctx) + now := time.Now().UTC().Round(time.Millisecond) + + var accessTokenSignature sql.NullString + if p.conn.Dialect.Name() == dbal.DriverMySQL { + // MySQL does not support returning values from an update query, so we need to do two queries. + var tokenToRevoke OAuth2RefreshTable + if err := c. + Select("access_token_signature"). + // Filtering by "active" status would break graceful token rotation. We know and trust (with tests) + // that Fosite is dealing with the refresh token reuse detection business logic without + // relying on the active filter her. + Where("signature=? AND nid = ?", refreshSignature, p.NetworkID(ctx)). + First(&tokenToRevoke); err != nil { + return sqlcon.HandleError(err) + } + + if count, err := c.RawQuery( + // Signature is the primary key so no limit needed. We only update first_used_at if it is not set yet (otherwise + // we would "refresh" the grace period again and again, and the refresh token would never "expire"). + `UPDATE hydra_oauth2_refresh SET active=false, first_used_at = COALESCE(first_used_at, ?) WHERE signature=? AND nid = ?`, + now, refreshSignature, p.NetworkID(ctx), + ).ExecWithCount(); err != nil { + return sqlcon.HandleError(err) + } else if count == 0 { + return errorsx.WithStack(fosite.ErrNotFound) + } + + accessTokenSignature = tokenToRevoke.AccessTokenSignature + } else { + var tokenToRevoke OAuth2RefreshTable + if err := c.RawQuery( + // Same query like in the MySQL case, but we can return the access token signature directly. + `UPDATE hydra_oauth2_refresh SET active=false, first_used_at = COALESCE(first_used_at, ?) WHERE signature=? AND nid = ? RETURNING access_token_signature`, + now, refreshSignature, p.NetworkID(ctx), + ).First(&tokenToRevoke); err != nil { + return sqlcon.HandleError(err) + } + + accessTokenSignature = tokenToRevoke.AccessTokenSignature + } + + if !accessTokenSignature.Valid { + // If the access token is not found, we fall back to deleting all access tokens associated with the request ID. + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err + } + return nil + } + + // We have the signature and we will only remove that specific access token as part of the rotation. + return p.deleteSessionBySignature(ctx, accessTokenSignature.String, sqlTableAccess) +} + +func (p *Persister) RotateRefreshToken(ctx context.Context, requestID string, refreshTokenSignature string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RotateRefreshToken") + defer otelx.End(span, &err) + + // If we end up here, we have a valid refresh token and can proceed with the rotation. + gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx) + if gracePeriod > 0 { + return handleRetryError(p.gracefulRefreshRotation(ctx, requestID, refreshTokenSignature, gracePeriod)) + } + + return handleRetryError(p.strictRefreshRotation(ctx, requestID)) +} diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index a4818a3e69d..b4c88ef01c3 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -28,7 +28,6 @@ import ( "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/driver" - "github.com/ory/hydra/v2/internal" ) func init() { @@ -120,11 +119,11 @@ func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registr func TestManagersNextGen(t *testing.T) { regs := map[string]driver.Registry{ - "memory": internal.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), true, &contextx.Default{}), + "memory": testhelpers.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), true, &contextx.Default{}), } if !testing.Short() { - regs["postgres"], regs["mysql"], regs["cockroach"], _ = internal.ConnectDatabases(t, true, &contextx.Default{}) + regs["postgres"], regs["mysql"], regs["cockroach"], _ = testhelpers.ConnectDatabases(t, true, &contextx.Default{}) } ctx := context.Background() @@ -153,16 +152,16 @@ func TestManagersNextGen(t *testing.T) { func TestManagers(t *testing.T) { ctx := context.TODO() t1registries := map[string]driver.Registry{ - "memory": internal.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), true, &contextx.Default{}), + "memory": testhelpers.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), true, &contextx.Default{}), } t2registries := map[string]driver.Registry{ - "memory": internal.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), false, &contextx.Default{}), + "memory": testhelpers.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), false, &contextx.Default{}), } if !testing.Short() { - t2registries["postgres"], t2registries["mysql"], t2registries["cockroach"], _ = internal.ConnectDatabases(t, false, &contextx.Default{}) - t1registries["postgres"], t1registries["mysql"], t1registries["cockroach"], _ = internal.ConnectDatabases(t, true, &contextx.Default{}) + t2registries["postgres"], t2registries["mysql"], t2registries["cockroach"], _ = testhelpers.ConnectDatabases(t, false, &contextx.Default{}) + t1registries["postgres"], t1registries["mysql"], t1registries["cockroach"], _ = testhelpers.ConnectDatabases(t, true, &contextx.Default{}) } network1NID, _ := uuid.NewV4() diff --git a/spec/config.json b/spec/config.json index 72f81534c66..effd1cc866d 100644 --- a/spec/config.json +++ b/spec/config.json @@ -1071,9 +1071,9 @@ "refresh_token": { "type": "object", "properties": { - "grace_period": { + "rotation_grace_period": { "title": "Refresh Token Rotation Grace Period", - "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is one hour.", + "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is 5 minutes.", "default": "0s", "allOf": [ { diff --git a/x/oauth2cors/cors_test.go b/x/oauth2cors/cors_test.go index d450fe308ab..dee215eae77 100644 --- a/x/oauth2cors/cors_test.go +++ b/x/oauth2cors/cors_test.go @@ -14,6 +14,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/hydra/v2/driver" "github.com/ory/x/contextx" @@ -24,13 +26,12 @@ import ( "github.com/ory/fosite" "github.com/ory/hydra/v2/client" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" ) func TestOAuth2AwareCORSMiddleware(t *testing.T) { ctx := context.Background() - r := internal.NewRegistryMemory(t, internal.NewConfigurationWithDefaults(), &contextx.Default{}) + r := testhelpers.NewRegistryMemory(t, testhelpers.NewConfigurationWithDefaults(), &contextx.Default{}) token, signature, _ := r.OAuth2HMACStrategy().GenerateAccessToken(ctx, nil) for k, tc := range []struct { @@ -275,7 +276,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - r.WithConfig(internal.NewConfigurationWithDefaults()) + r.WithConfig(testhelpers.NewConfigurationWithDefaults()) if tc.prep != nil { tc.prep(t, r) diff --git a/x/tls_termination_test.go b/x/tls_termination_test.go index bdb5581ce91..0c7be56f549 100644 --- a/x/tls_termination_test.go +++ b/x/tls_termination_test.go @@ -10,10 +10,11 @@ import ( "net/url" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" . "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" ) @@ -27,8 +28,8 @@ func noopHandler(w http.ResponseWriter, r *http.Request) { } func TestDoesRequestSatisfyTermination(t *testing.T) { - c := internal.NewConfigurationWithDefaultsAndHTTPS() - r := internal.NewRegistryMemory(t, c, &contextx.Default{}) + c := testhelpers.NewConfigurationWithDefaultsAndHTTPS() + r := testhelpers.NewRegistryMemory(t, c, &contextx.Default{}) t.Run("case=tls-termination-disabled", func(t *testing.T) { c.MustSet(context.Background(), config.KeyTLSAllowTerminationFrom, "") @@ -178,7 +179,7 @@ func TestDoesRequestSatisfyTermination(t *testing.T) { // test: in case http is forced request should be accepted t.Run("case=forced-http", func(t *testing.T) { - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() res := httptest.NewRecorder() RejectInsecureRequests(r, c.TLS(context.Background(), config.PublicInterface))(res, &http.Request{Header: http.Header{}, URL: new(url.URL)}, noopHandler) assert.EqualValues(t, http.StatusNoContent, res.Code)