diff --git a/packages/hagrid/hagrid/orchestra.py b/packages/hagrid/hagrid/orchestra.py index da87600e13b..31ef1ba7f09 100644 --- a/packages/hagrid/hagrid/orchestra.py +++ b/packages/hagrid/hagrid/orchestra.py @@ -164,7 +164,7 @@ def __init__( def client(self) -> Any: if self.port: sy = get_syft_client() - return sy.login(url=self.url, port=self.port, verbose=False) # type: ignore + return sy.login(url=self.url, port=self.port) # type: ignore elif self.deployment_type == DeploymentType.PYTHON: return self.python_node.get_guest_client(verbose=False) # type: ignore else: @@ -172,6 +172,16 @@ def client(self) -> Any: f"client not implemented for the deployment type:{self.deployment_type}" ) + def login_as_guest(self, **kwargs: Any) -> Optional[Any]: + client = self.client + + session = client.login_as_guest(**kwargs) + + if isinstance(session, SyftError): + return session + + return session + def login( self, email: Optional[str] = None, password: Optional[str] = None, **kwargs: Any ) -> Optional[Any]: @@ -183,6 +193,7 @@ def login( password = getpass.getpass("Password: ") session = client.login(email=email, password=password, **kwargs) + if isinstance(session, SyftError): return session diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index 947b4359a2f..77e5d6a5ec4 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -12,6 +12,7 @@ from .abstract_node import NodeType # noqa: F401 from .client.client import connect # noqa: F401 from .client.client import login # noqa: F401 +from .client.client import login_as_guest # noqa: F401 from .client.client import register # noqa: F401 from .client.deploy import Orchestra # noqa: F401 from .client.domain_client import DomainClient # noqa: F401 diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index f669ef66652..c6371f487a0 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -642,6 +642,16 @@ def me(self) -> Optional[Union[UserView, SyftError]]: return self.api.services.user.get_current_user() return None + def login_as_guest(self) -> Self: + _guest_client = self.guest() + + print( + f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side " + f"{self.metadata.node_type.capitalize()}> as GUEST" + ) + + return _guest_client + def login( self, email: str, password: str, cache: bool = True, register=False, **kwargs ) -> Self: @@ -857,6 +867,27 @@ def register( ) +@instrument +def login_as_guest( + url: Union[str, GridURL] = DEFAULT_PYGRID_ADDRESS, + node: Optional[AbstractNode] = None, + port: Optional[int] = None, + verbose: bool = True, +): + _client = connect(url=url, node=node, port=port) + + if isinstance(_client, SyftError): + return _client + + if verbose: + print( + f"Logged into <{_client.name}: {_client.metadata.node_side_type.capitalize()}-" + f"side {_client.metadata.node_type.capitalize()}> as GUEST" + ) + + return _client.guest() + + @instrument def login( url: Union[str, GridURL] = DEFAULT_PYGRID_ADDRESS, @@ -865,11 +896,12 @@ def login( email: Optional[str] = None, password: Optional[str] = None, cache: bool = True, - verbose: bool = True, ) -> SyftClient: _client = connect(url=url, node=node, port=port) + if isinstance(_client, SyftError): return _client + connection = _client.connection login_credentials = None @@ -879,14 +911,6 @@ def login( password = getpass("Password: ") login_credentials = UserLoginCredentials(email=email, password=password) - if login_credentials is None: - if verbose: - print( - f"Logged into <{_client.name}: {_client.metadata.node_side_type.capitalize()}-" - f"side {_client.metadata.node_type.capitalize()}> as GUEST" - ) - return _client.guest() - if cache and login_credentials: _client_cache = SyftClientSessionCache.get_client( login_credentials.email, diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 27f51c392c7..9a464d8c5ef 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -41,7 +41,6 @@ from nacl.signing import SigningKey from nacl.signing import VerifyKey import requests -from rich.prompt import Confirm # relative from .logger import critical @@ -456,11 +455,15 @@ def prompt_warning_message(message: str, confirm: bool = False) -> bool: warning = SyftWarning(message=message) display(warning) - if confirm: - allowed = Confirm.ask("Would you like to proceed?") - if not allowed: + while confirm: + response = input("Would you like to proceed? [y/n]: ").lower() + if response == "y": + return True + elif response == "n": display("Aborted !!") return False + else: + print("Invalid response. Please enter Y or N.") return True