diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 78d7205afb6..77995cb5a74 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -172,8 +172,8 @@ repos:
- id: mypy
name: "mypy: syft"
always_run: true
- files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service"
- # files: "^packages/syft/src/syft/"
+ files: "^packages/syft/src/syft/"
+ exclude: "packages/syft/src/syft/types/dicttuple.py|^packages/syft/src/syft/service/action/action_graph.py|^packages/syft/src/syft/external/oblv/"
args: [
"--follow-imports=skip",
"--ignore-missing-imports",
diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py
index d438f9ef9bd..2a0fcfa5b6d 100644
--- a/packages/syft/src/syft/__init__.py
+++ b/packages/syft/src/syft/__init__.py
@@ -93,6 +93,9 @@
logger.start()
try:
+ # third party
+ from IPython import get_ipython
+
get_ipython() # noqa: F821
# TODO: add back later or auto detect
# display(
diff --git a/packages/syft/src/syft/abstract_node.py b/packages/syft/src/syft/abstract_node.py
index 2341d6e4926..046c7e493ff 100644
--- a/packages/syft/src/syft/abstract_node.py
+++ b/packages/syft/src/syft/abstract_node.py
@@ -2,12 +2,17 @@
from enum import Enum
from typing import Callable
from typing import Optional
+from typing import TYPE_CHECKING
from typing import Union
# relative
from .serde.serializable import serializable
from .types.uid import UID
+if TYPE_CHECKING:
+ # relative
+ from .service.service import AbstractService
+
@serializable()
class NodeType(str, Enum):
@@ -37,5 +42,5 @@ class AbstractNode:
node_side_type: Optional[NodeSideType]
in_memory_workers: bool
- def get_service(self, path_or_func: Union[str, Callable]) -> Callable:
+ def get_service(self, path_or_func: Union[str, Callable]) -> "AbstractService":
raise NotImplementedError
diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py
index df948df96ca..5821550b2b9 100644
--- a/packages/syft/src/syft/client/api.py
+++ b/packages/syft/src/syft/client/api.py
@@ -16,6 +16,7 @@
from typing import Tuple
from typing import Union
from typing import _GenericAlias
+from typing import cast
from typing import get_args
from typing import get_origin
@@ -241,8 +242,8 @@ def __ipython_inspector_signature_override__(self) -> Optional[Signature]:
return self.signature
def prepare_args_and_kwargs(
- self, args: List[Any], kwargs: Dict[str, Any]
- ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]:
+ self, args: Union[list, tuple], kwargs: dict[str, Any]
+ ) -> Union[SyftError, tuple[tuple, dict[str, Any]]]:
# Validate and migrate args and kwargs
res = validate_callable_args_and_kwargs(args, kwargs, self.signature)
if isinstance(res, SyftError):
@@ -279,7 +280,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
api_call = SyftAPICall(
node_uid=self.node_uid,
path=self.path,
- args=_valid_args,
+ args=list(_valid_args),
kwargs=_valid_kwargs,
blocking=blocking,
)
@@ -304,8 +305,8 @@ class RemoteUserCodeFunction(RemoteFunction):
api: SyftAPI
def prepare_args_and_kwargs(
- self, args: List[Any], kwargs: Dict[str, Any]
- ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]:
+ self, args: Union[list, tuple], kwargs: Dict[str, Any]
+ ) -> Union[SyftError, tuple[tuple, dict[str, Any]]]:
# relative
from ..service.action.action_object import convert_to_pointers
@@ -506,9 +507,12 @@ def _repr_html_(self) -> Any:
results = self.get_all()
return results._repr_html_()
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ return NotImplementedError
+
def debox_signed_syftapicall_response(
- signed_result: SignedSyftAPICall,
+ signed_result: Union[SignedSyftAPICall, Any],
) -> Union[Any, SyftError]:
if not isinstance(signed_result, SignedSyftAPICall):
return SyftError(message="The result is not signed")
@@ -825,16 +829,16 @@ def build_endpoint_tree(
)
@property
- def services(self) -> Optional[APIModule]:
+ def services(self) -> APIModule:
if self.api_module is None:
self.generate_endpoints()
- return self.api_module
+ return cast(APIModule, self.api_module)
@property
- def lib(self) -> Optional[APIModule]:
+ def lib(self) -> APIModule:
if self.libs is None:
self.generate_endpoints()
- return self.libs
+ return cast(APIModule, self.libs)
def has_service(self, service_name: str) -> bool:
return hasattr(self.services, service_name)
@@ -940,19 +944,21 @@ class NodeIdentity(Identity):
node_name: str
@staticmethod
- def from_api(api: SyftAPI) -> Optional[NodeIdentity]:
+ def from_api(api: SyftAPI) -> NodeIdentity:
# stores the name root verify key of the domain node
- if api.connection is not None:
- node_metadata = api.connection.get_node_metadata(api.signing_key)
- return NodeIdentity(
- node_name=node_metadata.name,
- node_id=api.node_uid,
- verify_key=SyftVerifyKey.from_string(node_metadata.verify_key),
- )
- return None
+ if api.connection is None:
+ raise ValueError("{api}'s connection is None. Can't get the node identity")
+ node_metadata = api.connection.get_node_metadata(api.signing_key)
+ return NodeIdentity(
+ node_name=node_metadata.name,
+ node_id=api.node_uid,
+ verify_key=SyftVerifyKey.from_string(node_metadata.verify_key),
+ )
@classmethod
def from_change_context(cls, context: ChangeContext) -> NodeIdentity:
+ if context.node is None:
+ raise ValueError(f"{context}'s node is None")
return cls(
node_name=context.node.name,
node_id=context.node.id,
diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py
index 5d4228c1ce4..505b780f318 100644
--- a/packages/syft/src/syft/client/client.py
+++ b/packages/syft/src/syft/client/client.py
@@ -13,7 +13,6 @@
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
-from typing import Tuple
from typing import Type
from typing import Union
from typing import cast
@@ -25,7 +24,7 @@
from requests import Response
from requests import Session
from requests.adapters import HTTPAdapter
-from requests.packages.urllib3.util.retry import Retry
+from requests.packages.urllib3.util.retry import Retry # type: ignore[import-untyped]
from typing_extensions import Self
# relative
@@ -94,9 +93,9 @@ def forward_message_to_proxy(
proxy_target_uid: UID,
path: str,
credentials: Optional[SyftSigningKey] = None,
- args: Optional[Tuple] = None,
+ args: Optional[list] = None,
kwargs: Optional[Dict] = None,
-):
+) -> Union[Any, SyftError]:
kwargs = {} if kwargs is None else kwargs
args = [] if args is None else args
call = SyftAPICall(
@@ -155,7 +154,7 @@ def get_cache_key(self) -> str:
def api_url(self) -> GridURL:
return self.url.with_path(self.routes.ROUTE_API_CALL.value)
- def to_blob_route(self, path: str, **kwargs) -> GridURL:
+ def to_blob_route(self, path: str, **kwargs: Any) -> GridURL:
_path = self.routes.ROUTE_BLOB_STORE.value + path
return self.url.with_path(_path)
@@ -347,7 +346,7 @@ def get_node_metadata(self, credentials: SyftSigningKey) -> NodeMetadataJSON:
else:
return self.node.metadata.to(NodeMetadataJSON)
- def to_blob_route(self, path: str, host=None) -> GridURL:
+ def to_blob_route(self, path: str, host: Optional[str] = None) -> GridURL:
# TODO: FIX!
if host is not None:
return GridURL(host_or_ip=host, port=8333).with_path(path)
@@ -474,8 +473,8 @@ def __init__(
self.metadata = metadata
self.credentials: Optional[SyftSigningKey] = credentials
self._api = api
- self.communication_protocol = None
- self.current_protocol = None
+ self.communication_protocol: Optional[Union[int, str]] = None
+ self.current_protocol: Optional[Union[int, str]] = None
self.post_init()
@@ -485,7 +484,7 @@ def get_env(self) -> str:
def post_init(self) -> None:
if self.metadata is None:
self._fetch_node_metadata(self.credentials)
-
+ self.metadata = cast(NodeMetadataJSON, self.metadata)
self.communication_protocol = self._get_communication_protocol(
self.metadata.supported_protocols
)
@@ -528,7 +527,8 @@ def create_project(
project = project_create.start()
return project
- def sync_code_from_request(self, request):
+ # TODO: type of request should be REQUEST, but it will give circular import error
+ def sync_code_from_request(self, request: Any) -> Union[SyftSuccess, SyftError]:
# relative
from ..service.code.user_code import UserCode
from ..service.code.user_code import UserCodeStatusCollection
@@ -542,8 +542,11 @@ def sync_code_from_request(self, request):
code.node_uid = self.id
code.user_verify_key = self.verify_key
- def get_nested_codes(code: UserCode):
- result = []
+ def get_nested_codes(code: UserCode) -> list[UserCode]:
+ result: list[UserCode] = []
+ if code.nested_codes is None:
+ return result
+
for _, (linked_code_obj, _) in code.nested_codes.items():
nested_code = linked_code_obj.resolve
nested_code = deepcopy(nested_code)
@@ -552,11 +555,6 @@ def get_nested_codes(code: UserCode):
result.append(nested_code)
result += get_nested_codes(nested_code)
- updated_code_links = {
- nested_code.service_func_name: (LinkedObject.from_obj(nested_code), {})
- for nested_code in result
- }
- code.nested_codes = updated_code_links
return result
def get_code_statusses(codes: List[UserCode]) -> List[UserCodeStatusCollection]:
@@ -607,7 +605,7 @@ def verify_key(self) -> SyftVerifyKey:
@classmethod
def from_url(cls, url: Union[str, GridURL]) -> Self:
- return cls(connection=HTTPConnection(GridURL.from_url(url)))
+ return cls(connection=HTTPConnection(url=GridURL.from_url(url)))
@classmethod
def from_node(cls, node: AbstractNode) -> Self:
@@ -641,8 +639,7 @@ def api(self) -> SyftAPI:
# invalidate API
if self._api is None or (self._api.signing_key != self.credentials):
self._fetch_api(self.credentials)
-
- return self._api
+ return cast(SyftAPI, self._api) # we are sure self._api is not None after fetch
def guest(self) -> Self:
return self.__class__(
@@ -657,7 +654,8 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]:
self_node_route = connection_to_route(self.connection)
remote_node_route = connection_to_route(client.connection)
-
+ if client.metadata is None:
+ return SyftError(f"client {client}'s metadata is None!")
result = self.api.services.network.exchange_credentials_with(
self_node_route=self_node_route,
remote_node_route=remote_node_route,
@@ -715,10 +713,11 @@ def me(self) -> Optional[Union[UserView, SyftError]]:
def login_as_guest(self) -> Self:
_guest_client = self.guest()
- print(
- f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side "
- f"{self.metadata.node_type.capitalize()}> as GUEST"
- )
+ if self.metadata is not None:
+ print(
+ f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side "
+ f"{self.metadata.node_type.capitalize()}> as GUEST"
+ )
return _guest_client
@@ -762,11 +761,11 @@ def login(
client.__logged_in_user = email
- if user_private_key is not None:
+ if user_private_key is not None and client.users is not None:
client.__user_role = user_private_key.role
client.__logged_in_username = client.users.get_current_user().name
- if signing_key is not None:
+ if signing_key is not None and client.metadata is not None:
print(
f"Logged into <{client.name}: {client.metadata.node_side_type.capitalize()} side "
f"{client.metadata.node_type.capitalize()}> as <{email}>"
@@ -804,16 +803,18 @@ def login(
# relative
from ..node.node import CODE_RELOADER
- CODE_RELOADER[thread_ident()] = client._reload_user_code
+ thread_id = thread_ident()
+ if thread_id is not None:
+ CODE_RELOADER[thread_id] = client._reload_user_code
return client
- def _reload_user_code(self):
+ def _reload_user_code(self) -> None:
# relative
from ..service.code.user_code import load_approved_policy_code
user_code_items = self.code.get_all_for_user()
- load_approved_policy_code(user_code_items)
+ load_approved_policy_code(user_code_items=user_code_items, context=None)
def register(
self,
@@ -823,7 +824,7 @@ def register(
password_verify: Optional[str] = None,
institution: Optional[str] = None,
website: Optional[str] = None,
- ):
+ ) -> Optional[Union[SyftError, SyftSigningKey]]:
if not email:
email = input("Email: ")
if not password:
@@ -848,7 +849,10 @@ def register(
except Exception as e:
return SyftError(message=str(e))
- if self.metadata.node_side_type == NodeSideType.HIGH_SIDE.value:
+ if (
+ self.metadata
+ and self.metadata.node_side_type == NodeSideType.HIGH_SIDE.value
+ ):
message = (
"You're registering a user to a high side "
f"{self.metadata.node_type}, which could "
@@ -857,7 +861,7 @@ def register(
if self.metadata.show_warnings and not prompt_warning_message(
message=message
):
- return
+ return None
response = self.connection.register(new_user=new_user)
if isinstance(response, tuple):
@@ -894,16 +898,20 @@ def _fetch_node_metadata(self, credentials: SyftSigningKey) -> None:
metadata.check_version(__version__)
self.metadata = metadata
- def _fetch_api(self, credentials: SyftSigningKey):
+ def _fetch_api(self, credentials: SyftSigningKey) -> None:
_api: SyftAPI = self.connection.get_api(
credentials=credentials,
communication_protocol=self.communication_protocol,
)
- def refresh_callback():
+ def refresh_callback() -> None:
return self._fetch_api(self.credentials)
_api.refresh_api_callback = refresh_callback
+
+ if self.credentials is None:
+ raise ValueError(f"{self}'s credentials (signing key) is None!")
+
APIRegistry.set_api_for(
node_uid=self.id,
user_verify_key=self.credentials.verify_key,
@@ -943,7 +951,7 @@ def register(
password: str,
institution: Optional[str] = None,
website: Optional[str] = None,
-):
+) -> Optional[Union[SyftError, SyftSigningKey]]:
guest_client = connect(url=url, port=port)
return guest_client.register(
name=name,
@@ -960,13 +968,13 @@ def login_as_guest(
node: Optional[AbstractNode] = None,
port: Optional[int] = None,
verbose: bool = True,
-):
+) -> SyftClient:
_client = connect(url=url, node=node, port=port)
if isinstance(_client, SyftError):
return _client
- if verbose:
+ if verbose and _client.metadata is not None:
print(
f"Logged into <{_client.name}: {_client.metadata.node_side_type.capitalize()}-"
f"side {_client.metadata.node_type.capitalize()}> as GUEST"
@@ -1039,7 +1047,7 @@ def add_client(
password: str,
connection: NodeConnection,
syft_client: SyftClient,
- ):
+ ) -> None:
hash_key = cls._get_key(email, password, connection.get_cache_key())
cls.__credentials_store__[hash_key] = syft_client
cls.__client_cache__[syft_client.id] = syft_client
@@ -1050,7 +1058,7 @@ def add_client_by_uid_and_verify_key(
verify_key: SyftVerifyKey,
node_uid: UID,
syft_client: SyftClient,
- ):
+ ) -> None:
hash_key = str(node_uid) + str(verify_key)
cls.__client_cache__[hash_key] = syft_client
@@ -1067,8 +1075,8 @@ def get_client(
) -> Optional[SyftClient]:
# we have some bugs here so lets disable until they are fixed.
return None
- hash_key = cls._get_key(email, password, connection.get_cache_key())
- return cls.__credentials_store__.get(hash_key, None)
+ # hash_key = cls._get_key(email, password, connection.get_cache_key())
+ # return cls.__credentials_store__.get(hash_key, None)
@classmethod
def get_client_for_node_uid(cls, node_uid: UID) -> Optional[SyftClient]:
diff --git a/packages/syft/src/syft/client/connection.py b/packages/syft/src/syft/client/connection.py
index 5b9928c8355..a94cb1c0707 100644
--- a/packages/syft/src/syft/client/connection.py
+++ b/packages/syft/src/syft/client/connection.py
@@ -10,7 +10,7 @@ class NodeConnection(SyftObject):
__canonical_name__ = "NodeConnection"
__version__ = SYFT_OBJECT_VERSION_1
- def get_cache_key() -> str:
+ def get_cache_key(self) -> str:
raise NotImplementedError
def __repr__(self) -> str:
diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py
index 80679a49f81..35773178b56 100644
--- a/packages/syft/src/syft/client/domain_client.py
+++ b/packages/syft/src/syft/client/domain_client.py
@@ -4,11 +4,14 @@
# stdlib
from pathlib import Path
import re
+from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
+from typing import cast
# third party
+from hagrid.orchestra import NodeHandle
from loguru import logger
from tqdm import tqdm
@@ -26,8 +29,10 @@
from ..service.response import SyftSuccess
from ..service.sync.diff_state import ResolvedSyncState
from ..service.user.roles import Roles
+from ..service.user.user import UserView
from ..service.user.user_roles import ServiceRole
from ..types.blob_storage import BlobFile
+from ..types.syft_object import SyftObject
from ..types.uid import UID
from ..util.fonts import fonts_css
from ..util.util import get_mb_size
@@ -37,13 +42,14 @@
from .client import SyftClient
from .client import login
from .client import login_as_guest
+from .connection import NodeConnection
if TYPE_CHECKING:
# relative
from ..service.project.project import Project
-def _get_files_from_glob(glob_path: str) -> list:
+def _get_files_from_glob(glob_path: str) -> list[Path]:
files = Path().glob(glob_path)
return [f for f in files if f.is_file() and not f.name.startswith(".")]
@@ -61,7 +67,7 @@ def _contains_subdir(dir: Path) -> bool:
def add_default_uploader(
- user, obj: Union[CreateDataset, CreateAsset]
+ user: UserView, obj: Union[CreateDataset, CreateAsset]
) -> Union[CreateDataset, CreateAsset]:
uploader = None
for contributor in obj.contributors:
@@ -90,6 +96,9 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError
# relative
from ..types.twin_object import TwinObject
+ if self.users is None:
+ return SyftError(f"can't get user service for {self}")
+
user = self.users.get_current_user()
dataset = add_default_uploader(user, dataset)
for i in range(len(dataset.asset_list)):
@@ -97,9 +106,12 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError
dataset.asset_list[i] = add_default_uploader(user, asset)
dataset._check_asset_must_contain_mock()
- dataset_size = 0
+ dataset_size: float = 0.0
# TODO: Refactor so that object can also be passed to generate warnings
+
+ self.api.connection = cast(NodeConnection, self.api.connection)
+
metadata = self.api.connection.get_node_metadata(self.api.signing_key)
if (
@@ -134,18 +146,18 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError
dataset_size += get_mb_size(asset.data)
dataset.mb_size = dataset_size
valid = dataset.check()
- if valid.ok():
- return self.api.services.dataset.add(dataset=dataset)
- else:
- if len(valid.err()) > 0:
- return tuple(valid.err())
- return valid.err()
+ if isinstance(valid, SyftError):
+ return valid
+ return self.api.services.dataset.add(dataset=dataset)
- def create_actionobject(self, action_object):
+ def create_actionobject(self, action_object: SyftObject) -> None:
action_object = action_object.refresh_object()
action_object.send(self)
- def get_permissions_for_other_node(self, items):
+ def get_permissions_for_other_node(
+ self,
+ items: list[Union[ActionObject, SyftObject]],
+ ) -> dict:
if len(items) > 0:
if not len({i.syft_node_location for i in items}) == 1 or (
not len({i.syft_client_verify_key for i in items}) == 1
@@ -155,6 +167,10 @@ def get_permissions_for_other_node(self, items):
api = APIRegistry.api_for(
item.syft_node_location, item.syft_client_verify_key
)
+ if api is None:
+ raise ValueError(
+ f"Can't access the api. Please log in to {item.syft_node_location}"
+ )
return api.services.sync.get_permissions(items)
else:
return {}
@@ -168,7 +184,7 @@ def apply_state(
action_objects = [x for x in items if isinstance(x, ActionObject)]
# permissions = self.get_permissions_for_other_node(items)
- permissions = {}
+ permissions: dict[UID, set[str]] = {}
for p in resolved_state.new_permissions:
if p.uid in permissions:
permissions[p.uid].add(p.permission_string)
@@ -193,16 +209,17 @@ def apply_state(
def upload_files(
self,
file_list: Union[BlobFile, list[BlobFile], str, list[str], Path, list[Path]],
- allow_recursive=False,
- show_files=False,
+ allow_recursive: bool = False,
+ show_files: bool = False,
) -> Union[SyftSuccess, SyftError]:
if not file_list:
return SyftError(message="No files to upload")
if not isinstance(file_list, list):
- file_list = [file_list]
+ file_list = [file_list] # type: ignore[assignment]
+ file_list = cast(list, file_list)
- expanded_file_list = []
+ expanded_file_list: List[Union[BlobFile, Path]] = []
for file in file_list:
if isinstance(file, BlobFile):
@@ -263,7 +280,7 @@ def connect_to_gateway(
handle: Optional[NodeHandle] = None, # noqa: F821
email: Optional[str] = None,
password: Optional[str] = None,
- ) -> None:
+ ) -> Optional[Union[SyftSuccess, SyftError]]:
if via_client is not None:
client = via_client
elif handle is not None:
@@ -279,9 +296,12 @@ def connect_to_gateway(
res = self.exchange_route(client)
if isinstance(res, SyftSuccess):
- return SyftSuccess(
- message=f"Connected {self.metadata.node_type} to {client.name} gateway"
- )
+ if self.metadata:
+ return SyftSuccess(
+ message=f"Connected {self.metadata.node_type} to {client.name} gateway"
+ )
+ else:
+ return SyftSuccess(message=f"Connected to {client.name} gateway")
return res
@property
@@ -374,8 +394,8 @@ def output(self) -> Optional[APIModule]:
def get_project(
self,
- name: str = None,
- uid: UID = None,
+ name: Optional[str] = None,
+ uid: Optional[UID] = None,
) -> Optional[Project]:
"""Get project by name or UID"""
@@ -442,18 +462,17 @@ def _repr_html_(self) -> str:
url = getattr(self.connection, "url", None)
node_details = f"URL: {url}
" if url else ""
- node_details += (
- f"Node Type: {self.metadata.node_type.capitalize()}
"
- )
- node_side_type = (
- "Low Side"
- if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value
- else "High Side"
- )
- node_details += f"Node Side Type: {node_side_type}
"
- node_details += (
- f"Syft Version: {self.metadata.syft_version}
"
- )
+ if self.metadata is not None:
+ node_details += f"Node Type: {self.metadata.node_type.capitalize()}
"
+ node_side_type = (
+ "Low Side"
+ if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value
+ else "High Side"
+ )
+ node_details += f"Node Side Type: {node_side_type}
"
+ node_details += (
+ f"Syft Version: {self.metadata.syft_version}
"
+ )
return f"""