Skip to content

Commit

Permalink
fix: don't query by raw signature
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Aug 3, 2023
1 parent ca68fe9 commit 09da66e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 58 deletions.
52 changes: 27 additions & 25 deletions persistence/sql/persister_nid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,24 @@ import (
type PersisterTestSuite struct {
suite.Suite
registries map[string]driver.Registry
clean func(*testing.T)
t1 context.Context
t2 context.Context
t1NID uuid.UUID
t2NID uuid.UUID
}

var _ PersisterTestSuite = PersisterTestSuite{}
var _ interface {
suite.SetupAllSuite
suite.TearDownTestSuite
} = (*PersisterTestSuite)(nil)

func (s *PersisterTestSuite) SetupSuite() {
s.registries = map[string]driver.Registry{
"memory": internal.NewRegistrySQLFromURL(s.T(), dbal.NewSQLiteTestDatabase(s.T()), true, &contextx.Default{}),
}

if !testing.Short() {
s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], s.clean = internal.ConnectDatabases(s.T(), true, &contextx.Default{})
s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{})
}

s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4())
Expand Down Expand Up @@ -338,7 +340,7 @@ func (s *PersisterTestSuite) TestCreateAuthorizeCodeSession() {
fr.Client = &fosite.DefaultClient{ID: c1.LegacyClientID}
require.NoError(t, r.Persister().CreateAuthorizeCodeSession(s.t1, sig, fr))
actual := persistencesql.OAuth2RequestSQL{Table: "code"}
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
require.Equal(t, s.t1NID, actual.NID)
})
}
Expand Down Expand Up @@ -480,7 +482,7 @@ func (s *PersisterTestSuite) TestCreateOpenIDConnectSession() {
require.NoError(t, r.Persister().CreateOpenIDConnectSession(s.t1, authorizeCode, request))

actual := persistencesql.OAuth2RequestSQL{Table: "oidc"}
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
require.Equal(t, s.t1NID, actual.NID)
})
}
Expand All @@ -499,9 +501,9 @@ func (s *PersisterTestSuite) TestCreatePKCERequestSession() {
authorizeCode := uuid.Must(uuid.NewV4()).String()

actual := persistencesql.OAuth2RequestSQL{Table: "pkce"}
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
require.NoError(t, r.Persister().CreatePKCERequestSession(s.t1, authorizeCode, request))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
require.Equal(t, s.t1NID, actual.NID)
})
}
Expand All @@ -519,9 +521,9 @@ func (s *PersisterTestSuite) TestCreateRefreshTokenSession() {

authorizeCode := uuid.Must(uuid.NewV4()).String()
actual := persistencesql.OAuth2RequestSQL{Table: "refresh"}
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, authorizeCode, request))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
require.Equal(t, s.t1NID, actual.NID)
})
}
Expand Down Expand Up @@ -558,11 +560,11 @@ func (s *PersisterTestSuite) DeleteAccessTokenSession() {
require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t2, sig))

actual := persistencesql.OAuth2RequestSQL{Table: "access"}
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
require.Equal(t, s.t1NID, actual.NID)

require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t1, sig))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
})
}
}
Expand Down Expand Up @@ -660,9 +662,9 @@ func (s *PersisterTestSuite) TestDeleteOpenIDConnectSession() {
actual := persistencesql.OAuth2RequestSQL{Table: "oidc"}

require.NoError(t, r.Persister().DeleteOpenIDConnectSession(s.t2, authorizeCode))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
require.NoError(t, r.Persister().DeleteOpenIDConnectSession(s.t1, authorizeCode))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
})
}
}
Expand All @@ -683,9 +685,9 @@ func (s *PersisterTestSuite) TestDeletePKCERequestSession() {
actual := persistencesql.OAuth2RequestSQL{Table: "pkce"}

require.NoError(t, r.Persister().DeletePKCERequestSession(s.t2, authorizeCode))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
require.NoError(t, r.Persister().DeletePKCERequestSession(s.t1, authorizeCode))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(authorizeCode)))
})
}
}
Expand All @@ -706,9 +708,9 @@ func (s *PersisterTestSuite) TestDeleteRefreshTokenSession() {
actual := persistencesql.OAuth2RequestSQL{Table: "refresh"}

require.NoError(t, r.Persister().DeleteRefreshTokenSession(s.t2, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature)))
require.NoError(t, r.Persister().DeleteRefreshTokenSession(s.t1, signature))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, signature))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature)))
})
}
}
Expand Down Expand Up @@ -910,9 +912,9 @@ func (s *PersisterTestSuite) TestFlushInactiveRefreshTokens() {
actual := persistencesql.OAuth2RequestSQL{Table: "refresh"}

require.NoError(t, r.Persister().FlushInactiveRefreshTokens(s.t2, time.Now(), 100, 100))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature)))
require.NoError(t, r.Persister().FlushInactiveRefreshTokens(s.t1, time.Now(), 100, 100))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, signature))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature)))
})
}
}
Expand Down Expand Up @@ -1448,12 +1450,12 @@ func (s *PersisterTestSuite) TestInvalidateAuthorizeCodeSession() {

require.NoError(t, r.Persister().InvalidateAuthorizeCodeSession(s.t2, sig))
actual := persistencesql.OAuth2RequestSQL{Table: "code"}
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
require.Equal(t, true, actual.Active)

require.NoError(t, r.Persister().InvalidateAuthorizeCodeSession(s.t1, sig))
actual = persistencesql.OAuth2RequestSQL{Table: "code"}
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
require.Equal(t, false, actual.Active)
})
}
Expand Down Expand Up @@ -1748,10 +1750,10 @@ func (s *PersisterTestSuite) TestRevokeRefreshToken() {
actual := persistencesql.OAuth2RequestSQL{Table: "refresh"}

require.NoError(t, r.Persister().RevokeRefreshToken(s.t2, request.ID))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature)))
require.Equal(t, true, actual.Active)
require.NoError(t, r.Persister().RevokeRefreshToken(s.t1, request.ID))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature)))
require.Equal(t, false, actual.Active)
})
}
Expand All @@ -1778,10 +1780,10 @@ func (s *PersisterTestSuite) TestRevokeRefreshTokenMaybeGracePeriod() {
}

require.NoError(t, store.RevokeRefreshTokenMaybeGracePeriod(s.t2, request.ID, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature)))
require.Equal(t, true, actual.Active)
require.NoError(t, store.RevokeRefreshTokenMaybeGracePeriod(s.t1, request.ID, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(signature)))
require.Equal(t, false, actual.Active)
})
}
Expand Down
35 changes: 6 additions & 29 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin
return &OAuth2RequestSQL{
Request: r.GetID(),
ConsentChallenge: challenge,
ID: p.hashSignature(ctx, rawSignature, table),
ID: SignatureHash(rawSignature),
RequestedAt: r.GetRequestedAt(),
Client: r.GetClient().GetID(),
Scopes: strings.Join(r.GetRequestedScopes(), "|"),
Expand Down Expand Up @@ -166,14 +166,6 @@ func SignatureHash(signature string) string {
return fmt.Sprintf("%x", sha512.Sum384([]byte(signature)))
}

// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk).
func (p *Persister) hashSignature(_ context.Context, signature string, table tableName) string {
if table == sqlTableAccess {
return SignatureHash(signature)
}
return signature
}

func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid")
defer otelx.End(span, &err)
Expand Down Expand Up @@ -242,19 +234,12 @@ func (p *Persister) createSession(ctx context.Context, signature string, request
return nil
}

func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) {
func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature")
defer otelx.End(span, &err)

r := OAuth2RequestSQL{Table: table}

// We look for the signature as well as the hash of the signature here.
// This is because we now always store the hash of the signature in the database,
// regardless of the type of the signature. In previous versions, we only stored
// the hash of the signature for JWT tokens.
//
// This code will be removed in a future version.
err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r)
err = p.QueryWithNetwork(ctx).Where("signature = ?", SignatureHash(signature)).First(&r)
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(fosite.ErrNotFound)
} else if err != nil {
Expand All @@ -276,17 +261,9 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature")
defer otelx.End(span, &err)

signature = p.hashSignature(ctx, signature, table)

// We look for the signature as well as the hash of the signature here.
// This is because we now always store the hash of the signature in the database,
// regardless of the type of the signature. In previous versions, we only stored
// the hash of the signature for JWT tokens.
//
// This code will be removed in a future version.
err = sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature IN (?, ?)", signature, SignatureHash(signature)).
Where("signature = ?", SignatureHash(signature)).
Delete(&OAuth2RequestSQL{Table: table}))

if errors.Is(err, sqlcon.ErrNoRows) {
Expand Down Expand Up @@ -356,8 +333,8 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur
return sqlcon.HandleError(
p.Connection(ctx).
RawQuery(
fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()),
signature,
fmt.Sprintf("UPDATE %s SET active = false WHERE signature = ? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()),
SignatureHash(signature),
p.NetworkID(ctx),
).
Exec(),
Expand Down
2 changes: 0 additions & 2 deletions x/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ func TestLogAudit(t *testing.T) {
l.Logger.Out = buf
LogAudit(r, tc.message, l)

t.Logf("%s", buf.String())

assert.Contains(t, buf.String(), "audience=audit")
for _, expectContain := range tc.expectContains {
assert.Contains(t, buf.String(), expectContain)
Expand Down
3 changes: 1 addition & 2 deletions x/clean_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
)

func DeleteHydraRows(t *testing.T, c *pop.Connection) {
t.Logf("Deleting hydra rows in database: %s", c.Dialect.Name())
for _, tb := range []string{
"hydra_oauth2_access",
"hydra_oauth2_refresh",
Expand Down Expand Up @@ -57,7 +56,7 @@ func CleanSQLPop(t *testing.T, c *pop.Connection) {
"schema_migration",
} {
if err := c.RawQuery("DROP TABLE IF EXISTS " + tb).Exec(); err != nil {
t.Logf(`Unable to clean up table "%s": %s`, tb, err)
t.Fatalf(`Unable to clean up table "%s": %s`, tb, err)
}
}
t.Logf("Successfully cleaned up database: %s", c.Dialect.Name())
Expand Down

0 comments on commit 09da66e

Please sign in to comment.