Skip to content

Commit

Permalink
Add jwksRetainOnErrorDuration (#840)
Browse files Browse the repository at this point in the history
  • Loading branch information
luneo7 authored Nov 15, 2024
1 parent 71d80a8 commit a8f50f1
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/modules/ROOT/pages/configuration.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ SmallRye JWT supports many properties which can be used to customize the token p
|smallrye.jwt.groups-separator|' '|Separator for splitting a string which may contain multiple group values. It will only be used if the `smallrye.jwt.path.groups` property points to a custom claim whose value is a string. The default value is a single space because a standard OAuth2 `scope` claim may contain a space separated sequence.
|smallrye.jwt.claims.groups|none| This property can be used to set a default groups claim value when the current token has no standard groups claim available (or no custom groups claim when `smallrye.jwt.path.groups` is used).
|smallrye.jwt.jwks.refresh-interval|60|JWK cache refresh interval in minutes. It will be ignored unless the `mp.jwt.verify.publickey.location` points to the HTTP or HTTPS URL based JWK set and no HTTP `Cache-Control` response header with a positive `max-age` parameter value is returned from a JWK set endpoint.
|smallrye.jwt.jwks.retain-cache-on-error-duration|0|JWK cache retain on error duration in minutes which sets the length of time, before trying again, to keep using the cache when an error occurs making the request to the JWKS URI or parsing the response. It will be ignored unless the `mp.jwt.verify.publickey.location` property points to the HTTP or HTTPS URL based JWK set.
|smallrye.jwt.jwks.forced-refresh-interval|30|Forced JWK cache refresh interval in minutes which is used to restrict the frequency of the forced refresh attempts which may happen when the token verification fails due to the cache having no JWK key with a `kid` property matching the current token's `kid` header. It will be ignored unless the `mp.jwt.verify.publickey.location` points to the HTTP or HTTPS URL based JWK set.
|smallrye.jwt.expiration.grace|0|Expiration grace in seconds. By default an expired token will still be accepted if the current time is no more than 1 min after the token expiry time. This property is deprecated. Use `mp.jwt.verify.clock.skew` instead.
|smallrye.jwt.verify.aud|none|Comma separated list of the audiences that a token `aud` claim may contain. This property is deprecated. Use `mp.jwt.verify.audiences` instead.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ protected HttpsJwks initializeHttpsJwks(String location)
new InetSocketAddress(authContextInfo.getHttpProxyHost(), authContextInfo.getHttpProxyPort())));
}
theHttpsJwks.setSimpleHttpGet(httpGet);
theHttpsJwks.setRetainCacheOnErrorDuration(authContextInfo.getJwksRetainCacheOnErrorDuration() * 60L);
return theHttpsJwks;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class JWTAuthContextInfo {
private String decryptionKeyContent;
private Integer jwksRefreshInterval = 60;
private int forcedJwksRefreshInterval = 30;
private int jwksRetainCacheOnErrorDuration = 0;
private String tokenHeader = "Authorization";
private String tokenCookie;
private boolean alwaysCheckAuthorization;
Expand Down Expand Up @@ -121,6 +122,7 @@ public JWTAuthContextInfo(JWTAuthContextInfo orig) {
this.decryptionKeyContent = orig.getDecryptionKeyContent();
this.jwksRefreshInterval = orig.getJwksRefreshInterval();
this.forcedJwksRefreshInterval = orig.getForcedJwksRefreshInterval();
this.jwksRetainCacheOnErrorDuration = orig.getJwksRetainCacheOnErrorDuration();
this.tokenHeader = orig.getTokenHeader();
this.tokenCookie = orig.getTokenCookie();
this.alwaysCheckAuthorization = orig.isAlwaysCheckAuthorization();
Expand Down Expand Up @@ -283,6 +285,14 @@ public void setForcedJwksRefreshInterval(int forcedJwksRefreshInterval) {
this.forcedJwksRefreshInterval = forcedJwksRefreshInterval;
}

public int getJwksRetainCacheOnErrorDuration() {
return jwksRetainCacheOnErrorDuration;
}

public void setJwksRetainCacheOnErrorDuration(int jwksRetainCacheOnErrorDuration) {
this.jwksRetainCacheOnErrorDuration = jwksRetainCacheOnErrorDuration;
}

public String getTokenHeader() {
return tokenHeader;
}
Expand Down Expand Up @@ -436,6 +446,7 @@ public String toString() {
", decryptionKeyLocation='" + decryptionKeyLocation + '\'' +
", decryptionKeyContent='" + decryptionKeyContent + '\'' +
", jwksRefreshInterval=" + jwksRefreshInterval +
", jwksRetainCacheOnErrorDuration=" + jwksRetainCacheOnErrorDuration +
", tokenHeader='" + tokenHeader + '\'' +
", tokenCookie='" + tokenCookie + '\'' +
", alwaysCheckAuthorization=" + alwaysCheckAuthorization +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ private static JWTAuthContextInfoProvider create(String key,
provider.mpJwtVerifyTokenAge = Optional.empty();
provider.jwksRefreshInterval = 60;
provider.forcedJwksRefreshInterval = 30;
provider.jwksRetainCacheOnErrorDuration = 0;
provider.signatureAlgorithm = Optional.of(SignatureAlgorithm.RS256);
provider.keyEncryptionAlgorithm = Optional.empty();
provider.mpJwtDecryptKeyAlgorithm = new HashSet<>(Arrays.asList(KeyEncryptionAlgorithm.RSA_OAEP,
Expand Down Expand Up @@ -465,6 +466,15 @@ private static JWTAuthContextInfoProvider create(String key,
@ConfigProperty(name = "smallrye.jwt.jwks.forced-refresh-interval", defaultValue = "30")
private int forcedJwksRefreshInterval;

/**
* JWK cache retain on error duration in minutes which sets the length of time, before trying again, to keep using the cache
* when an error occurs making the request to the JWKS URI or parsing the response.
* It will be ignored unless the 'mp.jwt.verify.publickey.location' property points to the HTTP or HTTPS URL based JWK set.
*/
@Inject
@ConfigProperty(name = "smallrye.jwt.jwks.retain-cache-on-error-duration", defaultValue = "0")
private int jwksRetainCacheOnErrorDuration;

/**
* Supported JSON Web Algorithm asymmetric or symmetric signature algorithm.
*
Expand Down Expand Up @@ -836,6 +846,7 @@ Optional<JWTAuthContextInfo> getOptionalContextInfo() {
contextInfo.setTokenAge(mpJwtVerifyTokenAge.orElse(null));
contextInfo.setJwksRefreshInterval(jwksRefreshInterval);
contextInfo.setForcedJwksRefreshInterval(forcedJwksRefreshInterval);
contextInfo.setJwksRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration);
Set<SignatureAlgorithm> resolvedAlgorithm = mpJwtPublicKeyAlgorithm;
if (signatureAlgorithm.isPresent()) {
if (signatureAlgorithm.get().getAlgorithm().startsWith("HS")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.net.Proxy;
Expand All @@ -30,12 +34,18 @@
import java.security.interfaces.RSAPublicKey;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

import org.jose4j.base64url.Base64Url;
import org.jose4j.http.Get;
import org.jose4j.http.SimpleResponse;
import org.jose4j.json.internal.json_simple.JSONObject;
import org.jose4j.jwk.HttpsJwks;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.OctetSequenceJsonWebKey;
Expand Down Expand Up @@ -68,6 +78,8 @@ class KeyLocationResolverTest {
Get mockedGet;
@Mock
UrlStreamResolver urlResolver;
@Mock
SimpleResponse simpleResponse;

RSAPublicKey rsaKey;
SecretKey secretKey;
Expand Down Expand Up @@ -180,6 +192,46 @@ protected Get getHttpGet() {
assertNull(keyLocationResolver.key);
}

@Test
void keepsRsaKeyFromHttpsJwksWhenErrorDuringRefresh() throws Exception {
long cacheDuration = 1L;
int jwksRetainCacheOnErrorDuration = 10;
JWTAuthContextInfo contextInfo = new JWTAuthContextInfo("https://github.com/my_key.jwks", "issuer");
contextInfo.setJwksRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration);

HttpsJwks spiedHttpsJwks = Mockito.spy(new HttpsJwks(contextInfo.getPublicKeyLocation()));
spiedHttpsJwks.setDefaultCacheDuration(cacheDuration);
when(simpleResponse.getBody()).thenReturn(generateJWK(rsaKey));
when(mockedGet.get(contextInfo.getPublicKeyLocation())).thenReturn(simpleResponse);

KeyLocationResolver keyLocationResolver = new KeyLocationResolver(contextInfo) {
protected HttpsJwks getHttpsJwks(String loc) {
return spiedHttpsJwks;
}

protected Get getHttpGet() {
return mockedGet;
}
};

Mockito.verify(spiedHttpsJwks).setRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration * 60L);
Mockito.verify(spiedHttpsJwks).setSimpleHttpGet(mockedGet);

when(signature.getHeaders()).thenReturn(headers);
when(headers.getStringHeaderValue(JsonWebKey.KEY_ID_PARAMETER)).thenReturn("1");
when(headers.getStringHeaderValue(JsonWebKey.ALGORITHM_PARAMETER)).thenReturn("RS256");

assertEquals(rsaKey, keyLocationResolver.resolveKey(signature, emptyList()));

doThrow(RuntimeException.class).when(mockedGet).get(contextInfo.getPublicKeyLocation());

TimeUnit.SECONDS.sleep(cacheDuration);

assertEquals(rsaKey, keyLocationResolver.resolveKey(signature, emptyList()));

verify(mockedGet, atLeastOnce()).get(contextInfo.getPublicKeyLocation());
}

@Test
void loadRsaKeyFromHttpJwks() throws Exception {
JWTAuthContextInfo contextInfo = new JWTAuthContextInfo("http://github.com/my_key.jwks", "issuer");
Expand Down Expand Up @@ -330,7 +382,7 @@ void loadHttpsPemCrt() throws Exception {
contextInfo.setJwksRefreshInterval(10);

Mockito.doThrow(new JoseException("")).when(mockedHttpsJwks).refresh();
Mockito.doReturn(ResourceUtils.getAsClasspathResource("publicCrt.pem"))
doReturn(ResourceUtils.getAsClasspathResource("publicCrt.pem"))
.when(urlResolver).resolve(Mockito.any());
KeyLocationResolver keyLocationResolver = new KeyLocationResolver(contextInfo) {
protected HttpsJwks initializeHttpsJwks(String loc) {
Expand Down Expand Up @@ -380,4 +432,18 @@ void loadJWKOnClassPath() throws Exception {
assertEquals(keyLocationResolver.key,
keyLocationResolver.getJsonWebKey("key1", null).getKey());
}

private String generateJWK(RSAPublicKey publicKey) {
Map<String, Object> key = new HashMap<>();

key.put("alg", "RS256");
key.put("use", "sig");
key.put("kty", publicKey.getAlgorithm());
key.put("kid", "1");
key.put("n", Base64Url.encode(publicKey.getModulus().toByteArray()));
key.put("e", Base64Url.encode(publicKey.getPublicExponent().toByteArray()));

return JSONObject.toJSONString(Collections.singletonMap("keys",
Collections.singletonList(key)));
}
}

0 comments on commit a8f50f1

Please sign in to comment.