Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: invalidate user session when user's role memberships changes #15633

Merged
merged 8 commits into from
Nov 14, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
import static com.google.common.base.Preconditions.checkNotNull;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.hisp.dhis.cache.Cache;
import org.hisp.dhis.cache.CacheProvider;
import org.hisp.dhis.organisationunit.OrganisationUnit;
import org.springframework.context.annotation.Lazy;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

Expand All @@ -51,11 +54,15 @@ public class CurrentUserService {

private final Cache<CurrentUserGroupInfo> currentUserGroupInfoCache;

public CurrentUserService(@Lazy UserStore userStore, CacheProvider cacheProvider) {
private final SessionRegistry sessionRegistry;

public CurrentUserService(
@Lazy UserStore userStore, CacheProvider cacheProvider, SessionRegistry sessionRegistry) {
checkNotNull(userStore);

this.userStore = userStore;
this.currentUserGroupInfoCache = cacheProvider.createCurrentUserGroupInfoCache();
this.sessionRegistry = sessionRegistry;
}

/**
Expand Down Expand Up @@ -113,4 +120,20 @@ public void invalidateUserGroupCache(String userUID) {
// Ignore if key doesn't exist
}
}

public CurrentUserDetailsImpl getCurrentUserPrincipal(String uid) {
return sessionRegistry.getAllPrincipals().stream()
.map(CurrentUserDetailsImpl.class::cast)
.filter(principal -> principal.getUid().equals(uid))
.findFirst()
.orElse(null);
}

public void invalidateUserSessions(String uid) {
CurrentUserDetailsImpl principal = getCurrentUserPrincipal(uid);
if (principal != null) {
List<SessionInformation> allSessions = sessionRegistry.getAllSessions(principal, false);
allSessions.forEach(SessionInformation::expireNow);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@
import org.hisp.dhis.util.ObjectUtils;
import org.jboss.aerogear.security.otp.api.Base32;
import org.springframework.context.annotation.Lazy;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

Expand All @@ -108,8 +106,6 @@ public class DefaultUserService implements UserService {

private final PasswordManager passwordManager;

private final SessionRegistry sessionRegistry;

private final SecurityService securityService;

private final Cache<String> userDisplayNameCache;
Expand All @@ -126,7 +122,6 @@ public DefaultUserService(
SystemSettingManager systemSettingManager,
CacheProvider cacheProvider,
@Lazy PasswordManager passwordManager,
@Lazy SessionRegistry sessionRegistry,
@Lazy SecurityService securityService,
AclService aclService,
@Lazy OrganisationUnitService organisationUnitService) {
Expand All @@ -135,7 +130,6 @@ public DefaultUserService(
checkNotNull(userRoleStore);
checkNotNull(systemSettingManager);
checkNotNull(passwordManager);
checkNotNull(sessionRegistry);
checkNotNull(securityService);
checkNotNull(aclService);
checkNotNull(organisationUnitService);
Expand All @@ -146,7 +140,6 @@ public DefaultUserService(
this.currentUserService = currentUserService;
this.systemSettingManager = systemSettingManager;
this.passwordManager = passwordManager;
this.sessionRegistry = sessionRegistry;
this.securityService = securityService;
this.userDisplayNameCache = cacheProvider.createUserDisplayNameCache();
this.aclService = aclService;
Expand Down Expand Up @@ -798,9 +791,7 @@ public void privilegedTwoFactorDisable(

@Override
public void expireActiveSessions(User user) {
List<SessionInformation> sessions = sessionRegistry.getAllSessions(user, false);

sessions.forEach(SessionInformation::expireNow);
currentUserService.invalidateUserSessions(user.getUid());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
@AllArgsConstructor
public class UserObjectBundleHook extends AbstractObjectBundleHook<User> {
public static final String USERNAME = "username";
public static final String INVALIDATE_SESSIONS_KEY = "shouldInvalidateUserSessions";
public static final String PRE_UPDATE_USER_KEY = "preUpdateUser";

private final UserService userService;

Expand Down Expand Up @@ -140,7 +142,6 @@ public void preCreate(User user, ObjectBundle bundle) {

if (currentUser != null) {
user.getCogsDimensionConstraints().addAll(currentUser.getCogsDimensionConstraints());

user.getCatDimensionConstraints().addAll(currentUser.getCatDimensionConstraints());
}
}
Expand All @@ -166,7 +167,8 @@ public void postCreate(User user, ObjectBundle bundle) {
public void preUpdate(User user, User persisted, ObjectBundle bundle) {
if (user == null) return;

bundle.putExtras(user, "preUpdateUser", user);
bundle.putExtras(user, PRE_UPDATE_USER_KEY, user);
bundle.putExtras(persisted, INVALIDATE_SESSIONS_KEY, userRolesUpdated(user, persisted));

if (persisted.getAvatar() != null
&& (user.getAvatar() == null
Expand All @@ -183,17 +185,34 @@ public void preUpdate(User user, User persisted, ObjectBundle bundle) {
}
}

private Boolean userRolesUpdated(User preUpdateUser, User persistedUser) {
Set<String> before =
preUpdateUser.getUserRoles().stream().map(UserRole::getUid).collect(Collectors.toSet());
Set<String> after =
persistedUser.getUserRoles().stream().map(UserRole::getUid).collect(Collectors.toSet());

return !Objects.equals(before, after);
}

@Override
public void postUpdate(User persistedUser, ObjectBundle bundle) {
final User preUpdateUser = (User) bundle.getExtras(persistedUser, "preUpdateUser");
final User preUpdateUser = (User) bundle.getExtras(persistedUser, PRE_UPDATE_USER_KEY);
final Boolean invalidateSessions =
(Boolean) bundle.getExtras(persistedUser, INVALIDATE_SESSIONS_KEY);

if (!StringUtils.isEmpty(preUpdateUser.getPassword())) {
userService.encodeAndSetPassword(persistedUser, preUpdateUser.getPassword());
sessionFactory.getCurrentSession().update(persistedUser);
}

bundle.removeExtras(persistedUser, "preUpdateUser");
userSettingService.saveUserSettings(persistedUser.getSettings(), persistedUser);

if (Boolean.TRUE.equals(invalidateSessions)) {
currentUserService.invalidateUserSessions(persistedUser.getUid());
}

bundle.removeExtras(persistedUser, PRE_UPDATE_USER_KEY);
bundle.removeExtras(persistedUser, INVALIDATE_SESSIONS_KEY);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2004-2022, University of Oslo
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* Neither the name of the HISP project nor the names of its contributors may
* be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
* ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package org.hisp.dhis.dxf2.metadata.objectbundle.hooks;

import java.util.Objects;
import java.util.Set;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.hisp.dhis.dxf2.metadata.objectbundle.ObjectBundle;
import org.hisp.dhis.user.CurrentUserService;
import org.hisp.dhis.user.User;
import org.hisp.dhis.user.UserRole;
import org.springframework.stereotype.Component;

/**
* @author Morten Svanæs <[email protected]>
*/
@Component
@AllArgsConstructor
@Slf4j
public class UserRoleBundleHook extends AbstractObjectBundleHook<UserRole> {

public static final String INVALIDATE_SESSION_KEY = "shouldInvalidateUserSessions";

private final CurrentUserService currentUserService;

@Override
public void preUpdate(UserRole update, UserRole existing, ObjectBundle bundle) {
if (update == null) return;
bundle.putExtras(update, INVALIDATE_SESSION_KEY, userRolesUpdated(update, existing));
}

private Boolean userRolesUpdated(UserRole update, UserRole existing) {
Set<String> newAuthorities = update.getAuthorities();
Set<String> existingAuthorities = existing.getAuthorities();
return !Objects.equals(newAuthorities, existingAuthorities);
}

@Override
public void postUpdate(UserRole updatedUserRole, ObjectBundle bundle) {
final Boolean invalidateSessions =
(Boolean) bundle.getExtras(updatedUserRole, INVALIDATE_SESSION_KEY);

if (Boolean.TRUE.equals(invalidateSessions)) {
for (User user : updatedUserRole.getUsers()) {
currentUserService.invalidateUserSessions(user.getUid());
}
}

bundle.removeExtras(updatedUserRole, INVALIDATE_SESSION_KEY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.core.session.SessionRegistryImpl;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.ldap.authentication.LdapAuthenticator;
Expand Down Expand Up @@ -76,6 +78,11 @@ public class IntegrationTestConfig {
POSTGRES_CONTAINER.start();
}

@Bean
public static SessionRegistry sessionRegistry() {
return new SessionRegistryImpl();
}

@Bean
public LdapAuthenticator ldapAuthenticator() {
return authentication -> null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Profile;
import org.springframework.core.annotation.Order;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.core.session.SessionRegistryImpl;

/**
Expand All @@ -51,7 +52,7 @@
@Conditional(value = CacheInvalidationEnabledCondition.class)
public class TestableCacheInvalidationConfiguration {
@Bean
public static SessionRegistryImpl sessionRegistry() {
public static SessionRegistry sessionRegistry() {
return new SessionRegistryImpl();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.DefaultAuthenticationEventPublisher;
import org.springframework.security.authentication.event.AuthenticationFailureBadCredentialsEvent;
import org.springframework.security.core.session.SessionRegistryImpl;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.ldap.authentication.LdapAuthenticator;
import org.springframework.security.ldap.userdetails.LdapAuthoritiesPopulator;
Expand Down Expand Up @@ -123,7 +123,7 @@
@Order(10)
public class WebTestConfiguration {
@Bean
public static SessionRegistryImpl sessionRegistry() {
public static SessionRegistry sessionRegistry() {
return new org.springframework.security.core.session.SessionRegistryImpl();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@
import org.hisp.dhis.security.SecurityService;
import org.hisp.dhis.setting.SettingKey;
import org.hisp.dhis.setting.SystemSettingManager;
import org.hisp.dhis.user.CurrentUserDetails;
import org.hisp.dhis.user.CurrentUserService;
import org.hisp.dhis.user.User;
import org.hisp.dhis.user.UserGroup;
import org.hisp.dhis.user.UserRole;
import org.hisp.dhis.user.UserService;
import org.hisp.dhis.user.sharing.Sharing;
import org.hisp.dhis.user.sharing.UserAccess;
import org.hisp.dhis.web.HttpStatus;
Expand All @@ -81,6 +84,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.session.SessionRegistry;

/**
* Tests the {@link org.hisp.dhis.webapi.controller.user.UserController}.
Expand All @@ -96,6 +100,12 @@ class UserControllerTest extends DhisControllerConvenienceTest {

@Autowired private OrganisationUnitService organisationUnitService;

@Autowired private CurrentUserService currentUserService;

@Autowired private SessionRegistry sessionRegistry;

@Autowired private UserService userService;

private User peter;

@Autowired ObjectMapper objectMapper;
Expand All @@ -117,6 +127,56 @@ void setUp() {
assertEquals("[email protected]", user.getEmail());
}

@Test
void updateRolesShouldInvalidateUserSessions() {
netroms marked this conversation as resolved.
Show resolved Hide resolved
CurrentUserDetails sessionPrincipal = userService.createUserDetails(superUser);
sessionRegistry.registerNewSession("session1", sessionPrincipal);
assertFalse(sessionRegistry.getAllSessions(sessionPrincipal, false).isEmpty());

UserRole roleB = createUserRole("ROLE_B", "ALL");
userService.addUserRole(roleB);

String roleBID = userService.getUserRoleByName("ROLE_B").getUid();

PATCH(
"/users/" + superUser.getUid(),
"[{'op':'add','path':'/userRoles','value':[{'id':'" + roleBID + "'}]}]")
.content(HttpStatus.OK);

assertTrue(sessionRegistry.getAllSessions(sessionPrincipal, false).isEmpty());
}

@Test
void updateRolesAuthoritiesShouldInvalidateUserSessions() {
CurrentUserDetails sessionPrincipal = userService.createUserDetails(superUser);

UserRole roleB = createUserRole("ROLE_B", "ALL");
userService.addUserRole(roleB);

PATCH(
"/users/" + superUser.getUid(),
"[{'op':'add','path':'/userRoles','value':[{'id':'" + roleB.getUid() + "'}]}]")
.content(HttpStatus.OK);

String roleBID = userService.getUserRoleByName("ROLE_B").getUid();

sessionRegistry.registerNewSession("session1", sessionPrincipal);
assertFalse(sessionRegistry.getAllSessions(sessionPrincipal, false).isEmpty());

PATCH(
"/userRoles/" + roleBID,
"["
+ " {"
+ " 'op': 'add',"
+ " 'path': '/authorities',"
+ " 'value': ['NONE']"
+ " }"
+ "]")
.content(HttpStatus.OK);

assertTrue(sessionRegistry.getAllSessions(sessionPrincipal, false).isEmpty());
}

@Test
void testResetToInvite() {
assertStatus(HttpStatus.NO_CONTENT, POST("/users/" + peter.getUid() + "/reset"));
Expand Down
Loading