diff --git a/imperial_coldfront_plugin/oidc.py b/imperial_coldfront_plugin/oidc.py index 872f294..2fe35d1 100644 --- a/imperial_coldfront_plugin/oidc.py +++ b/imperial_coldfront_plugin/oidc.py @@ -1,30 +1,61 @@ """Customisations for the OIDC authentication backend.""" +import logging from typing import Any +from django.conf import settings from django.contrib.auth.models import User from mozilla_django_oidc.auth import OIDCAuthenticationBackend +from .ldap import get_uid_from_ldap +from .models import UnixUID -def _update_user(user: User, claims: dict[str, Any]) -> None: - user.username = claims["preferred_username"].rstrip("@ic.ac.uk") - user.email = claims["email"] - user.first_name = claims["given_name"] - user.last_name = claims["family_name"] - user.save() +logger = logging.getLogger("django") class ICLOIDCAuthenticationBackend(OIDCAuthenticationBackend): """Extension of the OIDC authentication backend for ICL auth.""" + def _get_user_data_from_claims(self, claims: dict[str, Any]) -> dict[str, str]: + return dict( + username=claims["preferred_username"].rstrip("@ic.ac.uk"), + email=claims["email"], + first_name=claims["given_name"], + last_name=claims["family_name"], + ) + + def _update_user_from_dict(self, user: User, data: dict[str, str]) -> None: + user.username = data["username"] + user.email = data["email"] + user.first_name = data["first_name"] + user.last_name = data["last_name"] + def create_user(self, claims: dict[str, Any]) -> User: """Create a new user from the available claims. Args: claims: user info provided by self.get_user_info """ + user_data = self._get_user_data_from_claims(claims) + username = user_data["username"] + if settings.LDAP_SERVER_URI and settings.LDAP_SEARCH_BASE: + try: + uid = get_uid_from_ldap(username) + except Exception: + raise ValueError( + f"Failed to retrieve UID from LDAP for user {username}" + ) + else: + uid = None + logger.warn( + f"LDAP settings not configured, UID not retrieved for user {username}" + ) + user = super().create_user(claims) - _update_user(user, claims) + self._update_user_from_dict(user, user_data) + user.save() + if uid is not None: + UnixUID.objects.create(user=user, identifier=uid) return user def update_user(self, user: User, claims: dict[str, Any]) -> User: @@ -34,7 +65,8 @@ def update_user(self, user: User, claims: dict[str, Any]) -> User: user: user to update claims: user info provided by self.get_user_info """ - _update_user(user, claims) + user_data = self._get_user_data_from_claims(claims) + self._update_user_from_dict(user, user_data) return user def get_userinfo(