From 09da66eaa169a09dd64016e503f99c2713da177a Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Thu, 3 Aug 2023 12:36:34 +0200 Subject: [PATCH] fix: don't query by raw signature --- persistence/sql/persister_nid_test.go | 52 ++++++++++++++------------- persistence/sql/persister_oauth2.go | 35 ++++-------------- x/audit_test.go | 2 -- x/clean_sql.go | 3 +- 4 files changed, 34 insertions(+), 58 deletions(-) diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index 6ad1c937aec..b748660b777 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -40,14 +40,16 @@ import ( type PersisterTestSuite struct { suite.Suite registries map[string]driver.Registry - clean func(*testing.T) t1 context.Context t2 context.Context t1NID uuid.UUID t2NID uuid.UUID } -var _ PersisterTestSuite = PersisterTestSuite{} +var _ interface { + suite.SetupAllSuite + suite.TearDownTestSuite +} = (*PersisterTestSuite)(nil) func (s *PersisterTestSuite) SetupSuite() { s.registries = map[string]driver.Registry{ @@ -55,7 +57,7 @@ func (s *PersisterTestSuite) SetupSuite() { } if !testing.Short() { - s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], s.clean = internal.ConnectDatabases(s.T(), true, &contextx.Default{}) + s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{}) } s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4()) @@ -338,7 +340,7 @@ func (s *PersisterTestSuite) TestCreateAuthorizeCodeSession() { fr.Client = &fosite.DefaultClient{ID: c1.LegacyClientID} require.NoError(t, r.Persister().CreateAuthorizeCodeSession(s.t1, sig, fr)) actual := persistencesql.OAuth2RequestSQL{Table: "code"} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) require.Equal(t, s.t1NID, actual.NID) }) } @@ -480,7 +482,7 @@ func (s *PersisterTestSuite) TestCreateOpenIDConnectSession() { require.NoError(t, r.Persister().CreateOpenIDConnectSession(s.t1, authorizeCode, request)) actual := persistencesql.OAuth2RequestSQL{Table: "oidc"} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) require.Equal(t, s.t1NID, actual.NID) }) } @@ -499,9 +501,9 @@ func (s *PersisterTestSuite) TestCreatePKCERequestSession() { authorizeCode := uuid.Must(uuid.NewV4()).String() actual := persistencesql.OAuth2RequestSQL{Table: "pkce"} - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) require.NoError(t, r.Persister().CreatePKCERequestSession(s.t1, authorizeCode, request)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) require.Equal(t, s.t1NID, actual.NID) }) } @@ -519,9 +521,9 @@ func (s *PersisterTestSuite) TestCreateRefreshTokenSession() { authorizeCode := uuid.Must(uuid.NewV4()).String() actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, authorizeCode, request)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) require.Equal(t, s.t1NID, actual.NID) }) } @@ -558,11 +560,11 @@ func (s *PersisterTestSuite) DeleteAccessTokenSession() { require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t2, sig)) actual := persistencesql.OAuth2RequestSQL{Table: "access"} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) require.Equal(t, s.t1NID, actual.NID) require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t1, sig)) - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) }) } } @@ -660,9 +662,9 @@ func (s *PersisterTestSuite) TestDeleteOpenIDConnectSession() { actual := persistencesql.OAuth2RequestSQL{Table: "oidc"} require.NoError(t, r.Persister().DeleteOpenIDConnectSession(s.t2, authorizeCode)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) require.NoError(t, r.Persister().DeleteOpenIDConnectSession(s.t1, authorizeCode)) - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) }) } } @@ -683,9 +685,9 @@ func (s *PersisterTestSuite) TestDeletePKCERequestSession() { actual := persistencesql.OAuth2RequestSQL{Table: "pkce"} require.NoError(t, r.Persister().DeletePKCERequestSession(s.t2, authorizeCode)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) require.NoError(t, r.Persister().DeletePKCERequestSession(s.t1, authorizeCode)) - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode))) }) } } @@ -706,9 +708,9 @@ func (s *PersisterTestSuite) TestDeleteRefreshTokenSession() { actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} require.NoError(t, r.Persister().DeleteRefreshTokenSession(s.t2, signature)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature))) require.NoError(t, r.Persister().DeleteRefreshTokenSession(s.t1, signature)) - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature))) }) } } @@ -910,9 +912,9 @@ func (s *PersisterTestSuite) TestFlushInactiveRefreshTokens() { actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} require.NoError(t, r.Persister().FlushInactiveRefreshTokens(s.t2, time.Now(), 100, 100)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature))) require.NoError(t, r.Persister().FlushInactiveRefreshTokens(s.t1, time.Now(), 100, 100)) - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature))) }) } } @@ -1448,12 +1450,12 @@ func (s *PersisterTestSuite) TestInvalidateAuthorizeCodeSession() { require.NoError(t, r.Persister().InvalidateAuthorizeCodeSession(s.t2, sig)) actual := persistencesql.OAuth2RequestSQL{Table: "code"} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) require.Equal(t, true, actual.Active) require.NoError(t, r.Persister().InvalidateAuthorizeCodeSession(s.t1, sig)) actual = persistencesql.OAuth2RequestSQL{Table: "code"} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) require.Equal(t, false, actual.Active) }) } @@ -1748,10 +1750,10 @@ func (s *PersisterTestSuite) TestRevokeRefreshToken() { actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} require.NoError(t, r.Persister().RevokeRefreshToken(s.t2, request.ID)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature))) require.Equal(t, true, actual.Active) require.NoError(t, r.Persister().RevokeRefreshToken(s.t1, request.ID)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature))) require.Equal(t, false, actual.Active) }) } @@ -1778,10 +1780,10 @@ func (s *PersisterTestSuite) TestRevokeRefreshTokenMaybeGracePeriod() { } require.NoError(t, store.RevokeRefreshTokenMaybeGracePeriod(s.t2, request.ID, signature)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature))) require.Equal(t, true, actual.Active) require.NoError(t, store.RevokeRefreshTokenMaybeGracePeriod(s.t1, request.ID, signature)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature))) require.Equal(t, false, actual.Active) }) } diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index fb2faba6c0c..be7b8c819fc 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -101,7 +101,7 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin return &OAuth2RequestSQL{ Request: r.GetID(), ConsentChallenge: challenge, - ID: p.hashSignature(ctx, rawSignature, table), + ID: SignatureHash(rawSignature), RequestedAt: r.GetRequestedAt(), Client: r.GetClient().GetID(), Scopes: strings.Join(r.GetRequestedScopes(), "|"), @@ -166,14 +166,6 @@ func SignatureHash(signature string) string { return fmt.Sprintf("%x", sha512.Sum384([]byte(signature))) } -// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk). -func (p *Persister) hashSignature(_ context.Context, signature string, table tableName) string { - if table == sqlTableAccess { - return SignatureHash(signature) - } - return signature -} - func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid") defer otelx.End(span, &err) @@ -242,19 +234,12 @@ func (p *Persister) createSession(ctx context.Context, signature string, request return nil } -func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) { +func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature") defer otelx.End(span, &err) r := OAuth2RequestSQL{Table: table} - - // We look for the signature as well as the hash of the signature here. - // This is because we now always store the hash of the signature in the database, - // regardless of the type of the signature. In previous versions, we only stored - // the hash of the signature for JWT tokens. - // - // This code will be removed in a future version. - err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r) + err = p.QueryWithNetwork(ctx).Where("signature = ?", SignatureHash(signature)).First(&r) if errors.Is(err, sql.ErrNoRows) { return nil, errorsx.WithStack(fosite.ErrNotFound) } else if err != nil { @@ -276,17 +261,9 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature") defer otelx.End(span, &err) - signature = p.hashSignature(ctx, signature, table) - - // We look for the signature as well as the hash of the signature here. - // This is because we now always store the hash of the signature in the database, - // regardless of the type of the signature. In previous versions, we only stored - // the hash of the signature for JWT tokens. - // - // This code will be removed in a future version. err = sqlcon.HandleError( p.QueryWithNetwork(ctx). - Where("signature IN (?, ?)", signature, SignatureHash(signature)). + Where("signature = ?", SignatureHash(signature)). Delete(&OAuth2RequestSQL{Table: table})) if errors.Is(err, sqlcon.ErrNoRows) { @@ -356,8 +333,8 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur return sqlcon.HandleError( p.Connection(ctx). RawQuery( - fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()), - signature, + fmt.Sprintf("UPDATE %s SET active = false WHERE signature = ? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()), + SignatureHash(signature), p.NetworkID(ctx), ). Exec(), diff --git a/x/audit_test.go b/x/audit_test.go index ef563c04a53..0a4061551d2 100644 --- a/x/audit_test.go +++ b/x/audit_test.go @@ -43,8 +43,6 @@ func TestLogAudit(t *testing.T) { l.Logger.Out = buf LogAudit(r, tc.message, l) - t.Logf("%s", buf.String()) - assert.Contains(t, buf.String(), "audience=audit") for _, expectContain := range tc.expectContains { assert.Contains(t, buf.String(), expectContain) diff --git a/x/clean_sql.go b/x/clean_sql.go index 59628fb3f97..a02a9a054ce 100644 --- a/x/clean_sql.go +++ b/x/clean_sql.go @@ -10,7 +10,6 @@ import ( ) func DeleteHydraRows(t *testing.T, c *pop.Connection) { - t.Logf("Deleting hydra rows in database: %s", c.Dialect.Name()) for _, tb := range []string{ "hydra_oauth2_access", "hydra_oauth2_refresh", @@ -57,7 +56,7 @@ func CleanSQLPop(t *testing.T, c *pop.Connection) { "schema_migration", } { if err := c.RawQuery("DROP TABLE IF EXISTS " + tb).Exec(); err != nil { - t.Logf(`Unable to clean up table "%s": %s`, tb, err) + t.Fatalf(`Unable to clean up table "%s": %s`, tb, err) } } t.Logf("Successfully cleaned up database: %s", c.Dialect.Name())