From 6d32ac48b6020c58e683d782d8806cd32b271935 Mon Sep 17 00:00:00 2001 From: Martin Ledvinka Date: Mon, 13 Nov 2023 08:58:26 +0100 Subject: [PATCH] [OIDC] Modify role mapping from OIDC access token. This ensures only known roles are mapped, and they are mapped correctly to types used by the record manager. --- .../kbss/study/rest/OidcUserController.java | 24 +++++++++- .../cvut/kbss/study/security/model/Role.java | 36 +++++++++++++++ .../study/security/model/UserDetails.java | 21 +++------ .../study/service/security/SecurityUtils.java | 18 +++++++- .../oidc/OidcGrantedAuthoritiesExtractor.java | 7 ++- .../kbss/study/security/model/RoleTest.java | 44 +++++++++++++++++++ .../service/security/SecurityUtilsTest.java | 30 +++++++++++++ 7 files changed, 162 insertions(+), 18 deletions(-) create mode 100644 src/main/java/cz/cvut/kbss/study/security/model/Role.java create mode 100644 src/test/java/cz/cvut/kbss/study/security/model/RoleTest.java diff --git a/src/main/java/cz/cvut/kbss/study/rest/OidcUserController.java b/src/main/java/cz/cvut/kbss/study/rest/OidcUserController.java index 3504e5b8..c55d3f21 100644 --- a/src/main/java/cz/cvut/kbss/study/rest/OidcUserController.java +++ b/src/main/java/cz/cvut/kbss/study/rest/OidcUserController.java @@ -1,15 +1,20 @@ package cz.cvut.kbss.study.rest; +import cz.cvut.kbss.study.model.Institution; import cz.cvut.kbss.study.model.User; import cz.cvut.kbss.study.security.SecurityConstants; +import cz.cvut.kbss.study.service.InstitutionService; import cz.cvut.kbss.study.service.UserService; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.http.MediaType; import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import java.util.List; + /** * API for getting basic user info. *

@@ -22,8 +27,11 @@ public class OidcUserController extends BaseController { private final UserService userService; - public OidcUserController(UserService userService) { + private final InstitutionService institutionService; + + public OidcUserController(UserService userService, InstitutionService institutionService) { this.userService = userService; + this.institutionService = institutionService; } @PreAuthorize("hasRole('" + SecurityConstants.ROLE_USER + "')") @@ -31,4 +39,18 @@ public OidcUserController(UserService userService) { public User getCurrent() { return userService.getCurrentUser(); } + + @PreAuthorize( + "hasRole('" + SecurityConstants.ROLE_ADMIN + "') " + + "or hasRole('" + SecurityConstants.ROLE_USER + "') and @securityUtils.isMemberOfInstitution(#institutionKey)") + @GetMapping(produces = MediaType.APPLICATION_JSON_VALUE) + public List getUsers(@RequestParam(value = "institution", required = false) String institutionKey) { + return institutionKey != null ? getByInstitution(institutionKey) : userService.findAll(); + } + + private List getByInstitution(String institutionKey) { + assert institutionKey != null; + final Institution institution = institutionService.findByKey(institutionKey); + return userService.findByInstitution(institution); + } } diff --git a/src/main/java/cz/cvut/kbss/study/security/model/Role.java b/src/main/java/cz/cvut/kbss/study/security/model/Role.java new file mode 100644 index 00000000..4b794953 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/study/security/model/Role.java @@ -0,0 +1,36 @@ +package cz.cvut.kbss.study.security.model; + +import cz.cvut.kbss.study.model.Vocabulary; +import cz.cvut.kbss.study.security.SecurityConstants; + +import java.util.Optional; +import java.util.stream.Stream; + +public enum Role { + USER(SecurityConstants.ROLE_USER, Vocabulary.s_c_doctor), + ADMIN(SecurityConstants.ROLE_ADMIN, Vocabulary.s_c_administrator); + + private final String name; + private final String type; + + Role(String name, String type) { + this.name = name; + this.type = type; + } + + public static Optional forType(String type) { + return Stream.of(Role.values()).filter(r -> r.type.equals(type)).findAny(); + } + + public static Optional forName(String name) { + return Stream.of(Role.values()).filter(r -> r.name.equals(name)).findAny(); + } + + public String getName() { + return name; + } + + public String getType() { + return type; + } +} diff --git a/src/main/java/cz/cvut/kbss/study/security/model/UserDetails.java b/src/main/java/cz/cvut/kbss/study/security/model/UserDetails.java index 91a843fc..78eb1eac 100644 --- a/src/main/java/cz/cvut/kbss/study/security/model/UserDetails.java +++ b/src/main/java/cz/cvut/kbss/study/security/model/UserDetails.java @@ -1,34 +1,23 @@ package cz.cvut.kbss.study.security.model; import cz.cvut.kbss.study.model.User; -import cz.cvut.kbss.study.model.Vocabulary; import cz.cvut.kbss.study.security.SecurityConstants; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; -import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; public class UserDetails implements org.springframework.security.core.userdetails.UserDetails { - private static final Map ROLE_MAPPING = initRoleMapping(); - private final User user; private final Set authorities; - private static Map initRoleMapping() { - final Map result = new HashMap<>(); - result.put(Vocabulary.s_c_administrator, SecurityConstants.ROLE_ADMIN); - result.put(Vocabulary.s_c_doctor, SecurityConstants.ROLE_USER); - return result; - } - public UserDetails(User user) { Objects.requireNonNull(user); this.user = user; @@ -46,8 +35,12 @@ public UserDetails(User user, Collection authorities) { } private void resolveRoles() { - authorities.addAll(ROLE_MAPPING.entrySet().stream().filter(e -> user.getTypes().contains(e.getKey())) - .map(e -> new SimpleGrantedAuthority(e.getValue())).toList()); + authorities.addAll( + user.getTypes().stream() + .map(Role::forType) + .filter(Optional::isPresent) + .map(r -> new SimpleGrantedAuthority(r.get().getName())) + .toList()); authorities.add(new SimpleGrantedAuthority(SecurityConstants.ROLE_USER)); } diff --git a/src/main/java/cz/cvut/kbss/study/service/security/SecurityUtils.java b/src/main/java/cz/cvut/kbss/study/service/security/SecurityUtils.java index 0322cb99..fd51a950 100644 --- a/src/main/java/cz/cvut/kbss/study/service/security/SecurityUtils.java +++ b/src/main/java/cz/cvut/kbss/study/service/security/SecurityUtils.java @@ -1,10 +1,14 @@ package cz.cvut.kbss.study.service.security; +import cz.cvut.kbss.study.exception.NotFoundException; import cz.cvut.kbss.study.model.PatientRecord; import cz.cvut.kbss.study.model.User; import cz.cvut.kbss.study.persistence.dao.PatientRecordDao; import cz.cvut.kbss.study.persistence.dao.UserDao; +import cz.cvut.kbss.study.security.model.Role; import cz.cvut.kbss.study.security.model.UserDetails; +import cz.cvut.kbss.study.service.ConfigReader; +import cz.cvut.kbss.study.util.oidc.OidcGrantedAuthoritiesExtractor; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.context.SecurityContext; @@ -15,6 +19,7 @@ import org.springframework.stereotype.Service; import java.util.List; +import java.util.Optional; @Service public class SecurityUtils { @@ -23,9 +28,12 @@ public class SecurityUtils { private final PatientRecordDao patientRecordDao; - public SecurityUtils(UserDao userDao, PatientRecordDao patientRecordDao) { + private final ConfigReader config; + + public SecurityUtils(UserDao userDao, PatientRecordDao patientRecordDao, ConfigReader config) { this.userDao = userDao; this.patientRecordDao = patientRecordDao; + this.config = config; } /** @@ -70,7 +78,13 @@ public User getCurrentUser() { private User resolveAccountFromOAuthPrincipal(Jwt principal) { final OidcUserInfo userInfo = new OidcUserInfo(principal.getClaims()); - return userDao.findByUsername(userInfo.getPreferredUsername()); + final List roles = new OidcGrantedAuthoritiesExtractor(config).extractRoles(principal); + final User user = userDao.findByUsername(userInfo.getPreferredUsername()); + if (user == null) { + throw new NotFoundException("User with username '" + userInfo.getPreferredUsername() + "' not found in repository."); + } + roles.stream().map(Role::forName).filter(Optional::isPresent).forEach(r -> user.addType(r.get().getType())); + return user; } /** diff --git a/src/main/java/cz/cvut/kbss/study/util/oidc/OidcGrantedAuthoritiesExtractor.java b/src/main/java/cz/cvut/kbss/study/util/oidc/OidcGrantedAuthoritiesExtractor.java index 4ea00937..136a4401 100644 --- a/src/main/java/cz/cvut/kbss/study/util/oidc/OidcGrantedAuthoritiesExtractor.java +++ b/src/main/java/cz/cvut/kbss/study/util/oidc/OidcGrantedAuthoritiesExtractor.java @@ -4,6 +4,7 @@ import cz.cvut.kbss.study.util.ConfigParam; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.oauth2.core.ClaimAccessor; import org.springframework.security.oauth2.jwt.Jwt; import java.util.Collection; @@ -21,6 +22,10 @@ public OidcGrantedAuthoritiesExtractor(ConfigReader config) { @Override public Collection convert(Jwt source) { + return extractRoles(source).stream().map(SimpleGrantedAuthority::new).toList(); + } + + public List extractRoles(ClaimAccessor source) { final String rolesClaim = config.getConfig(ConfigParam.OIDC_ROLE_CLAIM); final String[] parts = rolesClaim.split("\\."); assert parts.length > 0; @@ -40,6 +45,6 @@ public Collection convert(Jwt source) { } roles = (List) map.getOrDefault(parts[parts.length - 1], Collections.emptyList()); } - return roles.stream().map(SimpleGrantedAuthority::new).toList(); + return roles; } } diff --git a/src/test/java/cz/cvut/kbss/study/security/model/RoleTest.java b/src/test/java/cz/cvut/kbss/study/security/model/RoleTest.java new file mode 100644 index 00000000..b006c35c --- /dev/null +++ b/src/test/java/cz/cvut/kbss/study/security/model/RoleTest.java @@ -0,0 +1,44 @@ +package cz.cvut.kbss.study.security.model; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.Optional; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.*; + +class RoleTest { + + static Stream generator() { + return Stream.of(Role.values()).map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("generator") + void forTypeReturnsRoleMatchingSpecifiedType(Role r) { + final Optional result = Role.forType(r.getType()); + assertTrue(result.isPresent()); + assertEquals(r, result.get()); + } + + @Test + void forTypeReturnsEmptyOptionalForUnknownRoleType() { + assertTrue(Role.forType("unknownType").isEmpty()); + } + + @ParameterizedTest + @MethodSource("generator") + void forNameReturnsRoleMatchingSpecifiedRoleName(Role r) { + final Optional result = Role.forName(r.getName()); + assertTrue(result.isPresent()); + assertEquals(r, result.get()); + } + + @Test + void forNameReturnsEmptyOptionalForUnknownRoleName() { + assertTrue(Role.forName("unknownName").isEmpty()); + } +} \ No newline at end of file diff --git a/src/test/java/cz/cvut/kbss/study/service/security/SecurityUtilsTest.java b/src/test/java/cz/cvut/kbss/study/service/security/SecurityUtilsTest.java index bd88284b..4cdcdd11 100644 --- a/src/test/java/cz/cvut/kbss/study/service/security/SecurityUtilsTest.java +++ b/src/test/java/cz/cvut/kbss/study/service/security/SecurityUtilsTest.java @@ -5,9 +5,12 @@ import cz.cvut.kbss.study.model.Institution; import cz.cvut.kbss.study.model.PatientRecord; import cz.cvut.kbss.study.model.User; +import cz.cvut.kbss.study.model.Vocabulary; import cz.cvut.kbss.study.persistence.dao.PatientRecordDao; import cz.cvut.kbss.study.persistence.dao.UserDao; import cz.cvut.kbss.study.security.SecurityConstants; +import cz.cvut.kbss.study.service.ConfigReader; +import cz.cvut.kbss.study.util.ConfigParam; import cz.cvut.kbss.study.util.IdentificationUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -26,6 +29,8 @@ import java.time.temporal.ChronoUnit; import java.util.List; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasItem; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -40,6 +45,9 @@ public class SecurityUtilsTest { @Mock private PatientRecordDao patientRecordDao; + @Mock + private ConfigReader config; + @InjectMocks private SecurityUtils sut; @@ -71,6 +79,7 @@ public void getCurrentUserReturnsCurrentlyLoggedInUser() { @Test void getCurrentUserRetrievesCurrentUserForOauthJwtAccessToken() { + when(config.getConfig(ConfigParam.OIDC_ROLE_CLAIM)).thenReturn("roles"); final Jwt token = Jwt.withTokenValue("abcdef12345") .header("alg", "RS256") .header("typ", "JWT") @@ -153,4 +162,25 @@ public void isRecordInUsersInstitutionReturnsFalseWhenRecordBelongsToInstitution assertFalse(sut.isRecordInUsersInstitution(record.getKey())); } + + @Test + void getCurrentUserEnhancesRetrievedUserWithTypesCorrespondingToRolesSpecifiedInJwtClaim() { + when(config.getConfig(ConfigParam.OIDC_ROLE_CLAIM)).thenReturn("roles"); + final Jwt token = Jwt.withTokenValue("abcdef12345") + .header("alg", "RS256") + .header("typ", "JWT") + .claim("roles", List.of(SecurityConstants.ROLE_ADMIN)) + .issuer("http://localhost:8080/termit") + .subject(USERNAME) + .claim("preferred_username", USERNAME) + .expiresAt(Instant.now().truncatedTo(ChronoUnit.SECONDS).plusSeconds(300)) + .build(); + SecurityContext context = new SecurityContextImpl(); + context.setAuthentication(new JwtAuthenticationToken(token)); + SecurityContextHolder.setContext(context); + when(userDao.findByUsername(user.getUsername())).thenReturn(user); + + final User result = sut.getCurrentUser(); + assertThat(result.getTypes(), hasItem(Vocabulary.s_c_administrator)); + } }