Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return a new client with SyftClient.login #8159

Merged
merged 9 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/hagrid/hagrid/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
def client(self) -> Any:
if self.port:
sy = get_syft_client()
return sy.login(url=self.url, port=self.port) # type: ignore
return sy.login_as_guest(url=self.url, port=self.port) # type: ignore
elif self.deployment_type == DeploymentType.PYTHON:
return self.python_node.get_guest_client(verbose=False) # type: ignore
else:
Expand Down
78 changes: 42 additions & 36 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,38 +653,45 @@ def login_as_guest(self) -> Self:
return _guest_client

def login(
self, email: str, password: str, cache: bool = True, register=False, **kwargs
self,
email: Optional[str] = None,
password: Optional[str] = None,
cache: bool = True,
register: bool = False,
**kwargs: Any,
) -> Self:
if email is None:
shubham3121 marked this conversation as resolved.
Show resolved Hide resolved
email = input("Email: ")
if password is None:
password = getpass("Password: ")

if register:
if not email:
email = input("Email: ")
if not password:
password = getpass("Password: ")
self.register(
email=email, password=password, password_verify=password, **kwargs
)
if password is None:
password = getpass("Password: ")

user_private_key = self.connection.login(email=email, password=password)
if isinstance(user_private_key, SyftError):
return user_private_key
signing_key = None
if user_private_key is not None:
signing_key = user_private_key.signing_key
self.__user_role = user_private_key.role
if signing_key is not None:
self.credentials = signing_key
self.__logged_in_user = email

# Get current logged in user name
self.__logged_in_username = self.users.get_current_user().name
signing_key = None if user_private_key is None else user_private_key.signing_key

# TODO: How to get the role of the user?
# self.__user_role =
self._fetch_api(self.credentials)
client = self.__class__(
connection=self.connection,
metadata=self.metadata,
credentials=signing_key,
)

client.__logged_in_user = email

if user_private_key is not None:
client.__user_role = user_private_key.role
client.__logged_in_username = client.users.get_current_user().name

if signing_key is not None:
print(
f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()} side "
f"{self.metadata.node_type.capitalize()}> as <{email}>"
f"Logged into <{client.name}: {client.metadata.node_side_type.capitalize()} side "
f"{client.metadata.node_type.capitalize()}> as <{email}>"
)
# relative
from ..node.node import get_default_root_password
Expand All @@ -700,8 +707,8 @@ def login(
SyftClientSessionCache.add_client(
email=email,
password=password,
connection=self.connection,
syft_client=self,
connection=client.connection,
syft_client=client,
)
# Adding another cache storage
# as this would be useful in retrieving unique clients
Expand All @@ -712,16 +719,16 @@ def login(
# combining both email, password and verify key and uid
SyftClientSessionCache.add_client_by_uid_and_verify_key(
verify_key=signing_key.verify_key,
node_uid=self.id,
syft_client=self,
node_uid=client.id,
syft_client=client,
)

# relative
from ..node.node import CODE_RELOADER

CODE_RELOADER[thread_ident()] = self._reload_user_code
CODE_RELOADER[thread_ident()] = client._reload_user_code

return self
return client

def _reload_user_code(self):
# relative
Expand Down Expand Up @@ -756,7 +763,9 @@ def register(
password_verify=password_verify,
institution=institution,
website=website,
created_by=self.credentials,
created_by=(
None if self.__user_role == ServiceRole.GUEST else self.credentials
),
)
except Exception as e:
return SyftError(message=str(e))
Expand Down Expand Up @@ -890,10 +899,10 @@ def login_as_guest(

@instrument
def login(
email: str,
shubham3121 marked this conversation as resolved.
Show resolved Hide resolved
url: Union[str, GridURL] = DEFAULT_PYGRID_ADDRESS,
node: Optional[AbstractNode] = None,
port: Optional[int] = None,
email: Optional[str] = None,
password: Optional[str] = None,
cache: bool = True,
) -> SyftClient:
Expand All @@ -906,10 +915,9 @@ def login(

login_credentials = None

if email:
if not password:
password = getpass("Password: ")
login_credentials = UserLoginCredentials(email=email, password=password)
if not password:
password = getpass("Password: ")
login_credentials = UserLoginCredentials(email=email, password=password)

if cache and login_credentials:
_client_cache = SyftClientSessionCache.get_client(
Expand All @@ -924,13 +932,11 @@ def login(
_client = _client_cache

if not _client.authed and login_credentials:
_client.login(
_client = _client.login(
email=login_credentials.email,
password=login_credentials.password,
cache=cache,
)
if not _client.authed:
Copy link
Collaborator

@tcp tcp Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is a failed login attempt handled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would return a SyftError

_client = ...
if isinstance(_client, SyftError): return _client

return _client

in both cases we return _client so that check is not needed here

return SyftError(message=f"Failed to login as {login_credentials.email}")

return _client

Expand Down
10 changes: 8 additions & 2 deletions packages/syft/src/syft/client/domain_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .api import APIModule
from .client import SyftClient
from .client import login
from .client import login_as_guest

if TYPE_CHECKING:
# relative
Expand Down Expand Up @@ -121,14 +122,19 @@ def connect_to_gateway(
url: Optional[str] = None,
port: Optional[int] = None,
handle: Optional[NodeHandle] = None, # noqa: F821
**kwargs,
email: Optional[str] = None,
password: Optional[str] = None,
) -> None:
if via_client is not None:
client = via_client
elif handle is not None:
client = handle.client
else:
client = login(url=url, port=port, **kwargs)
client = (
login_as_guest(url=url, port=port)
if email is None
else login(url=url, port=port, email=email, password=password)
)
if isinstance(client, SyftError):
return client

Expand Down
10 changes: 8 additions & 2 deletions packages/syft/src/syft/client/enclave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .api import APIModule
from .client import SyftClient
from .client import login
from .client import login_as_guest

if TYPE_CHECKING:
# relative
Expand Down Expand Up @@ -65,14 +66,19 @@ def connect_to_gateway(
url: Optional[str] = None,
port: Optional[int] = None,
handle: Optional[NodeHandle] = None, # noqa: F821
**kwargs,
email: Optional[str] = None,
password: Optional[str] = None,
) -> None:
if via_client is not None:
client = via_client
elif handle is not None:
client = handle.client
else:
client = login(url=url, port=port, **kwargs)
client = (
login_as_guest(url=url, port=port)
if email is None
else login(url=url, port=port, email=email, password=password)
)
if isinstance(client, SyftError):
return client

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/gateway_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def proxy_client_for(
return SyftError(message=f"No domain with name {name}")
res = self.proxy_to(peer)
if email and password:
res.login(email=email, password=password, **kwargs)
res = res.login(email=email, password=password, **kwargs)
return res

@property
Expand Down
3 changes: 2 additions & 1 deletion packages/syft/src/syft/external/oblv/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...client.api import SyftAPI
from ...client.client import SyftClient
from ...client.client import login
from ...client.client import login_as_guest
from ...enclave.metadata import EnclaveMetadata
from ...serde.serializable import serializable
from ...types.uid import UID
Expand Down Expand Up @@ -246,7 +247,7 @@ def register(
website: Optional[str] = None,
):
self.check_connection_string()
guest_client = login(url=self.__conn_string)
guest_client = login_as_guest(url=self.__conn_string)
return guest_client.register(
name=name,
email=email,
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ def verify_members(cls, val: Union[List[SyftClient], List[NodeIdentity]]):
if len(clients) > 0:
emails = {client.logged_in_user for client in clients}
if len(emails) > 1:
raise SyftException(
raise ValueError(
f"All clients must be logged in from the same account. Found multiple: {emails}"
)
return val
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/tests/syft/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_api_cache_invalidation_login(root_verify_key, worker):
assert guest_client.register(
name="q", email="[email protected]", password="aaa", password_verify="aaa"
)
guest_client.login(email="[email protected]", password="aaa")
guest_client = guest_client.login(email="[email protected]", password="aaa")
user_id = worker.document_store.partitions["User"].all(root_verify_key).value[-1].id

def get_role(verify_key):
Expand All @@ -75,6 +75,6 @@ def get_role(verify_key):

assert get_role(guest_client.credentials.verify_key) == ServiceRole.DATA_OWNER

guest_client.login(email="[email protected]", password="aaa")
guest_client = guest_client.login(email="[email protected]", password="aaa")

assert guest_client.upload_dataset(dataset)
10 changes: 5 additions & 5 deletions packages/syft/tests/syft/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ def test_client_logged_in_user(worker):
guest_client = worker.guest_client
assert guest_client.logged_in_user == ""

guest_client.login(email="[email protected]", password="changethis")
assert guest_client.logged_in_user == "[email protected]"
client = guest_client.login(email="[email protected]", password="changethis")
assert client.logged_in_user == "[email protected]"

guest_client.register(
client.register(
name="sheldon",
email="[email protected]",
password="bazinga",
password_verify="bazinga",
)

guest_client.login(email="[email protected]", password="bazinga")
client = client.login(email="[email protected]", password="bazinga")

assert guest_client.logged_in_user == "[email protected]"
assert client.logged_in_user == "[email protected]"
14 changes: 10 additions & 4 deletions packages/syft/tests/syft/gateways/gateway_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ def test_domain_connect_to_gateway(faker: Faker):
assert proxy_domain_client.metadata == domain_client.metadata
assert proxy_domain_client.user_role == ServiceRole.NONE

domain_client.login(email="[email protected]", password="changethis")
proxy_domain_client.login(email="[email protected]", password="changethis")
domain_client = domain_client.login(
email="[email protected]", password="changethis"
)
proxy_domain_client = proxy_domain_client.login(
email="[email protected]", password="changethis"
)

assert proxy_domain_client.logged_in_user == "[email protected]"
assert proxy_domain_client.user_role == ServiceRole.ADMIN
Expand Down Expand Up @@ -129,8 +133,10 @@ def test_enclave_connect_to_gateway(faker: Faker):
password_verify=password,
)

enclave_client.login(email=user_email, password=password)
proxy_enclave_client.login(email=user_email, password=password)
enclave_client = enclave_client.login(email=user_email, password=password)
proxy_enclave_client = proxy_enclave_client.login(
email=user_email, password=password
)

assert proxy_enclave_client.logged_in_user == user_email
assert proxy_enclave_client.user_role == enclave_client.user_role
Expand Down
4 changes: 1 addition & 3 deletions packages/syft/tests/syft/project/project_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def test_exception_different_email(worker):

ds_sheldon = sy.login(node=worker, email="[email protected]", password="bazinga")

ds_leonard = sy.login(
node=worker, email="[email protected]", password="starwars"
)
ds_leonard = sy.login(node=worker, email="[email protected]", password="penny")

with pytest.raises(ValidationError):
sy.Project(
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/tests/syft/settings/settings_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,9 @@ def get_mock_client(faker, root_client, role):
assert not isinstance(result, SyftError)

guest_client = root_client.guest()
guest_client.login(email=user_create.email, password=user_create.password)
return guest_client
return guest_client.login(
email=user_create.email, password=user_create.password
)

verify_key = SyftSigningKey.generate().verify_key
mock_node_metadata = NodeMetadata(
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_scientist_can_list_code_assets(worker: sy.Worker, faker: Faker) -> None

guest_client = root_client.guest()
credentials.pop("name")
guest_client.login(**credentials)
guest_client = guest_client.login(**credentials)

root_client.upload_dataset(dataset=dataset)

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/tests/syft/users/user_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_mock_client(root_client, role) -> DomainClient:
assert worker.root_client.api.services.user.update(
user_id, UserUpdate(user_id=user_id, role=role)
)
client.login(email=mail, password=password)
client = client.login(email=mail, password=password)
client._fetch_api(client.credentials)
# hacky, but useful for testing: patch user id and role on client
client.user_id = user_id
Expand Down
Loading