Skip to content

Commit

Permalink
fix: faster GetPublicKeys (#3787)
Browse files Browse the repository at this point in the history
GetPublicKeys used to fetch all keys in a set, even if they were actually not being used. This patch fixes that.

Co-authored-by: zepatrik <[email protected]>
  • Loading branch information
aeneasr and zepatrik authored Jun 25, 2024
1 parent c184470 commit 04c34aa
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 60 deletions.
37 changes: 37 additions & 0 deletions jwk/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@ package jwk

import (
"context"
"encoding/json"
"net/http"
"time"

"github.com/pkg/errors"

"github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/errorsx"

jose "github.com/go-jose/go-jose/v3"
"github.com/gofrs/uuid"

Expand Down Expand Up @@ -64,8 +71,38 @@ type (
CreatedAt time.Time `db:"created_at"`
Key string `db:"keydata"`
}

SQLDataRows []SQLData
)

func (d SQLData) TableName() string {
return "hydra_jwk"
}

func (d SQLDataRows) ToJWK(ctx context.Context, r interface {
KeyCipher() *aead.AESGCM
}) (keys *jose.JSONWebKeySet, err error) {
if len(d) == 0 {
return nil, errors.Wrap(x.ErrNotFound, "")
}

keys = &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}}
for _, d := range d {
key, err := r.KeyCipher().Decrypt(ctx, d.Key, nil)
if err != nil {
return nil, errorsx.WithStack(err)
}

var c jose.JSONWebKey
if err := json.Unmarshal(key, &c); err != nil {
return nil, errorsx.WithStack(err)
}
keys.Keys = append(keys.Keys, c)
}

if len(keys.Keys) == 0 {
return nil, errorsx.WithStack(x.ErrNotFound)
}

return keys, nil
}
51 changes: 31 additions & 20 deletions oauth2/fosite_store_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -930,44 +930,55 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
})

t.Run("case=only associated key returns", func(t *testing.T) {
keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "some-key", "sig")
keySetToNotReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "some-key", "sig")
require.NoError(t, err)
require.NoError(t, keyManager.AddKeySet(context.Background(), "some-set", keySetToNotReturn), "adding a random key should not fail")

err = keyManager.AddKeySet(context.TODO(), "some-set", keySet)
require.NoError(t, err)

keySet, err = jwk.GenerateJWK(context.Background(), jose.RS256, "maria-key", "sig")
require.NoError(t, err)

publicKey := keySet.Keys[0].Public()
issuer := "maria"
subject := "[email protected]"
grant := trust.Grant{

keySet1ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-1", "sig")
require.NoError(t, err)
require.NoError(t, grantManager.CreateGrant(context.Background(), trust.Grant{
ID: uuid.New(),
Issuer: issuer,
Subject: subject,
AllowAnySubject: false,
Scope: []string{"openid"},
PublicKey: trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID},
PublicKey: trust.PublicKey{Set: issuer, KeyID: keySet1ToReturn.Keys[0].Public().KeyID},
CreatedAt: time.Now().UTC().Round(time.Second),
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
}
}, keySet1ToReturn.Keys[0].Public()))

err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-2", "sig")
require.NoError(t, err)
require.NoError(t, grantManager.CreateGrant(context.TODO(), trust.Grant{
ID: uuid.New(),
Issuer: issuer,
Subject: subject,
AllowAnySubject: false,
Scope: []string{"openid"},
PublicKey: trust.PublicKey{Set: issuer, KeyID: keySet2ToReturn.Keys[0].Public().KeyID},
CreatedAt: time.Now().UTC().Round(time.Second),
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
}, keySet2ToReturn.Keys[0].Public()))

storedKeySet, err := grantStorage.GetPublicKeys(context.TODO(), issuer, subject)
storedKeySet, err := grantStorage.GetPublicKeys(context.Background(), issuer, subject)
require.NoError(t, err)
assert.Len(t, storedKeySet.Keys, 1)
assert.Equal(t, publicKey.KeyID, storedKeySet.Keys[0].KeyID)
assert.Equal(t, publicKey.Use, storedKeySet.Keys[0].Use)
assert.Equal(t, publicKey.Key, storedKeySet.Keys[0].Key)

storedKeySet, err = grantStorage.GetPublicKeys(context.TODO(), issuer, "non-existing-subject")
require.Len(t, storedKeySet.Keys, 2)
// sorted by created_at DESC, so order is reverse
assert.Equal(t, keySet1ToReturn.Keys[0].Public().KeyID, storedKeySet.Keys[1].KeyID)
assert.Equal(t, keySet1ToReturn.Keys[0].Public().Use, storedKeySet.Keys[1].Use)
assert.Equal(t, keySet1ToReturn.Keys[0].Public().Key, storedKeySet.Keys[1].Key)
assert.Equal(t, keySet2ToReturn.Keys[0].Public().KeyID, storedKeySet.Keys[0].KeyID)
assert.Equal(t, keySet2ToReturn.Keys[0].Public().Use, storedKeySet.Keys[0].Use)
assert.Equal(t, keySet2ToReturn.Keys[0].Public().Key, storedKeySet.Keys[0].Key)

storedKeySet, err = grantStorage.GetPublicKeys(context.Background(), issuer, "non-existing-subject")
require.NoError(t, err)
assert.Len(t, storedKeySet.Keys, 0)

_, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "non-existing-subject", publicKey.KeyID)
_, err = grantStorage.GetPublicKeyScopes(context.Background(), issuer, "non-existing-subject", keySet2ToReturn.Keys[0].Public().KeyID)
require.Error(t, err)
})

Expand Down
43 changes: 31 additions & 12 deletions persistence/sql/persister_grant_jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"strings"
"time"

"github.com/ory/hydra/v2/jwk"

"github.com/pkg/errors"

"github.com/go-jose/go-jose/v3"
Expand Down Expand Up @@ -140,7 +142,7 @@ func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject st

grantsData := make([]trust.SQLData, 0)
query := p.QueryWithNetwork(ctx).
Select("key_set", "key_id").
Select("key_id").
Where(expiresAt).
Where("issuer = ?", issuer).
Where("(subject = ? OR allow_any_subject IS TRUE)", subject).
Expand All @@ -155,21 +157,38 @@ func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject st
return &jose.JSONWebKeySet{}, nil
}

// because keys must be grouped by issuer, we can retrieve set name from first grant
keySet, err := p.GetKeySet(ctx, grantsData[0].KeySet)
if err != nil {
return nil, err
keyIDs := make([]interface{}, len(grantsData))
for k, d := range grantsData {
keyIDs[k] = d.KeyID
}

// find keys, that belong to grants
filteredKeySet := &jose.JSONWebKeySet{}
for _, data := range grantsData {
if keys := keySet.Key(data.KeyID); len(keys) > 0 {
filteredKeySet.Keys = append(filteredKeySet.Keys, keys...)
}
var js jwk.SQLDataRows
if err := p.QueryWithNetwork(ctx).
// key_set and issuer are set to the same value on creation:
//
// grant := Grant{
// ID: uuid.New().String(),
// Issuer: grantRequest.Issuer,
// Subject: grantRequest.Subject,
// AllowAnySubject: grantRequest.AllowAnySubject,
// Scope: grantRequest.Scope,
// PublicKey: PublicKey{
// Set: grantRequest.Issuer, // group all keys by issuer, so set=issuer
// KeyID: grantRequest.PublicKeyJWK.KeyID,
// },
// CreatedAt: time.Now().UTC().Round(time.Second),
// ExpiresAt: grantRequest.ExpiresAt.UTC().Round(time.Second),
// }
//
// Therefore it is fine if we only look for the issuer here instead of the key set id.
Where("sid = ?", issuer).
Where("kid IN (?)", keyIDs).
Order("created_at DESC").
All(&js); err != nil {
return nil, sqlcon.HandleError(err)
}

return filteredKeySet, nil
return js.ToJWK(ctx, p.r)
}

func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) (_ []string, err error) {
Expand Down
27 changes: 2 additions & 25 deletions persistence/sql/persister_jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/pkg/errors"

"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/sqlcon"
)

Expand Down Expand Up @@ -152,37 +151,15 @@ func (p *Persister) GetKeySet(ctx context.Context, set string) (keys *jose.JSONW
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetKeySet")
defer span.End()

var js []jwk.SQLData
var js jwk.SQLDataRows
if err := p.QueryWithNetwork(ctx).
Where("sid = ?", set).
Order("created_at DESC").
All(&js); err != nil {
return nil, sqlcon.HandleError(err)
}

if len(js) == 0 {
return nil, errors.Wrap(x.ErrNotFound, "")
}

keys = &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}}
for _, d := range js {
key, err := p.r.KeyCipher().Decrypt(ctx, d.Key, nil)
if err != nil {
return nil, errorsx.WithStack(err)
}

var c jose.JSONWebKey
if err := json.Unmarshal(key, &c); err != nil {
return nil, errorsx.WithStack(err)
}
keys.Keys = append(keys.Keys, c)
}

if len(keys.Keys) == 0 {
return nil, errorsx.WithStack(x.ErrNotFound)
}

return keys, nil
return js.ToJWK(ctx, p.r)
}

func (p *Persister) DeleteKey(ctx context.Context, set, kid string) error {
Expand Down
8 changes: 5 additions & 3 deletions persistence/sql/persister_nid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1359,13 +1359,15 @@ func (s *PersisterTestSuite) TestGetPublicKeys() {
t := s.T()
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
ks := newKeySet("ks-id", "use")
const issuer = "ks-id"
ks := newKeySet(issuer, "use")
grant := trust.Grant{
ID: uuid.Must(uuid.NewV4()).String(),
ExpiresAt: time.Now().Add(time.Hour),
PublicKey: trust.PublicKey{Set: "ks-id", KeyID: ks.Keys[0].KeyID},
Issuer: issuer,
PublicKey: trust.PublicKey{Set: issuer, KeyID: ks.Keys[0].KeyID},
}
require.NoError(t, r.Persister().AddKeySet(s.t1, "ks-id", ks))
require.NoError(t, r.Persister().AddKeySet(s.t1, issuer, ks))
require.NoError(t, r.Persister().CreateGrant(s.t1, grant, ks.Keys[0]))

actual, err := r.Persister().GetPublicKeys(s.t2, grant.Issuer, grant.Subject)
Expand Down

0 comments on commit 04c34aa

Please sign in to comment.