diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78d7205afb6..77995cb5a74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,8 +172,8 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" - # files: "^packages/syft/src/syft/" + files: "^packages/syft/src/syft/" + exclude: "packages/syft/src/syft/types/dicttuple.py|^packages/syft/src/syft/service/action/action_graph.py|^packages/syft/src/syft/external/oblv/" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index d438f9ef9bd..2a0fcfa5b6d 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -93,6 +93,9 @@ logger.start() try: + # third party + from IPython import get_ipython + get_ipython() # noqa: F821 # TODO: add back later or auto detect # display( diff --git a/packages/syft/src/syft/abstract_node.py b/packages/syft/src/syft/abstract_node.py index 2341d6e4926..046c7e493ff 100644 --- a/packages/syft/src/syft/abstract_node.py +++ b/packages/syft/src/syft/abstract_node.py @@ -2,12 +2,17 @@ from enum import Enum from typing import Callable from typing import Optional +from typing import TYPE_CHECKING from typing import Union # relative from .serde.serializable import serializable from .types.uid import UID +if TYPE_CHECKING: + # relative + from .service.service import AbstractService + @serializable() class NodeType(str, Enum): @@ -37,5 +42,5 @@ class AbstractNode: node_side_type: Optional[NodeSideType] in_memory_workers: bool - def get_service(self, path_or_func: Union[str, Callable]) -> Callable: + def get_service(self, path_or_func: Union[str, Callable]) -> "AbstractService": raise NotImplementedError diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index df948df96ca..5821550b2b9 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -16,6 +16,7 @@ from typing import Tuple from typing import Union from typing import _GenericAlias +from typing import cast from typing import get_args from typing import get_origin @@ -241,8 +242,8 @@ def __ipython_inspector_signature_override__(self) -> Optional[Signature]: return self.signature def prepare_args_and_kwargs( - self, args: List[Any], kwargs: Dict[str, Any] - ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]: + self, args: Union[list, tuple], kwargs: dict[str, Any] + ) -> Union[SyftError, tuple[tuple, dict[str, Any]]]: # Validate and migrate args and kwargs res = validate_callable_args_and_kwargs(args, kwargs, self.signature) if isinstance(res, SyftError): @@ -279,7 +280,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: api_call = SyftAPICall( node_uid=self.node_uid, path=self.path, - args=_valid_args, + args=list(_valid_args), kwargs=_valid_kwargs, blocking=blocking, ) @@ -304,8 +305,8 @@ class RemoteUserCodeFunction(RemoteFunction): api: SyftAPI def prepare_args_and_kwargs( - self, args: List[Any], kwargs: Dict[str, Any] - ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]: + self, args: Union[list, tuple], kwargs: Dict[str, Any] + ) -> Union[SyftError, tuple[tuple, dict[str, Any]]]: # relative from ..service.action.action_object import convert_to_pointers @@ -506,9 +507,12 @@ def _repr_html_(self) -> Any: results = self.get_all() return results._repr_html_() + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return NotImplementedError + def debox_signed_syftapicall_response( - signed_result: SignedSyftAPICall, + signed_result: Union[SignedSyftAPICall, Any], ) -> Union[Any, SyftError]: if not isinstance(signed_result, SignedSyftAPICall): return SyftError(message="The result is not signed") @@ -825,16 +829,16 @@ def build_endpoint_tree( ) @property - def services(self) -> Optional[APIModule]: + def services(self) -> APIModule: if self.api_module is None: self.generate_endpoints() - return self.api_module + return cast(APIModule, self.api_module) @property - def lib(self) -> Optional[APIModule]: + def lib(self) -> APIModule: if self.libs is None: self.generate_endpoints() - return self.libs + return cast(APIModule, self.libs) def has_service(self, service_name: str) -> bool: return hasattr(self.services, service_name) @@ -940,19 +944,21 @@ class NodeIdentity(Identity): node_name: str @staticmethod - def from_api(api: SyftAPI) -> Optional[NodeIdentity]: + def from_api(api: SyftAPI) -> NodeIdentity: # stores the name root verify key of the domain node - if api.connection is not None: - node_metadata = api.connection.get_node_metadata(api.signing_key) - return NodeIdentity( - node_name=node_metadata.name, - node_id=api.node_uid, - verify_key=SyftVerifyKey.from_string(node_metadata.verify_key), - ) - return None + if api.connection is None: + raise ValueError("{api}'s connection is None. Can't get the node identity") + node_metadata = api.connection.get_node_metadata(api.signing_key) + return NodeIdentity( + node_name=node_metadata.name, + node_id=api.node_uid, + verify_key=SyftVerifyKey.from_string(node_metadata.verify_key), + ) @classmethod def from_change_context(cls, context: ChangeContext) -> NodeIdentity: + if context.node is None: + raise ValueError(f"{context}'s node is None") return cls( node_name=context.node.name, node_id=context.node.id, diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 5d4228c1ce4..505b780f318 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -13,7 +13,6 @@ from typing import List from typing import Optional from typing import TYPE_CHECKING -from typing import Tuple from typing import Type from typing import Union from typing import cast @@ -25,7 +24,7 @@ from requests import Response from requests import Session from requests.adapters import HTTPAdapter -from requests.packages.urllib3.util.retry import Retry +from requests.packages.urllib3.util.retry import Retry # type: ignore[import-untyped] from typing_extensions import Self # relative @@ -94,9 +93,9 @@ def forward_message_to_proxy( proxy_target_uid: UID, path: str, credentials: Optional[SyftSigningKey] = None, - args: Optional[Tuple] = None, + args: Optional[list] = None, kwargs: Optional[Dict] = None, -): +) -> Union[Any, SyftError]: kwargs = {} if kwargs is None else kwargs args = [] if args is None else args call = SyftAPICall( @@ -155,7 +154,7 @@ def get_cache_key(self) -> str: def api_url(self) -> GridURL: return self.url.with_path(self.routes.ROUTE_API_CALL.value) - def to_blob_route(self, path: str, **kwargs) -> GridURL: + def to_blob_route(self, path: str, **kwargs: Any) -> GridURL: _path = self.routes.ROUTE_BLOB_STORE.value + path return self.url.with_path(_path) @@ -347,7 +346,7 @@ def get_node_metadata(self, credentials: SyftSigningKey) -> NodeMetadataJSON: else: return self.node.metadata.to(NodeMetadataJSON) - def to_blob_route(self, path: str, host=None) -> GridURL: + def to_blob_route(self, path: str, host: Optional[str] = None) -> GridURL: # TODO: FIX! if host is not None: return GridURL(host_or_ip=host, port=8333).with_path(path) @@ -474,8 +473,8 @@ def __init__( self.metadata = metadata self.credentials: Optional[SyftSigningKey] = credentials self._api = api - self.communication_protocol = None - self.current_protocol = None + self.communication_protocol: Optional[Union[int, str]] = None + self.current_protocol: Optional[Union[int, str]] = None self.post_init() @@ -485,7 +484,7 @@ def get_env(self) -> str: def post_init(self) -> None: if self.metadata is None: self._fetch_node_metadata(self.credentials) - + self.metadata = cast(NodeMetadataJSON, self.metadata) self.communication_protocol = self._get_communication_protocol( self.metadata.supported_protocols ) @@ -528,7 +527,8 @@ def create_project( project = project_create.start() return project - def sync_code_from_request(self, request): + # TODO: type of request should be REQUEST, but it will give circular import error + def sync_code_from_request(self, request: Any) -> Union[SyftSuccess, SyftError]: # relative from ..service.code.user_code import UserCode from ..service.code.user_code import UserCodeStatusCollection @@ -542,8 +542,11 @@ def sync_code_from_request(self, request): code.node_uid = self.id code.user_verify_key = self.verify_key - def get_nested_codes(code: UserCode): - result = [] + def get_nested_codes(code: UserCode) -> list[UserCode]: + result: list[UserCode] = [] + if code.nested_codes is None: + return result + for _, (linked_code_obj, _) in code.nested_codes.items(): nested_code = linked_code_obj.resolve nested_code = deepcopy(nested_code) @@ -552,11 +555,6 @@ def get_nested_codes(code: UserCode): result.append(nested_code) result += get_nested_codes(nested_code) - updated_code_links = { - nested_code.service_func_name: (LinkedObject.from_obj(nested_code), {}) - for nested_code in result - } - code.nested_codes = updated_code_links return result def get_code_statusses(codes: List[UserCode]) -> List[UserCodeStatusCollection]: @@ -607,7 +605,7 @@ def verify_key(self) -> SyftVerifyKey: @classmethod def from_url(cls, url: Union[str, GridURL]) -> Self: - return cls(connection=HTTPConnection(GridURL.from_url(url))) + return cls(connection=HTTPConnection(url=GridURL.from_url(url))) @classmethod def from_node(cls, node: AbstractNode) -> Self: @@ -641,8 +639,7 @@ def api(self) -> SyftAPI: # invalidate API if self._api is None or (self._api.signing_key != self.credentials): self._fetch_api(self.credentials) - - return self._api + return cast(SyftAPI, self._api) # we are sure self._api is not None after fetch def guest(self) -> Self: return self.__class__( @@ -657,7 +654,8 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]: self_node_route = connection_to_route(self.connection) remote_node_route = connection_to_route(client.connection) - + if client.metadata is None: + return SyftError(f"client {client}'s metadata is None!") result = self.api.services.network.exchange_credentials_with( self_node_route=self_node_route, remote_node_route=remote_node_route, @@ -715,10 +713,11 @@ def me(self) -> Optional[Union[UserView, SyftError]]: def login_as_guest(self) -> Self: _guest_client = self.guest() - print( - f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side " - f"{self.metadata.node_type.capitalize()}> as GUEST" - ) + if self.metadata is not None: + print( + f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side " + f"{self.metadata.node_type.capitalize()}> as GUEST" + ) return _guest_client @@ -762,11 +761,11 @@ def login( client.__logged_in_user = email - if user_private_key is not None: + if user_private_key is not None and client.users is not None: client.__user_role = user_private_key.role client.__logged_in_username = client.users.get_current_user().name - if signing_key is not None: + if signing_key is not None and client.metadata is not None: print( f"Logged into <{client.name}: {client.metadata.node_side_type.capitalize()} side " f"{client.metadata.node_type.capitalize()}> as <{email}>" @@ -804,16 +803,18 @@ def login( # relative from ..node.node import CODE_RELOADER - CODE_RELOADER[thread_ident()] = client._reload_user_code + thread_id = thread_ident() + if thread_id is not None: + CODE_RELOADER[thread_id] = client._reload_user_code return client - def _reload_user_code(self): + def _reload_user_code(self) -> None: # relative from ..service.code.user_code import load_approved_policy_code user_code_items = self.code.get_all_for_user() - load_approved_policy_code(user_code_items) + load_approved_policy_code(user_code_items=user_code_items, context=None) def register( self, @@ -823,7 +824,7 @@ def register( password_verify: Optional[str] = None, institution: Optional[str] = None, website: Optional[str] = None, - ): + ) -> Optional[Union[SyftError, SyftSigningKey]]: if not email: email = input("Email: ") if not password: @@ -848,7 +849,10 @@ def register( except Exception as e: return SyftError(message=str(e)) - if self.metadata.node_side_type == NodeSideType.HIGH_SIDE.value: + if ( + self.metadata + and self.metadata.node_side_type == NodeSideType.HIGH_SIDE.value + ): message = ( "You're registering a user to a high side " f"{self.metadata.node_type}, which could " @@ -857,7 +861,7 @@ def register( if self.metadata.show_warnings and not prompt_warning_message( message=message ): - return + return None response = self.connection.register(new_user=new_user) if isinstance(response, tuple): @@ -894,16 +898,20 @@ def _fetch_node_metadata(self, credentials: SyftSigningKey) -> None: metadata.check_version(__version__) self.metadata = metadata - def _fetch_api(self, credentials: SyftSigningKey): + def _fetch_api(self, credentials: SyftSigningKey) -> None: _api: SyftAPI = self.connection.get_api( credentials=credentials, communication_protocol=self.communication_protocol, ) - def refresh_callback(): + def refresh_callback() -> None: return self._fetch_api(self.credentials) _api.refresh_api_callback = refresh_callback + + if self.credentials is None: + raise ValueError(f"{self}'s credentials (signing key) is None!") + APIRegistry.set_api_for( node_uid=self.id, user_verify_key=self.credentials.verify_key, @@ -943,7 +951,7 @@ def register( password: str, institution: Optional[str] = None, website: Optional[str] = None, -): +) -> Optional[Union[SyftError, SyftSigningKey]]: guest_client = connect(url=url, port=port) return guest_client.register( name=name, @@ -960,13 +968,13 @@ def login_as_guest( node: Optional[AbstractNode] = None, port: Optional[int] = None, verbose: bool = True, -): +) -> SyftClient: _client = connect(url=url, node=node, port=port) if isinstance(_client, SyftError): return _client - if verbose: + if verbose and _client.metadata is not None: print( f"Logged into <{_client.name}: {_client.metadata.node_side_type.capitalize()}-" f"side {_client.metadata.node_type.capitalize()}> as GUEST" @@ -1039,7 +1047,7 @@ def add_client( password: str, connection: NodeConnection, syft_client: SyftClient, - ): + ) -> None: hash_key = cls._get_key(email, password, connection.get_cache_key()) cls.__credentials_store__[hash_key] = syft_client cls.__client_cache__[syft_client.id] = syft_client @@ -1050,7 +1058,7 @@ def add_client_by_uid_and_verify_key( verify_key: SyftVerifyKey, node_uid: UID, syft_client: SyftClient, - ): + ) -> None: hash_key = str(node_uid) + str(verify_key) cls.__client_cache__[hash_key] = syft_client @@ -1067,8 +1075,8 @@ def get_client( ) -> Optional[SyftClient]: # we have some bugs here so lets disable until they are fixed. return None - hash_key = cls._get_key(email, password, connection.get_cache_key()) - return cls.__credentials_store__.get(hash_key, None) + # hash_key = cls._get_key(email, password, connection.get_cache_key()) + # return cls.__credentials_store__.get(hash_key, None) @classmethod def get_client_for_node_uid(cls, node_uid: UID) -> Optional[SyftClient]: diff --git a/packages/syft/src/syft/client/connection.py b/packages/syft/src/syft/client/connection.py index 5b9928c8355..a94cb1c0707 100644 --- a/packages/syft/src/syft/client/connection.py +++ b/packages/syft/src/syft/client/connection.py @@ -10,7 +10,7 @@ class NodeConnection(SyftObject): __canonical_name__ = "NodeConnection" __version__ = SYFT_OBJECT_VERSION_1 - def get_cache_key() -> str: + def get_cache_key(self) -> str: raise NotImplementedError def __repr__(self) -> str: diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 80679a49f81..35773178b56 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -4,11 +4,14 @@ # stdlib from pathlib import Path import re +from typing import List from typing import Optional from typing import TYPE_CHECKING from typing import Union +from typing import cast # third party +from hagrid.orchestra import NodeHandle from loguru import logger from tqdm import tqdm @@ -26,8 +29,10 @@ from ..service.response import SyftSuccess from ..service.sync.diff_state import ResolvedSyncState from ..service.user.roles import Roles +from ..service.user.user import UserView from ..service.user.user_roles import ServiceRole from ..types.blob_storage import BlobFile +from ..types.syft_object import SyftObject from ..types.uid import UID from ..util.fonts import fonts_css from ..util.util import get_mb_size @@ -37,13 +42,14 @@ from .client import SyftClient from .client import login from .client import login_as_guest +from .connection import NodeConnection if TYPE_CHECKING: # relative from ..service.project.project import Project -def _get_files_from_glob(glob_path: str) -> list: +def _get_files_from_glob(glob_path: str) -> list[Path]: files = Path().glob(glob_path) return [f for f in files if f.is_file() and not f.name.startswith(".")] @@ -61,7 +67,7 @@ def _contains_subdir(dir: Path) -> bool: def add_default_uploader( - user, obj: Union[CreateDataset, CreateAsset] + user: UserView, obj: Union[CreateDataset, CreateAsset] ) -> Union[CreateDataset, CreateAsset]: uploader = None for contributor in obj.contributors: @@ -90,6 +96,9 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError # relative from ..types.twin_object import TwinObject + if self.users is None: + return SyftError(f"can't get user service for {self}") + user = self.users.get_current_user() dataset = add_default_uploader(user, dataset) for i in range(len(dataset.asset_list)): @@ -97,9 +106,12 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError dataset.asset_list[i] = add_default_uploader(user, asset) dataset._check_asset_must_contain_mock() - dataset_size = 0 + dataset_size: float = 0.0 # TODO: Refactor so that object can also be passed to generate warnings + + self.api.connection = cast(NodeConnection, self.api.connection) + metadata = self.api.connection.get_node_metadata(self.api.signing_key) if ( @@ -134,18 +146,18 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError dataset_size += get_mb_size(asset.data) dataset.mb_size = dataset_size valid = dataset.check() - if valid.ok(): - return self.api.services.dataset.add(dataset=dataset) - else: - if len(valid.err()) > 0: - return tuple(valid.err()) - return valid.err() + if isinstance(valid, SyftError): + return valid + return self.api.services.dataset.add(dataset=dataset) - def create_actionobject(self, action_object): + def create_actionobject(self, action_object: SyftObject) -> None: action_object = action_object.refresh_object() action_object.send(self) - def get_permissions_for_other_node(self, items): + def get_permissions_for_other_node( + self, + items: list[Union[ActionObject, SyftObject]], + ) -> dict: if len(items) > 0: if not len({i.syft_node_location for i in items}) == 1 or ( not len({i.syft_client_verify_key for i in items}) == 1 @@ -155,6 +167,10 @@ def get_permissions_for_other_node(self, items): api = APIRegistry.api_for( item.syft_node_location, item.syft_client_verify_key ) + if api is None: + raise ValueError( + f"Can't access the api. Please log in to {item.syft_node_location}" + ) return api.services.sync.get_permissions(items) else: return {} @@ -168,7 +184,7 @@ def apply_state( action_objects = [x for x in items if isinstance(x, ActionObject)] # permissions = self.get_permissions_for_other_node(items) - permissions = {} + permissions: dict[UID, set[str]] = {} for p in resolved_state.new_permissions: if p.uid in permissions: permissions[p.uid].add(p.permission_string) @@ -193,16 +209,17 @@ def apply_state( def upload_files( self, file_list: Union[BlobFile, list[BlobFile], str, list[str], Path, list[Path]], - allow_recursive=False, - show_files=False, + allow_recursive: bool = False, + show_files: bool = False, ) -> Union[SyftSuccess, SyftError]: if not file_list: return SyftError(message="No files to upload") if not isinstance(file_list, list): - file_list = [file_list] + file_list = [file_list] # type: ignore[assignment] + file_list = cast(list, file_list) - expanded_file_list = [] + expanded_file_list: List[Union[BlobFile, Path]] = [] for file in file_list: if isinstance(file, BlobFile): @@ -263,7 +280,7 @@ def connect_to_gateway( handle: Optional[NodeHandle] = None, # noqa: F821 email: Optional[str] = None, password: Optional[str] = None, - ) -> None: + ) -> Optional[Union[SyftSuccess, SyftError]]: if via_client is not None: client = via_client elif handle is not None: @@ -279,9 +296,12 @@ def connect_to_gateway( res = self.exchange_route(client) if isinstance(res, SyftSuccess): - return SyftSuccess( - message=f"Connected {self.metadata.node_type} to {client.name} gateway" - ) + if self.metadata: + return SyftSuccess( + message=f"Connected {self.metadata.node_type} to {client.name} gateway" + ) + else: + return SyftSuccess(message=f"Connected to {client.name} gateway") return res @property @@ -374,8 +394,8 @@ def output(self) -> Optional[APIModule]: def get_project( self, - name: str = None, - uid: UID = None, + name: Optional[str] = None, + uid: Optional[UID] = None, ) -> Optional[Project]: """Get project by name or UID""" @@ -442,18 +462,17 @@ def _repr_html_(self) -> str: url = getattr(self.connection, "url", None) node_details = f"URL: {url}
" if url else "" - node_details += ( - f"Node Type: {self.metadata.node_type.capitalize()}
" - ) - node_side_type = ( - "Low Side" - if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value - else "High Side" - ) - node_details += f"Node Side Type: {node_side_type}
" - node_details += ( - f"Syft Version: {self.metadata.syft_version}
" - ) + if self.metadata is not None: + node_details += f"Node Type: {self.metadata.node_type.capitalize()}
" + node_side_type = ( + "Low Side" + if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value + else "High Side" + ) + node_details += f"Node Side Type: {node_side_type}
" + node_details += ( + f"Syft Version: {self.metadata.syft_version}
" + ) return f"""