Skip to content

Commit

Permalink
chore: P4ADEV-1167 verify jwt signature (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
antocalo authored Oct 9, 2024
1 parent 022c960 commit 8a22d1f
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import it.gov.pagopa.payhub.auth.exception.custom.InvalidAccessTokenException;
import it.gov.pagopa.payhub.auth.service.AuthnService;
import it.gov.pagopa.payhub.auth.service.ValidateTokenService;
import it.gov.pagopa.payhub.model.generated.UserInfo;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
Expand All @@ -26,18 +27,21 @@
public class JwtAuthenticationFilter extends OncePerRequestFilter {

private final AuthnService authnService;
private final ValidateTokenService validateTokenService;

public JwtAuthenticationFilter(AuthnService authnService) {
public JwtAuthenticationFilter(AuthnService authnService, ValidateTokenService validateTokenService) {
this.authnService = authnService;
this.validateTokenService = validateTokenService;
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
try {
String authorization = request.getHeader(HttpHeaders.AUTHORIZATION);
if (StringUtils.hasText(authorization)) {
UserInfo userInfo = authnService.getUserInfo(authorization.replace("Bearer ", ""));

String token = authorization.replace("Bearer ", "");
validateTokenService.validate(token);
UserInfo userInfo = authnService.getUserInfo(token);
Collection<? extends GrantedAuthority> authorities = null;
if (userInfo.getOrganizationAccess() != null) {
authorities = userInfo.getOrganizations().stream()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package it.gov.pagopa.payhub.auth.service;

import com.auth0.jwt.JWT;
import com.auth0.jwt.interfaces.DecodedJWT;
import it.gov.pagopa.payhub.auth.exception.custom.InvalidTokenException;
import it.gov.pagopa.payhub.auth.service.exchange.AccessTokenBuilderService;
import it.gov.pagopa.payhub.auth.utils.JWTValidator;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

@Service
@Slf4j
public class ValidateTokenService {
private final JWTValidator jwtValidator;

public ValidateTokenService(JWTValidator jwtValidator) {
this.jwtValidator = jwtValidator;
}

public void validate(String token) {
jwtValidator.validateInternalToken(token);
DecodedJWT jwt = JWT.decode(token);
validateAccessType(jwt.getHeaderClaim("typ").asString());
}

private void validateAccessType(String type) {
if(!AccessTokenBuilderService.ACCESS_TOKEN_TYPE.equalsIgnoreCase(type)) {
throw new InvalidTokenException("Invalid token type " + type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import com.auth0.jwt.algorithms.Algorithm;
import it.gov.pagopa.payhub.auth.utils.CertUtils;
import it.gov.pagopa.payhub.model.generated.AccessToken;
import java.util.HashMap;
import java.util.Map;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

Expand All @@ -17,7 +19,7 @@

@Service
public class AccessTokenBuilderService {

public static final String ACCESS_TOKEN_TYPE = "at+JWT";
private final String allowedAudience;
private final int expireIn;

Expand All @@ -43,8 +45,11 @@ public AccessTokenBuilderService(

public AccessToken build(){
Algorithm algorithm = Algorithm.RSA512(rsaPublicKey, rsaPrivateKey);
Map<String, Object> headerClaims = new HashMap<>();
headerClaims.put("typ", ACCESS_TOKEN_TYPE);
String tokenType = "bearer";
String token = JWT.create()
.withHeader(headerClaims)
.withClaim("typ", tokenType)
.withIssuer(allowedAudience)
.withJWTId(UUID.randomUUID().toString())
Expand Down
33 changes: 33 additions & 0 deletions src/main/java/it/gov/pagopa/payhub/auth/utils/JWTValidator.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import com.auth0.jwt.interfaces.DecodedJWT;
import it.gov.pagopa.payhub.auth.exception.custom.InvalidTokenException;
import it.gov.pagopa.payhub.auth.exception.custom.TokenExpiredException;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import java.security.interfaces.RSAPublicKey;
Expand All @@ -25,6 +29,15 @@
@Component
public class JWTValidator {

private final JWTVerifier jwtVerifier;

public JWTValidator(@Value("${jwt.access-token.public-key}") String publicKey)
throws NoSuchAlgorithmException, InvalidKeySpecException, IOException {
RSAPublicKey rsaPublicKey = CertUtils.pemPub2PublicKey(publicKey);
Algorithm algorithm = Algorithm.RSA512(rsaPublicKey);
jwtVerifier = JWT.require(algorithm).build();
}

/**
* Validates a JWT against a JWK provider URL.
*
Expand Down Expand Up @@ -53,4 +66,24 @@ public Map<String, Claim> validate(String token, String urlJwkProvider) {
throw new InvalidTokenException("The token is not valid");
}
}

/**
* Validates JWT signature with publickey.
*
* @param token the JWT to validate
* @throws IllegalStateException if the public key cannot be loaded due to
* invalid format, missing algorithm, or I/O issues.
* @throws TokenExpiredException if the token has expired.
* @throws InvalidTokenException if the token is invalid for any other reason
* (e.g., signature verification failure).
*/
public void validateInternalToken(String token) {
try{
jwtVerifier.verify(token);
} catch (com.auth0.jwt.exceptions.TokenExpiredException e){
throw new TokenExpiredException(e.getMessage());
} catch (JWTVerificationException ex) {
throw new InvalidTokenException("The token is not valid");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import it.gov.pagopa.payhub.auth.security.JwtAuthenticationFilter;
import it.gov.pagopa.payhub.auth.security.WebSecurityConfig;
import it.gov.pagopa.payhub.auth.service.AuthnService;
import it.gov.pagopa.payhub.auth.service.ValidateTokenService;
import it.gov.pagopa.payhub.model.generated.AccessToken;
import it.gov.pagopa.payhub.model.generated.AuthErrorDTO;
import it.gov.pagopa.payhub.model.generated.UserInfo;
Expand Down Expand Up @@ -44,6 +45,9 @@ class AuthnControllerTest {
@MockBean
private AuthnService authnServiceMock;

@MockBean
private ValidateTokenService validateTokenServiceMock;

//region desc=postToken tests
@Test
void givenExpectedAuthTokenWhenPostTokenThenOk() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import it.gov.pagopa.payhub.auth.security.WebSecurityConfig;
import it.gov.pagopa.payhub.auth.service.AuthnService;
import it.gov.pagopa.payhub.auth.service.AuthzService;
import it.gov.pagopa.payhub.auth.service.ValidateTokenService;
import it.gov.pagopa.payhub.auth.utils.Constants;
import it.gov.pagopa.payhub.model.generated.*;
import org.junit.jupiter.api.Assertions;
Expand Down Expand Up @@ -45,6 +46,9 @@ class AuthzControllerNoOrganizzationAccessModeTest {
@MockBean
private AuthnService authnServiceMock;

@MockBean
private ValidateTokenService validateTokenServiceMock;

// createOperator region
@Test
void givenUnauthorizedUserWhenCreateOrganizationOperatorThenOk() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import it.gov.pagopa.payhub.auth.security.WebSecurityConfig;
import it.gov.pagopa.payhub.auth.service.AuthnService;
import it.gov.pagopa.payhub.auth.service.AuthzService;
import it.gov.pagopa.payhub.auth.service.ValidateTokenService;
import it.gov.pagopa.payhub.auth.utils.Constants;
import it.gov.pagopa.payhub.model.generated.CreateOperatorRequest;
import it.gov.pagopa.payhub.model.generated.OperatorDTO;
Expand Down Expand Up @@ -51,6 +52,9 @@ class AuthzControllerTest {
@MockBean
private AuthnService authnServiceMock;

@MockBean
private ValidateTokenService validateTokenServiceMock;

//region desc=getOrganizationOperators tests
@Test
void givenAuthorizedUserwhenGetOrganizationOperatorsThenOk() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package it.gov.pagopa.payhub.auth.service;

import com.auth0.jwt.JWT;
import com.auth0.jwt.interfaces.DecodedJWT;
import it.gov.pagopa.payhub.auth.exception.custom.InvalidTokenException;
import it.gov.pagopa.payhub.auth.service.exchange.AccessTokenBuilderService;
import it.gov.pagopa.payhub.auth.utils.JWTValidator;
import org.junit.jupiter.api.Assertions;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.Mockito;

import org.springframework.test.context.junit.jupiter.SpringExtension;

@ExtendWith(SpringExtension.class)
class ValidateTokenServiceTest {

private ValidateTokenService validateTokenService;

@Mock
private JWTValidator jwtValidator;

@Mock
private JWT jwtMock;

private DecodedJWT decodedJWT;

@BeforeEach
void setup(){
validateTokenService = new ValidateTokenService(jwtValidator);
}

@Test
void givenValidJWTThenOk() {
String validToken = "eyJ0eXAiOiJhdCtKV1QiLCJhbGciOiJSUzUxMiJ9.eyJ0eXAiOiJiZWFyZXIiLCJpc3MiOiJkZXYucGlhdHRhZm9ybWF1bml0YXJpYS5wYWdvcGEuaXQiLCJqdGkiOiI5NzZhYTYzMy0wMTVmLTQ3MDMtYWM3NC03NjE2YjJlN2JkNjQiLCJpYXQiOjE3MjgyOTkwOTksImV4cCI6MTcyODMxMzQ5OX0.l3gHHCdyPxq0AOUO3nFIzDzpp4kgwslS6U3K_KUaQ0VExSsxETGM7N7YiVVu3qXfaNy4H8Q7lvtb8bWThGNehh-SA1sX_U_nmTWhdtt0ULEdQ5sbg5_PH5VGuav-bthzqkeS1zv_TbAGl27HswOOCpdA3LhWzRs4KxA55EnKj0gCjxMHIEYuMxLhc400IKXC8dFk888dv_WZk1FgakdCYUbqOGCK_g7eVxa4N6oaFxJTZHaqviRQOs4YBMszwGhRAl34JBgrR1PYwx3Bsy6wcjEjshilqeOLjGIsUBojFoa8Vfw0oYDJ0OrfiG5EuiyABxqtKkS5b4Hs1qnU63wneg";
DecodedJWT jwt = JWT.decode(validToken);

Mockito.doNothing().when(jwtValidator).validateInternalToken(validToken);

validateTokenService.validate(validToken);

Assertions.assertDoesNotThrow(() -> jwtValidator.validateInternalToken(validToken));
Assertions.assertDoesNotThrow(() -> validateTokenService.validate(validToken));
Assertions.assertEquals(AccessTokenBuilderService.ACCESS_TOKEN_TYPE, jwt.getHeaderClaim("typ").asString());
}

@Test
void givenInvalidJWTTypeThenInvalidTokenException() {
String invalidToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzUxMiJ9.eyJ0eXAiOiJiZWFyZXIiLCJpc3MiOiJkZXYucGlhdHRhZm9ybWF1bml0YXJpYS5wYWdvcGEuaXQiLCJqdGkiOiI5NzZhYTYzMy0wMTVmLTQ3MDMtYWM3NC03NjE2YjJlN2JkNjQiLCJpYXQiOjE3MjgyOTkwOTksImV4cCI6MTcyODMxMzQ5OX0.NxbnCRBGcr0iftbagyPU-v3140loAQq4k0JaAg1fdTvI3qHBm4CS8za31s7OnRpNQ2ojlww9ApEAowzcjajnVJRo4L5D1W5M0RcVN_wSdBJrNcvPmN7PFKQn37xCbDkQ00I1d4ZLJVbP5hA2FFekJXu_w0NlUhSHsGPQoSYNOJr70fJUQ15K_asr6zi7J5XfbYSMNJBZWdVSCJoVfQDVRaWCq5H4zcBhfCbiOYtYeVDbYygFDWizHTiz9XwF-79aJcjp9VCTduyJ1ROJCBZfnUqZgN4BM75E5H-bmBEEbahqIT3eAY1lYAyv83s3Y5ys-5n6pFWgi6NuvP5vifl78w";
// When
Mockito.doNothing().when(jwtValidator).validateInternalToken(invalidToken);
Mockito.when(jwtMock.decodeJwt(invalidToken)).thenReturn(decodedJWT);

// Then
Assertions.assertThrows(InvalidTokenException.class, ()->validateTokenService.validate(invalidToken));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void test(){
String decodedHeader = new String(Base64.getDecoder().decode(decodedAccessToken.getHeader()));
String decodedPayload = new String(Base64.getDecoder().decode(decodedAccessToken.getPayload()));

Assertions.assertEquals("{\"alg\":\"RS512\",\"typ\":\"JWT\"}", decodedHeader);
Assertions.assertEquals("{\"typ\":\"at+JWT\",\"alg\":\"RS512\"}", decodedHeader);
Assertions.assertEquals(EXPIRE_IN, (decodedAccessToken.getExpiresAtAsInstant().toEpochMilli() - decodedAccessToken.getIssuedAtAsInstant().toEpochMilli()) / 1_000);
Assertions.assertTrue(Pattern.compile("\\{\"typ\":\"bearer\",\"iss\":\"APPLICATION_AUDIENCE\",\"jti\":\"[0-9a-z]{8}-[0-9a-z]{4}-[0-9a-z]{4}-[0-9a-z]{4}-[0-9a-z]{12}\",\"iat\":[0-9]+,\"exp\":[0-9]+}").matcher(decodedPayload).matches(), "Payload not matches requested pattern: " + decodedPayload);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import com.github.tomakehurst.wiremock.WireMockServer;
import it.gov.pagopa.payhub.auth.exception.custom.InvalidTokenException;
import it.gov.pagopa.payhub.auth.exception.custom.TokenExpiredException;
import java.security.KeyPair;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -23,13 +25,16 @@ class JWTValidatorTest {
private JWTValidator jwtValidator;
private WireMockServer wireMockServer;
private JWTValidatorUtils utils;
private KeyPair keyPair;

@BeforeEach
void setup(){
void setup() throws Exception {
wireMockServer = new WireMockServer(wireMockConfig().dynamicPort());
wireMockServer.start();
utils = new JWTValidatorUtils(wireMockServer);
jwtValidator = new JWTValidator();
keyPair = JWTValidatorUtils.generateKeyPair();
String publicKey = JWTValidatorUtils.getPublicKey(keyPair);
jwtValidator = new JWTValidator(publicKey);
}

@AfterEach
Expand Down Expand Up @@ -63,4 +68,26 @@ void givenInvalidTokenThenThrowInvalidTokenException() {

assertThrows(InvalidTokenException.class, () -> jwtValidator.validate(invalidToken, urlJwkProvider));
}

@Test
void givenValidInternalJWTThenOk() {
String validToken = utils.generateInternalToken(keyPair,new Date(System.currentTimeMillis() + 3600000));
Assertions.assertDoesNotThrow(() -> jwtValidator.validateInternalToken(validToken));
}

@Test
void givenInvalidInternalJWTThenInvalidTokenException() throws Exception {
KeyPair otherKeyPair = JWTValidatorUtils.generateKeyPair();
String invalidToken = utils.generateInternalToken(otherKeyPair, new Date(System.currentTimeMillis() + 3600000));

assertThrows(InvalidTokenException.class, () -> jwtValidator.validateInternalToken(invalidToken));
}

@Test
void givenTokenExpiredThenTokenExpiredException() {
String invalidToken = utils.generateInternalToken(keyPair, new Date(System.currentTimeMillis() - 3600000));

assertThrows(TokenExpiredException.class, () -> jwtValidator.validateInternalToken(invalidToken));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.RSAKey;
import java.io.StringWriter;
import java.security.PublicKey;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import org.bouncycastle.util.io.pem.PemObject;
import org.bouncycastle.util.io.pem.PemWriter;
import org.json.JSONObject;

import java.security.KeyPair;
Expand All @@ -23,6 +30,7 @@ public class JWTValidatorUtils {

private static final String AUD = "AUD";
private static final String ISS = "ISS";
private static final String ACCESS_TOKEN_TYPE = "at+JWT";

public JWTValidatorUtils(WireMockServer wireMockServer) {
this.wireMockServer = wireMockServer;
Expand Down Expand Up @@ -64,9 +72,34 @@ public String getUrlJwkProvider() {
return "http://localhost:" + wireMockServer.port() + "/jwks";
}

private static KeyPair generateKeyPair() throws Exception {
public static KeyPair generateKeyPair() throws Exception {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
return keyPairGenerator.generateKeyPair();
}

public String generateInternalToken(KeyPair keyPair, Date expiresAt) {
Algorithm algorithm = Algorithm.RSA512((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());
Map<String, Object> headerClaims = new HashMap<>();
headerClaims.put("typ", ACCESS_TOKEN_TYPE);
String tokenType = "bearer";
return JWT.create()
.withHeader(headerClaims)
.withClaim("typ", tokenType)
.withIssuer(ISS)
.withJWTId("my-jwt-id")
.withIssuedAt(Instant.now())
.withExpiresAt(expiresAt)
.sign(algorithm);
}

public static String getPublicKey(KeyPair keyPair) throws Exception {
PublicKey publicKey = keyPair.getPublic();
StringWriter stringWriter = new StringWriter();
PemWriter pemWriter = new PemWriter(stringWriter);
pemWriter.writeObject(new PemObject("PUBLIC KEY", publicKey.getEncoded()));
pemWriter.flush();
pemWriter.close();
return stringWriter.toString();
}
}

0 comments on commit 8a22d1f

Please sign in to comment.