From 9352879234b0d1ede64cb4db04045ac8fa06cb27 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Fri, 13 Oct 2023 16:41:16 +0800 Subject: [PATCH] Return a new client with SyftClient.login --- packages/syft/src/syft/client/client.py | 58 ++++++++++++++----------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index c6371f487a0..fbd230eac0a 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -653,38 +653,42 @@ def login_as_guest(self) -> Self: return _guest_client def login( - self, email: str, password: str, cache: bool = True, register=False, **kwargs + self, + email: Optional[str] = None, + password: Optional[str] = None, + cache: bool = True, + register: bool = False, + **kwargs: Any, ) -> Self: + if not register and email is None and password is None: + return self.login_as_guest() + + if email is None: + email = input("Email: ") + if password is None: + password = getpass("Password: ") + if register: - if not email: - email = input("Email: ") - if not password: - password = getpass("Password: ") self.register( email=email, password=password, password_verify=password, **kwargs ) - if password is None: - password = getpass("Password: ") + user_private_key = self.connection.login(email=email, password=password) if isinstance(user_private_key, SyftError): return user_private_key - signing_key = None - if user_private_key is not None: - signing_key = user_private_key.signing_key - self.__user_role = user_private_key.role - if signing_key is not None: - self.credentials = signing_key - self.__logged_in_user = email - # Get current logged in user name - self.__logged_in_username = self.users.get_current_user().name + signing_key = None if user_private_key is None else user_private_key.signing_key - # TODO: How to get the role of the user? - # self.__user_role = - self._fetch_api(self.credentials) + client = self.__class__( + connection=self.connection, + metadata=self.metadata, + credentials=signing_key, + ) + + if signing_key is not None: print( - f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()} side " - f"{self.metadata.node_type.capitalize()}> as <{email}>" + f"Logged into <{client.name}: {client.metadata.node_side_type.capitalize()} side " + f"{client.metadata.node_type.capitalize()}> as <{email}>" ) # relative from ..node.node import get_default_root_password @@ -700,8 +704,8 @@ def login( SyftClientSessionCache.add_client( email=email, password=password, - connection=self.connection, - syft_client=self, + connection=client.connection, + syft_client=client, ) # Adding another cache storage # as this would be useful in retrieving unique clients @@ -719,9 +723,9 @@ def login( # relative from ..node.node import CODE_RELOADER - CODE_RELOADER[thread_ident()] = self._reload_user_code + CODE_RELOADER[thread_ident()] = client._reload_user_code - return self + return client def _reload_user_code(self): # relative @@ -756,7 +760,9 @@ def register( password_verify=password_verify, institution=institution, website=website, - created_by=self.credentials, + created_by=( + None if self.__user_role == ServiceRole.GUEST else self.credentials + ), ) except Exception as e: return SyftError(message=str(e))