diff --git a/services/enterprise/modules/user/repository.py b/services/enterprise/modules/user/repository.py index 3640cb95..67e1aaa3 100644 --- a/services/enterprise/modules/user/repository.py +++ b/services/enterprise/modules/user/repository.py @@ -13,6 +13,10 @@ def get_user(self, query: dict) -> User: user = MongoDB.find_one(USER_COL, query) return User(id=str(user["_id"]), **user) if user else None + def get_user_by_sub(self, sub: str) -> User: + user = MongoDB.find_one(USER_COL, {"sub": sub}) + return User(id=str(user["_id"]), **user) if user else None + def get_user_by_email(self, email: str) -> User: user = MongoDB.find_one(USER_COL, {"email": email}) return User(id=str(user["_id"]), **user) if user else None diff --git a/services/enterprise/modules/user/service.py b/services/enterprise/modules/user/service.py index 3ded69de..2d5c3320 100644 --- a/services/enterprise/modules/user/service.py +++ b/services/enterprise/modules/user/service.py @@ -28,6 +28,11 @@ def get_users(self, org_id: str) -> list[UserResponse]: def get_user(self, user_id: str, org_id: str) -> UserResponse: return self.get_user_in_org(user_id, org_id) + def get_user_by_sub(self, sub: str) -> User: + """Helper function to get user by Auth0sub.""" + user = self.repo.get_user_by_sub(sub) + return user if user else None + def get_user_by_email(self, email: str) -> User: """Helper function to get user by email.""" user = self.repo.get_user_by_email(email) diff --git a/services/enterprise/utils/auth.py b/services/enterprise/utils/auth.py index ad40d828..0fbe6b9c 100644 --- a/services/enterprise/utils/auth.py +++ b/services/enterprise/utils/auth.py @@ -80,10 +80,10 @@ def _decode_payload(self): class Authorize: def user(self, payload: dict) -> User: - email = payload[auth_settings.auth0_issuer + "email"] - user = user_service.get_user_by_email(email) + sub = payload['sub'] + user = user_service.get_user_by_sub(sub) if not user: - raise UnauthorizedUserError(email=email) + raise UnauthorizedUserError(email=sub) return user def user_in_organization(self, user_id: str, org_id: str):