Skip to content

Commit

Permalink
Return a new client with SyftClient.login
Browse files Browse the repository at this point in the history
  • Loading branch information
kiendang committed Oct 13, 2023
1 parent c1288b2 commit 9352879
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,38 +653,42 @@ def login_as_guest(self) -> Self:
return _guest_client

def login(
self, email: str, password: str, cache: bool = True, register=False, **kwargs
self,
email: Optional[str] = None,
password: Optional[str] = None,
cache: bool = True,
register: bool = False,
**kwargs: Any,
) -> Self:
if not register and email is None and password is None:
return self.login_as_guest()

if email is None:
email = input("Email: ")
if password is None:
password = getpass("Password: ")

if register:
if not email:
email = input("Email: ")
if not password:
password = getpass("Password: ")
self.register(
email=email, password=password, password_verify=password, **kwargs
)
if password is None:
password = getpass("Password: ")

user_private_key = self.connection.login(email=email, password=password)
if isinstance(user_private_key, SyftError):
return user_private_key
signing_key = None
if user_private_key is not None:
signing_key = user_private_key.signing_key
self.__user_role = user_private_key.role
if signing_key is not None:
self.credentials = signing_key
self.__logged_in_user = email

# Get current logged in user name
self.__logged_in_username = self.users.get_current_user().name
signing_key = None if user_private_key is None else user_private_key.signing_key

# TODO: How to get the role of the user?
# self.__user_role =
self._fetch_api(self.credentials)
client = self.__class__(
connection=self.connection,
metadata=self.metadata,
credentials=signing_key,
)

if signing_key is not None:
print(
f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()} side "
f"{self.metadata.node_type.capitalize()}> as <{email}>"
f"Logged into <{client.name}: {client.metadata.node_side_type.capitalize()} side "
f"{client.metadata.node_type.capitalize()}> as <{email}>"
)
# relative
from ..node.node import get_default_root_password
Expand All @@ -700,8 +704,8 @@ def login(
SyftClientSessionCache.add_client(
email=email,
password=password,
connection=self.connection,
syft_client=self,
connection=client.connection,
syft_client=client,
)
# Adding another cache storage
# as this would be useful in retrieving unique clients
Expand All @@ -719,9 +723,9 @@ def login(
# relative
from ..node.node import CODE_RELOADER

CODE_RELOADER[thread_ident()] = self._reload_user_code
CODE_RELOADER[thread_ident()] = client._reload_user_code

return self
return client

def _reload_user_code(self):
# relative
Expand Down Expand Up @@ -756,7 +760,9 @@ def register(
password_verify=password_verify,
institution=institution,
website=website,
created_by=self.credentials,
created_by=(
None if self.__user_role == ServiceRole.GUEST else self.credentials
),
)
except Exception as e:
return SyftError(message=str(e))
Expand Down

0 comments on commit 9352879

Please sign in to comment.