diff --git a/jwk/helper.go b/jwk/helper.go index 50f3a28b2d..2c94b151c8 100644 --- a/jwk/helper.go +++ b/jwk/helper.go @@ -12,7 +12,8 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" - "sync" + + "golang.org/x/sync/singleflight" hydra "github.com/ory/hydra-client-go/v2" @@ -26,17 +27,7 @@ import ( "github.com/pkg/errors" ) -var mapLock sync.RWMutex -var locks = map[string]*sync.RWMutex{} - -func getLock(set string) *sync.RWMutex { - mapLock.Lock() - defer mapLock.Unlock() - if _, ok := locks[set]; !ok { - locks[set] = new(sync.RWMutex) - } - return locks[set] -} +var genJWKGroup singleflight.Group func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, set string) error { _, err := GetOrGenerateKeys(ctx, r, r.KeyManager(), set, set, alg) @@ -44,37 +35,33 @@ func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, } func GetOrGenerateKeys(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) { - getLock(set).Lock() - defer getLock(set).Unlock() - keys, err := m.GetKeySet(ctx, set) - if errors.Is(err, x.ErrNotFound) || keys != nil && len(keys.Keys) == 0 { - r.Logger().Warnf("JSON Web Key Set \"%s\" does not exist yet, generating new key pair...", set) - keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") - if err != nil { - return nil, err - } - } else if err != nil { + if err != nil && !errors.Is(err, x.ErrNotFound) { return nil, err } - privKey, privKeyErr := FindPrivateKey(keys) - if privKeyErr == nil { - return privKey, nil - } else { - r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set) - - keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") - if err != nil { - return nil, err + if keys != nil && len(keys.Keys) > 0 { + privKey, privKeyErr := FindPrivateKey(keys) + if privKeyErr == nil { + return privKey, nil } + } - privKey, err := FindPrivateKey(keys) - if err != nil { - return nil, err - } - return privKey, nil + r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set) + + // Suppress duplicate keyset generation + keysResult, err, _ := genJWKGroup.Do(set+alg, func() (any, error) { + return m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") + }) + if err != nil { + return nil, err + } + + privKey, err := FindPrivateKey(keysResult.(*jose.JSONWebKeySet)) + if err != nil { + return nil, err } + return privKey, nil } func First(keys []jose.JSONWebKey) *jose.JSONWebKey {