diff --git a/doc/modules/ROOT/pages/configuration.adoc b/doc/modules/ROOT/pages/configuration.adoc index 6d5a122d..20a9490e 100644 --- a/doc/modules/ROOT/pages/configuration.adoc +++ b/doc/modules/ROOT/pages/configuration.adoc @@ -61,6 +61,7 @@ SmallRye JWT supports many properties which can be used to customize the token p |smallrye.jwt.groups-separator|' '|Separator for splitting a string which may contain multiple group values. It will only be used if the `smallrye.jwt.path.groups` property points to a custom claim whose value is a string. The default value is a single space because a standard OAuth2 `scope` claim may contain a space separated sequence. |smallrye.jwt.claims.groups|none| This property can be used to set a default groups claim value when the current token has no standard groups claim available (or no custom groups claim when `smallrye.jwt.path.groups` is used). |smallrye.jwt.jwks.refresh-interval|60|JWK cache refresh interval in minutes. It will be ignored unless the `mp.jwt.verify.publickey.location` points to the HTTP or HTTPS URL based JWK set and no HTTP `Cache-Control` response header with a positive `max-age` parameter value is returned from a JWK set endpoint. +|smallrye.jwt.jwks.retain-cache-on-error-duration|0|JWK cache retain on error duration in minutes which sets the length of time, before trying again, to keep using the cache when an error occurs making the request to the JWKS URI or parsing the response. It will be ignored unless the `mp.jwt.verify.publickey.location` property points to the HTTP or HTTPS URL based JWK set. |smallrye.jwt.jwks.forced-refresh-interval|30|Forced JWK cache refresh interval in minutes which is used to restrict the frequency of the forced refresh attempts which may happen when the token verification fails due to the cache having no JWK key with a `kid` property matching the current token's `kid` header. It will be ignored unless the `mp.jwt.verify.publickey.location` points to the HTTP or HTTPS URL based JWK set. |smallrye.jwt.expiration.grace|0|Expiration grace in seconds. By default an expired token will still be accepted if the current time is no more than 1 min after the token expiry time. This property is deprecated. Use `mp.jwt.verify.clock.skew` instead. |smallrye.jwt.verify.aud|none|Comma separated list of the audiences that a token `aud` claim may contain. This property is deprecated. Use `mp.jwt.verify.audiences` instead. diff --git a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/AbstractKeyLocationResolver.java b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/AbstractKeyLocationResolver.java index f039e04e..cc6db8b4 100644 --- a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/AbstractKeyLocationResolver.java +++ b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/AbstractKeyLocationResolver.java @@ -115,6 +115,7 @@ protected HttpsJwks initializeHttpsJwks(String location) new InetSocketAddress(authContextInfo.getHttpProxyHost(), authContextInfo.getHttpProxyPort()))); } theHttpsJwks.setSimpleHttpGet(httpGet); + theHttpsJwks.setRetainCacheOnErrorDuration(authContextInfo.getJwksRetainCacheOnErrorDuration() * 60L); return theHttpsJwks; } diff --git a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/JWTAuthContextInfo.java b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/JWTAuthContextInfo.java index 9eb36be8..d9486381 100644 --- a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/JWTAuthContextInfo.java +++ b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/JWTAuthContextInfo.java @@ -51,6 +51,7 @@ public class JWTAuthContextInfo { private String decryptionKeyContent; private Integer jwksRefreshInterval = 60; private int forcedJwksRefreshInterval = 30; + private int jwksRetainCacheOnErrorDuration = 0; private String tokenHeader = "Authorization"; private String tokenCookie; private boolean alwaysCheckAuthorization; @@ -121,6 +122,7 @@ public JWTAuthContextInfo(JWTAuthContextInfo orig) { this.decryptionKeyContent = orig.getDecryptionKeyContent(); this.jwksRefreshInterval = orig.getJwksRefreshInterval(); this.forcedJwksRefreshInterval = orig.getForcedJwksRefreshInterval(); + this.jwksRetainCacheOnErrorDuration = orig.getJwksRetainCacheOnErrorDuration(); this.tokenHeader = orig.getTokenHeader(); this.tokenCookie = orig.getTokenCookie(); this.alwaysCheckAuthorization = orig.isAlwaysCheckAuthorization(); @@ -283,6 +285,14 @@ public void setForcedJwksRefreshInterval(int forcedJwksRefreshInterval) { this.forcedJwksRefreshInterval = forcedJwksRefreshInterval; } + public int getJwksRetainCacheOnErrorDuration() { + return jwksRetainCacheOnErrorDuration; + } + + public void setJwksRetainCacheOnErrorDuration(int jwksRetainCacheOnErrorDuration) { + this.jwksRetainCacheOnErrorDuration = jwksRetainCacheOnErrorDuration; + } + public String getTokenHeader() { return tokenHeader; } @@ -436,6 +446,7 @@ public String toString() { ", decryptionKeyLocation='" + decryptionKeyLocation + '\'' + ", decryptionKeyContent='" + decryptionKeyContent + '\'' + ", jwksRefreshInterval=" + jwksRefreshInterval + + ", jwksRetainCacheOnErrorDuration=" + jwksRetainCacheOnErrorDuration + ", tokenHeader='" + tokenHeader + '\'' + ", tokenCookie='" + tokenCookie + '\'' + ", alwaysCheckAuthorization=" + alwaysCheckAuthorization + diff --git a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/config/JWTAuthContextInfoProvider.java b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/config/JWTAuthContextInfoProvider.java index dd2482c4..c4385f48 100644 --- a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/config/JWTAuthContextInfoProvider.java +++ b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/config/JWTAuthContextInfoProvider.java @@ -186,6 +186,7 @@ private static JWTAuthContextInfoProvider create(String key, provider.mpJwtVerifyTokenAge = Optional.empty(); provider.jwksRefreshInterval = 60; provider.forcedJwksRefreshInterval = 30; + provider.jwksRetainCacheOnErrorDuration = 0; provider.signatureAlgorithm = Optional.of(SignatureAlgorithm.RS256); provider.keyEncryptionAlgorithm = Optional.empty(); provider.mpJwtDecryptKeyAlgorithm = new HashSet<>(Arrays.asList(KeyEncryptionAlgorithm.RSA_OAEP, @@ -465,6 +466,15 @@ private static JWTAuthContextInfoProvider create(String key, @ConfigProperty(name = "smallrye.jwt.jwks.forced-refresh-interval", defaultValue = "30") private int forcedJwksRefreshInterval; + /** + * JWK cache retain on error duration in minutes which sets the length of time, before trying again, to keep using the cache + * when an error occurs making the request to the JWKS URI or parsing the response. + * It will be ignored unless the 'mp.jwt.verify.publickey.location' property points to the HTTP or HTTPS URL based JWK set. + */ + @Inject + @ConfigProperty(name = "smallrye.jwt.jwks.retain-cache-on-error-duration", defaultValue = "0") + private int jwksRetainCacheOnErrorDuration; + /** * Supported JSON Web Algorithm asymmetric or symmetric signature algorithm. * @@ -836,6 +846,7 @@ Optional getOptionalContextInfo() { contextInfo.setTokenAge(mpJwtVerifyTokenAge.orElse(null)); contextInfo.setJwksRefreshInterval(jwksRefreshInterval); contextInfo.setForcedJwksRefreshInterval(forcedJwksRefreshInterval); + contextInfo.setJwksRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration); Set resolvedAlgorithm = mpJwtPublicKeyAlgorithm; if (signatureAlgorithm.isPresent()) { if (signatureAlgorithm.get().getAlgorithm().startsWith("HS")) { diff --git a/implementation/jwt-auth/src/test/java/io/smallrye/jwt/auth/principal/KeyLocationResolverTest.java b/implementation/jwt-auth/src/test/java/io/smallrye/jwt/auth/principal/KeyLocationResolverTest.java index 6ec74134..57c670c7 100644 --- a/implementation/jwt-auth/src/test/java/io/smallrye/jwt/auth/principal/KeyLocationResolverTest.java +++ b/implementation/jwt-auth/src/test/java/io/smallrye/jwt/auth/principal/KeyLocationResolverTest.java @@ -21,7 +21,11 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.net.Proxy; @@ -30,12 +34,18 @@ import java.security.interfaces.RSAPublicKey; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; +import java.util.concurrent.TimeUnit; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; +import org.jose4j.base64url.Base64Url; import org.jose4j.http.Get; +import org.jose4j.http.SimpleResponse; +import org.jose4j.json.internal.json_simple.JSONObject; import org.jose4j.jwk.HttpsJwks; import org.jose4j.jwk.JsonWebKey; import org.jose4j.jwk.OctetSequenceJsonWebKey; @@ -68,6 +78,8 @@ class KeyLocationResolverTest { Get mockedGet; @Mock UrlStreamResolver urlResolver; + @Mock + SimpleResponse simpleResponse; RSAPublicKey rsaKey; SecretKey secretKey; @@ -180,6 +192,46 @@ protected Get getHttpGet() { assertNull(keyLocationResolver.key); } + @Test + void keepsRsaKeyFromHttpsJwksWhenErrorDuringRefresh() throws Exception { + long cacheDuration = 1L; + int jwksRetainCacheOnErrorDuration = 10; + JWTAuthContextInfo contextInfo = new JWTAuthContextInfo("https://github.com/my_key.jwks", "issuer"); + contextInfo.setJwksRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration); + + HttpsJwks spiedHttpsJwks = Mockito.spy(new HttpsJwks(contextInfo.getPublicKeyLocation())); + spiedHttpsJwks.setDefaultCacheDuration(cacheDuration); + when(simpleResponse.getBody()).thenReturn(generateJWK(rsaKey)); + when(mockedGet.get(contextInfo.getPublicKeyLocation())).thenReturn(simpleResponse); + + KeyLocationResolver keyLocationResolver = new KeyLocationResolver(contextInfo) { + protected HttpsJwks getHttpsJwks(String loc) { + return spiedHttpsJwks; + } + + protected Get getHttpGet() { + return mockedGet; + } + }; + + Mockito.verify(spiedHttpsJwks).setRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration * 60L); + Mockito.verify(spiedHttpsJwks).setSimpleHttpGet(mockedGet); + + when(signature.getHeaders()).thenReturn(headers); + when(headers.getStringHeaderValue(JsonWebKey.KEY_ID_PARAMETER)).thenReturn("1"); + when(headers.getStringHeaderValue(JsonWebKey.ALGORITHM_PARAMETER)).thenReturn("RS256"); + + assertEquals(rsaKey, keyLocationResolver.resolveKey(signature, emptyList())); + + doThrow(RuntimeException.class).when(mockedGet).get(contextInfo.getPublicKeyLocation()); + + TimeUnit.SECONDS.sleep(cacheDuration); + + assertEquals(rsaKey, keyLocationResolver.resolveKey(signature, emptyList())); + + verify(mockedGet, atLeastOnce()).get(contextInfo.getPublicKeyLocation()); + } + @Test void loadRsaKeyFromHttpJwks() throws Exception { JWTAuthContextInfo contextInfo = new JWTAuthContextInfo("http://github.com/my_key.jwks", "issuer"); @@ -330,7 +382,7 @@ void loadHttpsPemCrt() throws Exception { contextInfo.setJwksRefreshInterval(10); Mockito.doThrow(new JoseException("")).when(mockedHttpsJwks).refresh(); - Mockito.doReturn(ResourceUtils.getAsClasspathResource("publicCrt.pem")) + doReturn(ResourceUtils.getAsClasspathResource("publicCrt.pem")) .when(urlResolver).resolve(Mockito.any()); KeyLocationResolver keyLocationResolver = new KeyLocationResolver(contextInfo) { protected HttpsJwks initializeHttpsJwks(String loc) { @@ -380,4 +432,18 @@ void loadJWKOnClassPath() throws Exception { assertEquals(keyLocationResolver.key, keyLocationResolver.getJsonWebKey("key1", null).getKey()); } + + private String generateJWK(RSAPublicKey publicKey) { + Map key = new HashMap<>(); + + key.put("alg", "RS256"); + key.put("use", "sig"); + key.put("kty", publicKey.getAlgorithm()); + key.put("kid", "1"); + key.put("n", Base64Url.encode(publicKey.getModulus().toByteArray())); + key.put("e", Base64Url.encode(publicKey.getPublicExponent().toByteArray())); + + return JSONObject.toJSONString(Collections.singletonMap("keys", + Collections.singletonList(key))); + } }