From 185afde1511e5720460d10e1570bed6a9a64b291 Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Tue, 16 Jul 2024 15:18:20 +0200 Subject: [PATCH] fix: store expiry times in UTC --- oauth2/fosite_store_helpers.go | 18 +++++++++--------- oauth2/fosite_store_test.go | 2 +- persistence/sql/persister_oauth2.go | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go index 5d1918df13e..6b64d61991e 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers.go @@ -146,7 +146,7 @@ var flushRequests = []*fosite.Request{ }, } -func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry, createClient bool) { +func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry) { cl := &client.Client{ID: "foobar"} cr := &flow.OAuth2ConsentRequest{ Client: cl, @@ -160,7 +160,7 @@ func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry, createCl } ctx := context.Background() - if createClient { + if _, err := x.ClientManager().GetClient(ctx, cl.ID); errors.Is(err, sqlcon.ErrNoRows) { require.NoError(t, x.ClientManager().CreateClient(ctx, cl)) } @@ -230,7 +230,7 @@ func TestHelperRunner(t *testing.T, store InternalRegistry, k string) { func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) { return func(t *testing.T) { requestId := uuid.New() - mockRequestForeignKey(t, requestId, m, true) + mockRequestForeignKey(t, requestId, m) cl := &client.Client{ID: "foobar"} fositeRequest := &fosite.Request{ @@ -313,8 +313,8 @@ func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) { reqIdOne := uuid.New() reqIdTwo := uuid.New() - mockRequestForeignKey(t, reqIdOne, x, false) - mockRequestForeignKey(t, reqIdTwo, x, false) + mockRequestForeignKey(t, reqIdOne, x) + mockRequestForeignKey(t, reqIdTwo, x) err = m.CreateRefreshTokenSession(ctx, "1111", &fosite.Request{ ID: reqIdOne, @@ -356,7 +356,7 @@ func testHelperCreateGetDeleteAuthorizeCodes(x InternalRegistry) func(t *testing return func(t *testing.T) { m := x.OAuth2Storage() - mockRequestForeignKey(t, "blank", x, false) + mockRequestForeignKey(t, "blank", x) ctx := context.Background() res, err := m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar")) @@ -394,7 +394,7 @@ func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { m := reg.OAuth2Storage() t.Parallel() - mockRequestForeignKey(t, "blank", reg, false) + mockRequestForeignKey(t, "blank", reg) ctx := context.Background() @@ -583,7 +583,7 @@ func testHelperFlushTokens(x InternalRegistry, lifespan time.Duration) func(t *t return func(t *testing.T) { ctx := context.Background() for _, r := range flushRequests { - mockRequestForeignKey(t, r.ID, x, false) + mockRequestForeignKey(t, r.ID, x) require.NoError(t, m.CreateAccessTokenSession(ctx, r.ID, r)) _, err := m.GetAccessTokenSession(ctx, r.ID, ds) require.NoError(t, err) @@ -630,7 +630,7 @@ func testHelperFlushTokensWithLimitAndBatchSize(x InternalRegistry, limit int, b for i := 0; i < totalCount; i++ { r := createTestRequest(fmt.Sprintf("%s-%d", id, i+1)) r.RequestedAt = time.Now().Add(-2 * time.Hour) - mockRequestForeignKey(t, r.ID, x, false) + mockRequestForeignKey(t, r.ID, x) require.NoError(t, m.CreateAccessTokenSession(ctx, r.ID, r)) _, err := m.GetAccessTokenSession(ctx, r.ID, ds) require.NoError(t, err) diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index cc437a1e9d6..2a48a52f8e7 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -38,7 +38,7 @@ func setupRegistries(t *testing.T) { if len(registries) == 0 && !testing.Short() { // first time called and sql tests var cleanSQL func(*testing.T) - registries["postgres"], _, _, cleanSQL = internal.ConnectDatabases(t, true, &contextx.Default{}) + registries["postgres"], registries["mysql"], registries["cockroach"], cleanSQL = internal.ConnectDatabases(t, true, &contextx.Default{}) cleanMem := cleanRegistries cleanMem(t) cleanRegistries = func(t *testing.T) { diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 5888c0b1cad..6e1336b80de 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -312,7 +312,7 @@ func (p *Persister) deactivateSessionByRequestID(ctx context.Context, id string, func (p *Persister) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error { return otelx.WithSpan(ctx, "persistence.sql.CreateAuthorizeCodeSession", func(ctx context.Context) error { - return p.createSession(ctx, signature, requester, sqlTableCode, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode)) + return p.createSession(ctx, signature, requester, sqlTableCode, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) }) } @@ -347,7 +347,7 @@ func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature stri append(toEventOptions(requester), events.WithGrantType(requester.GetRequestForm().Get("grant_type")))..., ) - return p.createSession(ctx, x.SignatureHash(signature), requester, sqlTableAccess, requester.GetSession().GetExpiresAt(fosite.AccessToken)) + return p.createSession(ctx, x.SignatureHash(signature), requester, sqlTableAccess, requester.GetSession().GetExpiresAt(fosite.AccessToken).UTC()) } func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { @@ -423,7 +423,7 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRefreshTokenSession") defer otelx.End(span, &err) events.Trace(ctx, events.RefreshTokenIssued, toEventOptions(requester)...) - return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken)) + return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) } func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { @@ -442,8 +442,8 @@ func (p *Persister) CreateOpenIDConnectSession(ctx context.Context, signature st ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateOpenIDConnectSession") defer otelx.End(span, &err) events.Trace(ctx, events.IdentityTokenIssued, toEventOptions(requester)...) - // The expiry of a PKCE session is equal to the expiry of the authorization code. If the code is invalid, so is this OIDC request. - return p.createSession(ctx, signature, requester, sqlTableOpenID, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode)) + // The expiry of an OIDC session is equal to the expiry of the authorization code. If the code is invalid, so is this OIDC request. + return p.createSession(ctx, signature, requester, sqlTableOpenID, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) } func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (_ fosite.Requester, err error) { @@ -468,7 +468,7 @@ func (p *Persister) CreatePKCERequestSession(ctx context.Context, signature stri ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreatePKCERequestSession") defer otelx.End(span, &err) // The expiry of a PKCE session is equal to the expiry of the authorization code. If the code is invalid, so is this PKCE request. - return p.createSession(ctx, signature, requester, sqlTablePKCE, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode)) + return p.createSession(ctx, signature, requester, sqlTablePKCE, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) } func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature string) (err error) {