diff --git a/backend/lcfs/web/api/user/repo.py b/backend/lcfs/web/api/user/repo.py index 4a8f99513..eabb9911c 100644 --- a/backend/lcfs/web/api/user/repo.py +++ b/backend/lcfs/web/api/user/repo.py @@ -352,31 +352,46 @@ async def create_user( async def update_user( self, user: UserProfile, user_update: UserCreateSchema ) -> None: + """ + Update an existing UserProfile with new data. + """ + + # Extract incoming data from the Pydantic schema user_data = user_update.model_dump() - updated_user_profile = UserProfile(**user_update.model_dump(exclude={"roles"})) - roles = user_data.pop("roles", {}) + new_roles = user_data.pop("roles", {}) + + # Update basic fields directly + user.email = user_update.email + user.title = user_update.title + user.first_name = user_update.first_name + user.last_name = user_update.last_name + user.is_active = user_update.is_active + user.keycloak_email = user_update.keycloak_email + user.keycloak_username = user_update.keycloak_username + user.phone = user_update.phone + user.mobile_phone = user_update.mobile_phone + # Find the RoleEnum member corresponding to each role - new_roles = [ - role_enum for role_enum in RoleEnum if role_enum.value.lower() in roles - ] + new_role_enums = [] + for role_str in new_roles: + try: + name_str = role_str.title() + role_enum = RoleEnum(name_str) + new_role_enums.append(role_enum) + except ValueError: + pass + # Create a set for faster membership checks existing_roles_set = set(user.role_names) - # Update the user object with the new data - user.email = updated_user_profile.email - user.title = updated_user_profile.title - user.first_name = updated_user_profile.first_name - user.last_name = updated_user_profile.last_name - user.is_active = updated_user_profile.is_active - user.keycloak_email = updated_user_profile.keycloak_email - user.keycloak_username = updated_user_profile.keycloak_username - user.phone = updated_user_profile.phone - user.mobile_phone = updated_user_profile.mobile_phone - if user.organization: - await self.update_bceid_roles(user, new_roles, existing_roles_set) + # BCEID logic + await self.update_bceid_roles(user, new_role_enums, existing_roles_set) else: - await self.update_idir_roles(user, new_roles, existing_roles_set) + # IDIR logic + await self.update_idir_roles(user, new_role_enums, existing_roles_set) + + # Add the updated user to the session self.db.add(user) return user @@ -672,9 +687,7 @@ async def create_login_history(self, user: UserProfile): self.db.add(login_history) @repo_handler - async def update_email( - self, user_profile_id: int, email: str - ) -> UserProfile: + async def update_email(self, user_profile_id: int, email: str) -> UserProfile: # Fetch the user profile query = select(UserProfile).where( UserProfile.user_profile_id == user_profile_id