From 09df3cf2c1f802a44dac2cc04e0a8f79f263a7f4 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 22 Feb 2024 11:25:28 +0700 Subject: [PATCH 01/35] [refactor] done fixing mypy issues of `syft/client` --- .pre-commit-config.yaml | 6 +-- packages/syft/src/syft/client/client.py | 45 ++++++++++--------- packages/syft/src/syft/client/connection.py | 2 +- .../syft/src/syft/client/domain_client.py | 10 +++-- .../syft/src/syft/client/enclave_client.py | 7 ++- .../syft/src/syft/client/gateway_client.py | 14 +++--- packages/syft/src/syft/client/registry.py | 12 ++--- packages/syft/src/syft/client/search.py | 6 +-- .../syft/service/code/user_code_service.py | 4 +- 9 files changed, 56 insertions(+), 50 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ad05bb12a3e..b94ec680e9e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,15 +172,14 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util/env.py|^packages/syft/src/syft/util/logger.py|^packages/syft/src/syft/util/markdown.py|^packages/syft/src/syft/util/notebook_ui/notebook_addons.py|^packages/syft/src/syft/img/base64.py|^packages/syft/src/syft/store/mongo_codecs.py|^packages/syft/src/syft/service/warnings.py|^packages/syft/src/syft/util/util.py|^packages/syft/src/syft/client/api.py|^packages/syft/src/syft/service/worker|^packages/syft/src/syft/service/user|^packages/syft/src/syft/service/dataset" - #files: "^packages/syft/src/syft/serde" + files: "^packages/syft/src/syft/client" args: [ "--follow-imports=skip", "--ignore-missing-imports", "--scripts-are-modules", "--disallow-incomplete-defs", "--no-implicit-optional", - "--warn-unused-ignores", + # "--warn-unused-ignores", "--warn-redundant-casts", "--strict-equality", "--warn-unreachable", @@ -190,6 +189,7 @@ repos: "--install-types", "--non-interactive", "--config-file=tox.ini", + "--disable-error-code=union-attr", # todo: remove this line after fixing the issue context.node can be None ] - repo: https://github.com/kynan/nbstripout diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 11b4b2ab9a0..39e9f887978 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -13,7 +13,6 @@ from typing import List from typing import Optional from typing import TYPE_CHECKING -from typing import Tuple from typing import Type from typing import Union from typing import cast @@ -25,7 +24,7 @@ from requests import Response from requests import Session from requests.adapters import HTTPAdapter -from requests.packages.urllib3.util.retry import Retry +from requests.packages.urllib3.util.retry import Retry # type: ignore[import-untyped] from typing_extensions import Self # relative @@ -94,9 +93,9 @@ def forward_message_to_proxy( proxy_target_uid: UID, path: str, credentials: Optional[SyftSigningKey] = None, - args: Optional[Tuple] = None, + args: Optional[list] = None, kwargs: Optional[Dict] = None, -): +) -> Union[Any, SyftError]: kwargs = {} if kwargs is None else kwargs args = [] if args is None else args call = SyftAPICall( @@ -155,7 +154,7 @@ def get_cache_key(self) -> str: def api_url(self) -> GridURL: return self.url.with_path(self.routes.ROUTE_API_CALL.value) - def to_blob_route(self, path: str, **kwargs) -> GridURL: + def to_blob_route(self, path: str, **kwargs: Any) -> GridURL: _path = self.routes.ROUTE_BLOB_STORE.value + path return self.url.with_path(_path) @@ -347,7 +346,7 @@ def get_node_metadata(self, credentials: SyftSigningKey) -> NodeMetadataJSON: else: return self.node.metadata.to(NodeMetadataJSON) - def to_blob_route(self, path: str, host=None) -> GridURL: + def to_blob_route(self, path: str, host: Optional[str] = None) -> GridURL: # TODO: FIX! if host is not None: return GridURL(host_or_ip=host, port=8333).with_path(path) @@ -474,7 +473,7 @@ def __init__( self.metadata = metadata self.credentials: Optional[SyftSigningKey] = credentials self._api = api - self.communication_protocol = None + self.communication_protocol: Optional[Union[int, str]] = None self.current_protocol = None self.post_init() @@ -528,7 +527,8 @@ def create_project( project = project_create.start() return project - def sync_code_from_request(self, request): + # TODO: type of request should be REQUEST, but it will give circular import error + def sync_code_from_request(self, request: Any) -> Union[SyftSuccess, SyftError]: # relative from ..service.code.user_code import UserCode from ..store.linked_obj import LinkedObject @@ -541,7 +541,7 @@ def sync_code_from_request(self, request): code.node_uid = self.id code.user_verify_key = self.verify_key - def get_nested_codes(code: UserCode): + def get_nested_codes(code: UserCode) -> list[UserCode]: result = [] for __, (linked_code_obj, _) in code.nested_codes.items(): nested_code = linked_code_obj.resolve @@ -551,7 +551,7 @@ def get_nested_codes(code: UserCode): result.append(nested_code) result += get_nested_codes(nested_code) - updated_code_links = { + updated_code_links: dict[str, tuple[LinkedObject, dict]] = { nested_code.service_func_name: (LinkedObject.from_obj(nested_code), {}) for nested_code in result } @@ -625,7 +625,8 @@ def api(self) -> SyftAPI: # invalidate API if self._api is None or (self._api.signing_key != self.credentials): self._fetch_api(self.credentials) - + if self._api is None: + raise ValueError(f"{self}'s api is None") return self._api def guest(self) -> Self: @@ -792,7 +793,7 @@ def login( return client - def _reload_user_code(self): + def _reload_user_code(self) -> None: # relative from ..service.code.user_code import load_approved_policy_code @@ -807,7 +808,7 @@ def register( password_verify: Optional[str] = None, institution: Optional[str] = None, website: Optional[str] = None, - ): + ) -> Optional[Union[SyftError, Any]]: if not email: email = input("Email: ") if not password: @@ -841,7 +842,7 @@ def register( if self.metadata.show_warnings and not prompt_warning_message( message=message ): - return + return None response = self.connection.register(new_user=new_user) if isinstance(response, tuple): @@ -878,13 +879,13 @@ def _fetch_node_metadata(self, credentials: SyftSigningKey) -> None: metadata.check_version(__version__) self.metadata = metadata - def _fetch_api(self, credentials: SyftSigningKey): + def _fetch_api(self, credentials: SyftSigningKey) -> None: _api: SyftAPI = self.connection.get_api( credentials=credentials, communication_protocol=self.communication_protocol, ) - def refresh_callback(): + def refresh_callback() -> None: return self._fetch_api(self.credentials) _api.refresh_api_callback = refresh_callback @@ -927,7 +928,7 @@ def register( password: str, institution: Optional[str] = None, website: Optional[str] = None, -): +) -> SyftClient: guest_client = connect(url=url, port=port) return guest_client.register( name=name, @@ -944,7 +945,7 @@ def login_as_guest( node: Optional[AbstractNode] = None, port: Optional[int] = None, verbose: bool = True, -): +) -> SyftClient: _client = connect(url=url, node=node, port=port) if isinstance(_client, SyftError): @@ -1023,7 +1024,7 @@ def add_client( password: str, connection: NodeConnection, syft_client: SyftClient, - ): + ) -> None: hash_key = cls._get_key(email, password, connection.get_cache_key()) cls.__credentials_store__[hash_key] = syft_client cls.__client_cache__[syft_client.id] = syft_client @@ -1034,7 +1035,7 @@ def add_client_by_uid_and_verify_key( verify_key: SyftVerifyKey, node_uid: UID, syft_client: SyftClient, - ): + ) -> None: hash_key = str(node_uid) + str(verify_key) cls.__client_cache__[hash_key] = syft_client @@ -1051,8 +1052,8 @@ def get_client( ) -> Optional[SyftClient]: # we have some bugs here so lets disable until they are fixed. return None - hash_key = cls._get_key(email, password, connection.get_cache_key()) - return cls.__credentials_store__.get(hash_key, None) + # hash_key = cls._get_key(email, password, connection.get_cache_key()) + # return cls.__credentials_store__.get(hash_key, None) @classmethod def get_client_for_node_uid(cls, node_uid: UID) -> Optional[SyftClient]: diff --git a/packages/syft/src/syft/client/connection.py b/packages/syft/src/syft/client/connection.py index 5b9928c8355..a94cb1c0707 100644 --- a/packages/syft/src/syft/client/connection.py +++ b/packages/syft/src/syft/client/connection.py @@ -10,7 +10,7 @@ class NodeConnection(SyftObject): __canonical_name__ = "NodeConnection" __version__ = SYFT_OBJECT_VERSION_1 - def get_cache_key() -> str: + def get_cache_key(self) -> str: raise NotImplementedError def __repr__(self) -> str: diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index d9d8cc3ae4e..805fd149807 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -9,6 +9,7 @@ from typing import Union # third party +from hagrid.orchestra import NodeHandle from loguru import logger from tqdm import tqdm @@ -25,6 +26,7 @@ from ..service.response import SyftError from ..service.response import SyftSuccess from ..service.user.roles import Roles +from ..service.user.user import UserView from ..service.user.user_roles import ServiceRole from ..types.blob_storage import BlobFile from ..types.uid import UID @@ -59,7 +61,7 @@ def _contains_subdir(dir: Path) -> bool: def add_default_uploader( - user, obj: Union[CreateDataset, CreateAsset] + user: UserView, obj: Union[CreateDataset, CreateAsset] ) -> Union[CreateDataset, CreateAsset]: uploader = None for contributor in obj.contributors: @@ -142,8 +144,8 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError def upload_files( self, file_list: Union[BlobFile, list[BlobFile], str, list[str], Path, list[Path]], - allow_recursive=False, - show_files=False, + allow_recursive: bool = False, + show_files: bool = False, ) -> Union[SyftSuccess, SyftError]: if not file_list: return SyftError(message="No files to upload") @@ -305,7 +307,7 @@ def worker_images(self) -> Optional[APIModule]: def get_project( self, - name: str = None, + name: Optional[str] = None, uid: UID = None, ) -> Optional[Project]: """Get project by name or UID""" diff --git a/packages/syft/src/syft/client/enclave_client.py b/packages/syft/src/syft/client/enclave_client.py index e0a09167805..07f8d93878b 100644 --- a/packages/syft/src/syft/client/enclave_client.py +++ b/packages/syft/src/syft/client/enclave_client.py @@ -2,8 +2,13 @@ from __future__ import annotations # stdlib +from typing import Any from typing import Optional from typing import TYPE_CHECKING +from typing import Union + +# third party +from hagrid.orchestra import NodeHandle # relative from ..abstract_node import NodeSideType @@ -92,7 +97,7 @@ def connect_to_gateway( def get_enclave_metadata(self) -> EnclaveMetadata: return EnclaveMetadata(route=self.connection.route) - def request_code_execution(self, code: SubmitUserCode): + def request_code_execution(self, code: SubmitUserCode) -> Union[Any, SyftError]: # relative from ..service.code.user_code_service import SubmitUserCode diff --git a/packages/syft/src/syft/client/gateway_client.py b/packages/syft/src/syft/client/gateway_client.py index 2a569cabd22..3f9f3e571a7 100644 --- a/packages/syft/src/syft/client/gateway_client.py +++ b/packages/syft/src/syft/client/gateway_client.py @@ -2,11 +2,9 @@ from typing import Any from typing import List from typing import Optional +from typing import Type from typing import Union -# third party -from typing_extensions import Self - # relative from ..abstract_node import NodeSideType from ..abstract_node import NodeType @@ -26,7 +24,7 @@ class GatewayClient(SyftClient): # TODO: add widget repr for gateway client - def proxy_to(self, peer: Any) -> Self: + def proxy_to(self, peer: Any) -> SyftClient: # relative from .domain_client import DomainClient from .enclave_client import EnclaveClient @@ -34,7 +32,7 @@ def proxy_to(self, peer: Any) -> Self: connection = self.connection.with_proxy(peer.id) metadata = connection.get_node_metadata(credentials=SyftSigningKey.generate()) if metadata.node_type == NodeType.DOMAIN.value: - client_type = DomainClient + client_type: Type[SyftClient] = DomainClient elif metadata.node_type == NodeType.ENCLAVE.value: client_type = EnclaveClient else: @@ -53,8 +51,8 @@ def proxy_client_for( name: str, email: Optional[str] = None, password: Optional[str] = None, - **kwargs, - ): + **kwargs: Any, + ) -> SyftClient: peer = None if self.api.has_service("network"): peer = self.api.services.network.get_peer_by_name(name=name) @@ -178,7 +176,7 @@ def _repr_html_(self) -> str: def __len__(self) -> int: return len(self.retrieve_nodes()) - def __getitem__(self, key: int): + def __getitem__(self, key: int) -> SyftClient: if not isinstance(key, int): raise SyftException(f"Key: {key} must be an integer") diff --git a/packages/syft/src/syft/client/registry.py b/packages/syft/src/syft/client/registry.py index 95a8c470cb3..7669cd69f2e 100644 --- a/packages/syft/src/syft/client/registry.py +++ b/packages/syft/src/syft/client/registry.py @@ -112,7 +112,7 @@ def __repr__(self) -> str: return pd.DataFrame(on).to_string() @staticmethod - def create_client(network: Dict[str, Any]) -> Client: # type: ignore + def create_client(network: Dict[str, Any]) -> Client: # relative from ..client.client import connect @@ -127,7 +127,7 @@ def create_client(network: Dict[str, Any]) -> Client: # type: ignore error(f"Failed to login with: {network}. {e}") raise SyftException(f"Failed to login with: {network}. {e}") - def __getitem__(self, key: Union[str, int]) -> Client: # type: ignore + def __getitem__(self, key: Union[str, int]) -> Client: if isinstance(key, int): return self.create_client(network=self.online_networks[key]) else: @@ -266,14 +266,14 @@ def __repr__(self) -> str: return "(no domains online - try syft.domains.all_domains to see offline domains)" return pd.DataFrame(on).to_string() - def create_client(self, peer: NodePeer) -> Client: # type: ignore + def create_client(self, peer: NodePeer) -> Client: try: return peer.guest_client except Exception as e: error(f"Failed to login to: {peer}. {e}") raise SyftException(f"Failed to login to: {peer}. {e}") - def __getitem__(self, key: Union[str, int]) -> Client: # type: ignore + def __getitem__(self, key: Union[str, int]) -> Client: if isinstance(key, int): return self.create_client(self.online_domains[key][0]) else: @@ -360,7 +360,7 @@ def __repr__(self) -> str: return pd.DataFrame(on).to_string() @staticmethod - def create_client(enclave: Dict[str, Any]) -> Client: # type: ignore + def create_client(enclave: Dict[str, Any]) -> Client: # relative from ..client.client import connect @@ -375,7 +375,7 @@ def create_client(enclave: Dict[str, Any]) -> Client: # type: ignore error(f"Failed to login with: {enclave}. {e}") raise SyftException(f"Failed to login with: {enclave}. {e}") - def __getitem__(self, key: Union[str, int]) -> EnclaveClient: # type: ignore + def __getitem__(self, key: Union[str, int]) -> EnclaveClient: if isinstance(key, int): return self.create_client(enclave=self.online_enclaves[key]) else: diff --git a/packages/syft/src/syft/client/search.py b/packages/syft/src/syft/client/search.py index 66d46c5b3d4..45c5e8c6d34 100644 --- a/packages/syft/src/syft/client/search.py +++ b/packages/syft/src/syft/client/search.py @@ -64,9 +64,9 @@ def __search_one_node( peer, _ = peer_tuple client = peer.guest_client results = client.api.services.dataset.search(name=name) - return (client, results) + return (client, results) # type: ignore[return-value] except: # noqa - return (None, []) + return (None, []) # type: ignore[return-value] def __search(self, name: str) -> List[Tuple[SyftClient, List[Dataset]]]: results = [ @@ -75,7 +75,7 @@ def __search(self, name: str) -> List[Tuple[SyftClient, List[Dataset]]]: # filter out SyftError filtered = ((client, result) for client, result in results if result) - return filtered + return filtered # type: ignore[return-value] def search(self, name: str) -> SearchResults: return SearchResults(self.__search(name)) diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index da409b1dac0..3144a419108 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -126,7 +126,7 @@ def _request_code_execution( context: AuthedServiceContext, code: SubmitUserCode, reason: Optional[str] = "", - ): + ) -> Union[Request, SyftError]: user_code: UserCode = code.to(UserCode, context=context) return self._request_code_execution_inner(context, user_code, reason) @@ -135,7 +135,7 @@ def _request_code_execution_inner( context: AuthedServiceContext, user_code: UserCode, reason: Optional[str] = "", - ): + ) -> Union[Request, SyftError]: if not all( x in user_code.input_owner_verify_keys for x in user_code.output_readers ): From 81d06d411902f293bbdc7515c1203c747d0f4461 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 22 Feb 2024 15:16:31 +0700 Subject: [PATCH 02/35] [refactor] done fixing mypy issues for `syft/node` --- .pre-commit-config.yaml | 2 +- packages/syft/src/syft/node/node.py | 96 +++++++++++-------- packages/syft/src/syft/node/run.py | 5 +- packages/syft/src/syft/node/server.py | 8 +- .../syft/src/syft/node/worker_settings.py | 11 ++- 5 files changed, 72 insertions(+), 50 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b94ec680e9e..41cbd8bcaac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,7 +172,7 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/client" + files: "^packages/syft/src/syft/client|^packages/syft/src/syft/node" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index a9b161bc63f..e869c308708 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -18,7 +18,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Tuple from typing import Type from typing import Union import uuid @@ -39,6 +38,7 @@ from ..client.api import SyftAPICall from ..client.api import SyftAPIData from ..client.api import debox_signed_syftapicall_response +from ..client.client import SyftClient from ..exceptions.exception import PySyftException from ..external import OBLV from ..protocol.data_protocol import PROTOCOL_TYPE @@ -63,6 +63,7 @@ from ..service.enclave.enclave_service import EnclaveService from ..service.job.job_service import JobService from ..service.job.job_stash import Job +from ..service.job.job_stash import JobStash from ..service.log.log_service import LogService from ..service.metadata.metadata_service import MetadataService from ..service.metadata.node_metadata import NodeMetadataV3 @@ -101,9 +102,11 @@ from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME from ..service.worker.utils import create_default_image from ..service.worker.worker_image_service import SyftWorkerImageService +from ..service.worker.worker_pool import WorkerPool from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_pool_stash import SyftWorkerPoolStash from ..service.worker.worker_service import WorkerService +from ..service.worker.worker_stash import WorkerStash from ..store.blob_storage import BlobStorageConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig @@ -196,7 +199,7 @@ def get_default_worker_pool_name() -> str: return get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME) -def get_default_worker_pool_count(node) -> int: +def get_default_worker_pool_count(node: Node) -> int: return int( get_env( "DEFAULT_WORKER_POOL_COUNT", node.queue_config.client_config.n_consumers @@ -204,7 +207,7 @@ def get_default_worker_pool_count(node) -> int: ) -def in_kubernetes() -> Optional[str]: +def in_kubernetes() -> bool: return get_container_host() == "k8s" @@ -242,7 +245,7 @@ def get_syft_worker_uid() -> Optional[str]: class AuthNodeContextRegistry: - __node_context_registry__: Dict[Tuple, Node] = OrderedDict() + __node_context_registry__: Dict[str, Node] = OrderedDict() @classmethod def set_node_context( @@ -250,7 +253,7 @@ def set_node_context( node_uid: Union[UID, str], context: NodeServiceContext, user_verify_key: Union[SyftVerifyKey, str], - ): + ) -> None: if isinstance(node_uid, str): node_uid = UID.from_string(node_uid) @@ -290,9 +293,9 @@ def __init__( signing_key: Optional[Union[SyftSigningKey, SigningKey]] = None, action_store_config: Optional[StoreConfig] = None, document_store_config: Optional[StoreConfig] = None, - root_email: str = default_root_email, - root_username: str = default_root_username, - root_password: str = default_root_password, + root_email: Optional[str] = default_root_email, + root_username: Optional[str] = default_root_username, + root_password: Optional[str] = default_root_password, processes: int = 0, is_subprocess: bool = False, node_type: Union[str, NodeType] = NodeType.DOMAIN, @@ -394,7 +397,7 @@ def __init__( node=self, ) - self.client_cache = {} + self.client_cache: dict = {} if isinstance(node_type, str): node_type = NodeType(node_type) self.node_type = node_type @@ -425,7 +428,7 @@ def __init__( NodeRegistry.set_node_for(self.id, self) @property - def runs_in_docker(self): + def runs_in_docker(self) -> bool: path = "/proc/self/cgroup" return ( os.path.exists("/.dockerenv") @@ -457,14 +460,14 @@ def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None: remote_profile.profile_name ] = remote_profile - def stop(self): + def stop(self) -> None: for consumer_list in self.queue_manager.consumers.values(): for c in consumer_list: c.close() for p in self.queue_manager.producers.values(): p.close() - def close(self): + def close(self) -> None: self.stop() def create_queue_config( @@ -493,10 +496,10 @@ def create_queue_config( return queue_config_ - def init_queue_manager(self, queue_config: QueueConfig): + def init_queue_manager(self, queue_config: QueueConfig) -> None: MessageHandlers = [APICallMessageHandler] if self.is_subprocess: - return + return None self.queue_manager = QueueManager(config=queue_config) for message_handler in MessageHandlers: @@ -552,7 +555,7 @@ def add_consumer_for_service( syft_worker_id: UID, address: str, message_handler: AbstractMessageHandler = APICallMessageHandler, - ): + ) -> None: consumer: QueueConsumer = self.queue_manager.create_consumer( message_handler, address=address, @@ -664,7 +667,7 @@ def is_root(self, credentials: SyftVerifyKey) -> bool: return credentials == self.verify_key @property - def root_client(self): + def root_client(self) -> SyftClient: # relative from ..client.client import PythonConnection @@ -673,7 +676,8 @@ def root_client(self): if isinstance(client_type, SyftError): return client_type root_client = client_type(connection=connection, credentials=self.signing_key) - root_client.api.refresh_api_callback() + if root_client.api.refresh_api_callback is not None: + root_client.api.refresh_api_callback() return root_client def _find_klasses_pending_for_migration( @@ -707,7 +711,7 @@ def _find_klasses_pending_for_migration( return klasses_to_be_migrated - def find_and_migrate_data(self): + def find_and_migrate_data(self) -> None: # Track all object type that need migration for document store context = AuthedServiceContext( node=self, @@ -772,7 +776,7 @@ def find_and_migrate_data(self): print("Data Migrated to latest version !!!") @property - def guest_client(self): + def guest_client(self) -> SyftClient: return self.get_guest_client() @property @@ -780,7 +784,7 @@ def current_protocol(self) -> List: data_protocol = get_data_protocol() return data_protocol.latest_version - def get_guest_client(self, verbose: bool = True): + def get_guest_client(self, verbose: bool = True) -> SyftClient: # relative from ..client.client import PythonConnection @@ -798,7 +802,8 @@ def get_guest_client(self, verbose: bool = True): guest_client = client_type( connection=connection, credentials=SyftSigningKey.generate() ) - guest_client.api.refresh_api_callback() + if guest_client.api.refresh_api_callback is not None: + guest_client.api.refresh_api_callback() return guest_client def __repr__(self) -> str: @@ -840,7 +845,7 @@ def init_stores( self, document_store_config: Optional[StoreConfig] = None, action_store_config: Optional[StoreConfig] = None, - ): + ) -> None: if document_store_config is None: if self.local_db or (self.processes > 0 and not self.is_subprocess): client_config = SQLiteStoreClientConfig(path=self.sqlite_path) @@ -905,14 +910,14 @@ def init_stores( self.queue_stash = QueueStash(store=self.document_store) @property - def job_stash(self): + def job_stash(self) -> JobStash: return self.get_service("jobservice").stash @property - def worker_stash(self): + def worker_stash(self) -> WorkerStash: return self.get_service("workerservice").stash - def _construct_services(self): + def _construct_services(self) -> None: self.service_path_map = {} for service_klass in self.services: @@ -1135,7 +1140,7 @@ def handle_api_call( self, api_call: Union[SyftAPICall, SignedSyftAPICall], job_id: Optional[UID] = None, - check_call_location=True, + check_call_location: bool = True, ) -> Result[SignedSyftAPICall, Err]: # Get the result result = self.handle_api_call_with_unsigned_result( @@ -1150,7 +1155,7 @@ def handle_api_call_with_unsigned_result( self, api_call: Union[SyftAPICall, SignedSyftAPICall], job_id: Optional[UID] = None, - check_call_location=True, + check_call_location: bool = True, ) -> Result[Union[QueueItem, SyftObject], Err]: if self.required_signed_calls and isinstance(api_call, SyftAPICall): return SyftError( @@ -1212,12 +1217,12 @@ def handle_api_call_with_unsigned_result( def add_action_to_queue( self, - action, - credentials, - parent_job_id=None, + action: Action, + credentials: SyftVerifyKey, + parent_job_id: Optional[UID] = None, has_execute_permissions: bool = False, worker_pool_name: Optional[str] = None, - ): + ) -> Union[Job, SyftError]: job_id = UID() task_uid = UID() worker_settings = WorkerSettings.from_node(node=self) @@ -1267,8 +1272,12 @@ def add_action_to_queue( ) def add_queueitem_to_queue( - self, queue_item, credentials, action=None, parent_job_id=None - ): + self, + queue_item: QueueItem, + credentials: SyftVerifyKey, + action: Optional[Action] = None, + parent_job_id: Optional[UID] = None, + ) -> Union[Job, SyftError]: log_id = UID() role = self.get_role_for_credentials(credentials=credentials) context = AuthedServiceContext(node=self, credentials=credentials, role=role) @@ -1329,7 +1338,9 @@ def _is_usercode_call_on_owned_kwargs( user_code_service = self.get_service("usercodeservice") return user_code_service.is_execution_on_owned_args(api_call.kwargs, context) - def add_api_call_to_queue(self, api_call, parent_job_id=None): + def add_api_call_to_queue( + self, api_call: SyftAPICall, parent_job_id: Optional[UID] = None + ) -> Union[Job, SyftError]: unsigned_call = api_call if isinstance(api_call, SignedSyftAPICall): unsigned_call = api_call.message @@ -1416,7 +1427,7 @@ def pool_stash(self) -> SyftWorkerPoolStash: def user_code_stash(self) -> UserCodeStash: return self.get_service(UserCodeService).stash - def get_default_worker_pool(self): + def get_default_worker_pool(self) -> WorkerPool: result = self.pool_stash.get_by_name( credentials=self.verify_key, pool_name=get_default_worker_pool_name(), @@ -1481,6 +1492,7 @@ def create_initial_settings(self, admin_email: str) -> Optional[NodeSettingsV2]: return None except Exception as e: print("create_worker_metadata failed", e) + return None def create_admin_new( @@ -1521,6 +1533,8 @@ def create_admin_new( except Exception as e: print("Unable to create new admin", e) + return None + def create_oblv_key_pair( worker: Node, @@ -1544,6 +1558,9 @@ def create_oblv_key_pair( print(f"Using Existing Public/Private Key pair: {len(oblv_keys_stash)}") except Exception as e: print("Unable to create Oblv Keys.", e) + return None + + return None class NodeRegistry: @@ -1569,7 +1586,7 @@ def get_all_nodes(cls) -> List[Node]: return list(cls.__node_registry__.values()) -def get_default_worker_tag_by_env(dev_mode=False): +def get_default_worker_tag_by_env(dev_mode: bool = False) -> str: if in_kubernetes(): return get_default_worker_image() elif dev_mode: @@ -1617,7 +1634,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: if isinstance(result, SyftError): print("Failed to build default worker image: ", result.message) - return + return None # Create worker pool if it doesn't exists print( @@ -1650,7 +1667,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: if isinstance(result, SyftError): print(f"Default worker pool error. {result.message}") - return + return None for n in range(worker_to_add_): container_status = result[n] @@ -1659,6 +1676,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: f"Failed to create container: Worker: {container_status.worker}," f"Error: {container_status.error}" ) - return + return None print("Created default worker pool.") + return None diff --git a/packages/syft/src/syft/node/run.py b/packages/syft/src/syft/node/run.py index 3b8376fb4e1..10aa942a498 100644 --- a/packages/syft/src/syft/node/run.py +++ b/packages/syft/src/syft/node/run.py @@ -2,6 +2,9 @@ import argparse from typing import Optional +# third party +from hagrid.orchestra import NodeHandle + # relative from ..client.deploy import Orchestra @@ -14,7 +17,7 @@ def str_to_bool(bool_str: Optional[str]) -> bool: return result -def run(): +def run() -> Optional[NodeHandle]: parser = argparse.ArgumentParser() parser.add_argument("command", help="command: launch", type=str, default="none") parser.add_argument( diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index fc7bbb2bc8d..eabcacc49c1 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -81,7 +81,7 @@ def run_uvicorn( queue_port: Optional[int], create_producer: bool, n_consumers: int, -): +) -> None: async def _run_uvicorn( name: str, node_type: Enum, @@ -90,7 +90,7 @@ async def _run_uvicorn( reset: bool, dev_mode: bool, node_side_type: Enum, - ): + ) -> None: if node_type not in worker_classes: raise NotImplementedError(f"node_type: {node_type} is not supported") worker_class = worker_classes[node_type] @@ -205,7 +205,7 @@ def serve_node( ), ) - def stop(): + def stop() -> None: print(f"Stopping {name}") server_process.terminate() server_process.join(3) @@ -214,7 +214,7 @@ def stop(): server_process.kill() print("killed") - def start(): + def start() -> None: print(f"Starting {name} server on {host}:{port}") server_process.start() diff --git a/packages/syft/src/syft/node/worker_settings.py b/packages/syft/src/syft/node/worker_settings.py index 106e2e94821..509488ac5db 100644 --- a/packages/syft/src/syft/node/worker_settings.py +++ b/packages/syft/src/syft/node/worker_settings.py @@ -2,6 +2,7 @@ from __future__ import annotations # stdlib +from typing import Callable from typing import Optional # third party @@ -55,9 +56,9 @@ class WorkerSettings(SyftObject): blob_store_config: Optional[BlobStorageConfig] queue_config: Optional[QueueConfig] - @staticmethod - def from_node(node: AbstractNode) -> Self: - return WorkerSettings( + @classmethod + def from_node(cls, node: AbstractNode) -> Self: + return cls( id=node.id, name=node.name, node_type=node.node_type, @@ -74,14 +75,14 @@ def from_node(node: AbstractNode) -> Self: @migrate(WorkerSettings, WorkerSettingsV1) -def downgrade_workersettings_v2_to_v1(): +def downgrade_workersettings_v2_to_v1() -> list[Callable]: return [ drop(["queue_config"]), ] @migrate(WorkerSettingsV1, WorkerSettings) -def upgrade_workersettings_v1_to_v2(): +def upgrade_workersettings_v1_to_v2() -> list[Callable]: # relative from ..service.queue.zmq_queue import ZMQQueueConfig From 2604b1c49fdb3d5a8dc75a3f4ce8274ef8dca294 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 22 Feb 2024 16:32:00 +0700 Subject: [PATCH 03/35] [refactor] done fixing mypy issues of `syft/exceptions` --- .pre-commit-config.yaml | 4 +++- packages/syft/src/syft/exceptions/exception.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41cbd8bcaac..3132ed207bc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,7 +172,9 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/client|^packages/syft/src/syft/node" + files: "^packages/syft/src/syft/client|^packages/syft/src/syft/enclave|^packages/syft/src/syft/node|^packages/syft/src/syft/shylock|^packages/syft/src/syft/exceptions" + # files: "^packages/syft/src/syft/" + # files: "^packages/syft/src/syft/abstract_node.py" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/exceptions/exception.py b/packages/syft/src/syft/exceptions/exception.py index 6bab71a747a..16f1717686b 100644 --- a/packages/syft/src/syft/exceptions/exception.py +++ b/packages/syft/src/syft/exceptions/exception.py @@ -2,6 +2,9 @@ from typing import List from typing import Optional +# third party +from typing_extensions import Self + # relative from ..service.context import NodeServiceContext from ..service.response import SyftError @@ -16,7 +19,7 @@ def __init__(self, message: str, roles: Optional[List[ServiceRole]] = None): self.message = message self.roles = roles if roles else [ServiceRole.ADMIN] - def raise_with_context(self, context: NodeServiceContext): + def raise_with_context(self, context: NodeServiceContext) -> Self: self.context = context return self From 30ff73f8ff7c32df682d4467dac90bea21349e02 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Fri, 23 Feb 2024 09:35:19 +0700 Subject: [PATCH 04/35] [refactor] fix mypy issues of `syft/custom_worker` --- .pre-commit-config.yaml | 3 +-- .../syft/src/syft/custom_worker/builder.py | 5 +++-- .../src/syft/custom_worker/builder_docker.py | 12 +++++++----- .../src/syft/custom_worker/builder_k8s.py | 15 +++++++++------ .../src/syft/custom_worker/builder_types.py | 11 +++++++---- .../syft/src/syft/custom_worker/config.py | 12 +++++++----- packages/syft/src/syft/custom_worker/k8s.py | 19 ++++++++++++------- .../syft/src/syft/custom_worker/runner_k8s.py | 13 +++++++------ 8 files changed, 53 insertions(+), 37 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3132ed207bc..615858aede8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,9 +172,8 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/client|^packages/syft/src/syft/enclave|^packages/syft/src/syft/node|^packages/syft/src/syft/shylock|^packages/syft/src/syft/exceptions" + files: "^packages/syft/src/syft/ast|^packages/syft/src/syft/capnp|^packages/syft/src/syft/client|^packages/syft/src/syft/core|^packages/syft/src/syft/custom_worker|^packages/syft/src/syft/enclave|^packages/syft/src/syft/node|^packages/syft/src/syft/shylock|^packages/syft/src/syft/exceptions" # files: "^packages/syft/src/syft/" - # files: "^packages/syft/src/syft/abstract_node.py" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/custom_worker/builder.py b/packages/syft/src/syft/custom_worker/builder.py index 5e479cee71c..8109ac94b43 100644 --- a/packages/syft/src/syft/custom_worker/builder.py +++ b/packages/syft/src/syft/custom_worker/builder.py @@ -3,6 +3,7 @@ import os.path from pathlib import Path from typing import Any +from typing import Optional # relative from .builder_docker import DockerBuilder @@ -39,7 +40,7 @@ def builder(self) -> BuilderBase: def build_image( self, config: WorkerConfig, - tag: str = None, + tag: Optional[str] = None, **kwargs: Any, ) -> ImageBuildResult: """ @@ -82,7 +83,7 @@ def _build_dockerfile( self, config: DockerWorkerConfig, tag: str, - **kwargs, + **kwargs: Any, ) -> ImageBuildResult: return self.builder.build_image( dockerfile=config.dockerfile, diff --git a/packages/syft/src/syft/custom_worker/builder_docker.py b/packages/syft/src/syft/custom_worker/builder_docker.py index 3f7f16cf185..6b68d1e99c2 100644 --- a/packages/syft/src/syft/custom_worker/builder_docker.py +++ b/packages/syft/src/syft/custom_worker/builder_docker.py @@ -2,6 +2,7 @@ import contextlib import io from pathlib import Path +from typing import Any from typing import Iterable from typing import Optional @@ -22,11 +23,11 @@ class DockerBuilder(BuilderBase): def build_image( self, tag: str, - dockerfile: str = None, - dockerfile_path: Path = None, + dockerfile: Optional[str] = None, + dockerfile_path: Optional[Path] = None, buildargs: Optional[dict] = None, - **kwargs, - ): + **kwargs: Any, + ) -> ImageBuildResult: if dockerfile: # convert dockerfile string to file-like object kwargs["fileobj"] = io.BytesIO(dockerfile.encode("utf-8")) @@ -53,9 +54,10 @@ def build_image( def push_image( self, tag: str, - registry_url: str, username: str, password: str, + registry_url: str, + **kwargs: Any, ) -> ImagePushResult: with contextlib.closing(docker.from_env()) as client: if registry_url and username and password: diff --git a/packages/syft/src/syft/custom_worker/builder_k8s.py b/packages/syft/src/syft/custom_worker/builder_k8s.py index 1be16d3c0ac..24e494c7756 100644 --- a/packages/syft/src/syft/custom_worker/builder_k8s.py +++ b/packages/syft/src/syft/custom_worker/builder_k8s.py @@ -1,6 +1,7 @@ # stdlib from hashlib import sha256 from pathlib import Path +from typing import Any from typing import Dict from typing import List from typing import Optional @@ -33,16 +34,16 @@ class BuildFailed(Exception): class KubernetesBuilder(BuilderBase): COMPONENT = "builder" - def __init__(self): + def __init__(self) -> None: self.client = get_kr8s_client() def build_image( self, tag: str, - dockerfile: str = None, - dockerfile_path: Path = None, + dockerfile: Optional[str] = None, + dockerfile_path: Optional[Path] = None, buildargs: Optional[dict] = None, - **kwargs, + **kwargs: Any, ) -> ImageBuildResult: image_digest = None logs = None @@ -102,7 +103,7 @@ def push_image( username: str, password: str, registry_url: str, - **kwargs, + **kwargs: Any, ) -> ImagePushResult: exit_code = 1 logs = None @@ -354,7 +355,9 @@ def _create_push_job( ) return KubeUtils.create_or_get(job) - def _create_push_secret(self, id: str, url: str, username: str, password: str): + def _create_push_secret( + self, id: str, url: str, username: str, password: str + ) -> Secret: return KubeUtils.create_dockerconfig_secret( secret_name=f"push-secret-{id}", component=KubernetesBuilder.COMPONENT, diff --git a/packages/syft/src/syft/custom_worker/builder_types.py b/packages/syft/src/syft/custom_worker/builder_types.py index 8007bf476e9..9464bafced5 100644 --- a/packages/syft/src/syft/custom_worker/builder_types.py +++ b/packages/syft/src/syft/custom_worker/builder_types.py @@ -2,6 +2,7 @@ from abc import ABC from abc import abstractmethod from pathlib import Path +from typing import Any from typing import Optional # third party @@ -33,20 +34,22 @@ class ImagePushResult(BaseModel): class BuilderBase(ABC): @abstractmethod def build_image( + self, tag: str, - dockerfile: str = None, - dockerfile_path: Path = None, + dockerfile: Optional[str] = None, + dockerfile_path: Optional[Path] = None, buildargs: Optional[dict] = None, - **kwargs, + **kwargs: Any, ) -> ImageBuildResult: pass @abstractmethod def push_image( + self, tag: str, username: str, password: str, registry_url: str, - **kwargs, + **kwargs: Any, ) -> ImagePushResult: pass diff --git a/packages/syft/src/syft/custom_worker/config.py b/packages/syft/src/syft/custom_worker/config.py index c54d4f77c40..58f4b3f626b 100644 --- a/packages/syft/src/syft/custom_worker/config.py +++ b/packages/syft/src/syft/custom_worker/config.py @@ -57,7 +57,7 @@ class CustomBuildConfig(SyftBaseModel): @validator("python_packages") def validate_python_packages(cls, pkgs: List[str]) -> List[str]: for pkg in pkgs: - ver_parts = () + ver_parts: Union[tuple, list] = () name_ver = pkg.split("==") if len(name_ver) != 2: raise ValueError(_malformed_python_package_error_msg(pkg)) @@ -72,13 +72,13 @@ def validate_python_packages(cls, pkgs: List[str]) -> List[str]: return pkgs - def merged_python_pkgs(self, sep=" ") -> str: + def merged_python_pkgs(self, sep: str = " ") -> str: return sep.join(self.python_packages) - def merged_system_pkgs(self, sep=" ") -> str: + def merged_system_pkgs(self, sep: str = " ") -> str: return sep.join(self.system_packages) - def merged_custom_cmds(self, sep=";") -> str: + def merged_custom_cmds(self, sep: str = ";") -> str: return sep.join(self.custom_cmds) @@ -166,7 +166,9 @@ def __str__(self) -> str: def set_description(self, description_text: str) -> None: self.description = description_text - def test_image_build(self, tag: str, **kwargs) -> Union[SyftSuccess, SyftError]: + def test_image_build( + self, tag: str, **kwargs: Any + ) -> Union[SyftSuccess, SyftError]: try: with contextlib.closing(docker.from_env()) as client: if not client.ping(): diff --git a/packages/syft/src/syft/custom_worker/k8s.py b/packages/syft/src/syft/custom_worker/k8s.py index fb777f6ec6c..2c953ff1297 100644 --- a/packages/syft/src/syft/custom_worker/k8s.py +++ b/packages/syft/src/syft/custom_worker/k8s.py @@ -17,6 +17,7 @@ from kr8s.objects import Pod from kr8s.objects import Secret from pydantic import BaseModel +from typing_extensions import Self # Time after which Job will be deleted JOB_COMPLETION_TTL = 60 @@ -47,7 +48,7 @@ class PodCondition(BaseModel): ready: bool @classmethod - def from_conditions(cls, conditions: list): + def from_conditions(cls, conditions: list) -> Self: pod_cond = KubeUtils.list_dict_unpack(conditions, key="type", value="status") pod_cond_flags = {k: v == "True" for k, v in pod_cond.items()} return cls( @@ -67,7 +68,7 @@ class ContainerStatus(BaseModel): startedAt: Optional[str] # when running=True @classmethod - def from_status(cls, cstatus: dict): + def from_status(cls, cstatus: dict) -> Self: cstate = cstatus.get("state", {}) return cls( @@ -86,7 +87,7 @@ class PodStatus(BaseModel): container: ContainerStatus @classmethod - def from_status_dict(cls: "PodStatus", status: dict): + def from_status_dict(cls, status: dict) -> Self: return cls( phase=PodPhase(status.get("phase", "Unknown")), condition=PodCondition.from_conditions(status.get("conditions", [])), @@ -120,8 +121,10 @@ def resolve_pod(client: kr8s.Api, pod: Union[str, Pod]) -> Optional[Pod]: for _pod in client.get("pods", pod): return _pod + return None + @staticmethod - def get_logs(pods: List[Pod]): + def get_logs(pods: List[Pod]) -> str: """Combine and return logs for all the pods as string""" logs = [] for pod in pods: @@ -142,11 +145,13 @@ def get_pod_status(pod: Pod) -> Optional[PodStatus]: def get_pod_env(pod: Pod) -> Optional[List[Dict]]: """Return the environment variables of the first container in the pod.""" if not pod: - return + return None for container in pod.spec.containers: return container.env.to_list() + return None + @staticmethod def get_container_exit_code(pods: List[Pod]) -> List[int]: """Return the exit codes of all the containers in the given pods.""" @@ -203,11 +208,11 @@ def create_secret( type: str, component: str, data: str, - encoded=True, + encoded: bool = True, ) -> Secret: if not encoded: for k, v in data.items(): - data[k] = KubeUtils.b64encode_secret(v) + data[k] = KubeUtils.b64encode_secret(v) # type: ignore secret = Secret( { diff --git a/packages/syft/src/syft/custom_worker/runner_k8s.py b/packages/syft/src/syft/custom_worker/runner_k8s.py index 3b35830c0f4..25d3dbfd2a3 100644 --- a/packages/syft/src/syft/custom_worker/runner_k8s.py +++ b/packages/syft/src/syft/custom_worker/runner_k8s.py @@ -1,4 +1,5 @@ # stdlib +from typing import Any from typing import Dict from typing import List from typing import Optional @@ -21,7 +22,7 @@ class KubernetesRunner: - def __init__(self): + def __init__(self) -> None: self.client = get_kr8s_client() def create_pool( @@ -34,7 +35,7 @@ def create_pool( reg_username: Optional[str] = None, reg_password: Optional[str] = None, reg_url: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> StatefulSet: try: # create pull secret if registry credentials are passed @@ -134,8 +135,8 @@ def _create_image_pull_secret( reg_username: str, reg_password: str, reg_url: str, - **kwargs, - ): + **kwargs: Any, + ) -> Secret: return KubeUtils.create_dockerconfig_secret( secret_name=f"pull-secret-{pool_name}", component=pool_name, @@ -148,11 +149,11 @@ def _create_stateful_set( self, pool_name: str, tag: str, - replicas=1, + replicas: int = 1, env_vars: Optional[List[Dict]] = None, mount_secrets: Optional[Dict] = None, pull_secret: Optional[Secret] = None, - **kwargs, + **kwargs: Any, ) -> StatefulSet: """Create a stateful set for a pool""" From 0f557f40d1add0a683fc3513283f29a5a7d39ad5 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Fri, 23 Feb 2024 11:57:56 +0700 Subject: [PATCH 05/35] [refactor] fix mypy issues for `syft/store` --- .pre-commit-config.yaml | 2 +- .../src/syft/store/blob_storage/__init__.py | 45 ++++++++++----- .../src/syft/store/blob_storage/on_disk.py | 4 +- .../src/syft/store/blob_storage/seaweedfs.py | 26 ++++++--- .../src/syft/store/dict_document_store.py | 9 +-- .../syft/src/syft/store/document_store.py | 57 ++++++++++++------- .../syft/src/syft/store/kv_document_store.py | 20 +++---- packages/syft/src/syft/store/linked_obj.py | 5 +- packages/syft/src/syft/store/locks.py | 35 ++++++------ packages/syft/src/syft/store/mongo_client.py | 4 +- .../src/syft/store/mongo_document_store.py | 51 +++++++++-------- .../src/syft/store/sqlite_document_store.py | 14 ++--- 12 files changed, 163 insertions(+), 109 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 615858aede8..b5633960bab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,7 +172,7 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/ast|^packages/syft/src/syft/capnp|^packages/syft/src/syft/client|^packages/syft/src/syft/core|^packages/syft/src/syft/custom_worker|^packages/syft/src/syft/enclave|^packages/syft/src/syft/node|^packages/syft/src/syft/shylock|^packages/syft/src/syft/exceptions" + files: "^packages/syft/src/syft/ast|^packages/syft/src/syft/capnp|^packages/syft/src/syft/client|^packages/syft/src/syft/core|^packages/syft/src/syft/custom_worker|^packages/syft/src/syft/enclave|^packages/syft/src/syft/node|^packages/syft/src/syft/shylock|^packages/syft/src/syft/exceptions|^packages/syft/src/syft/grid|^packages/syft/src/syft/img|^packages/syft/src/syft/lib|^packages/syft/src/syft/store" # files: "^packages/syft/src/syft/" args: [ "--follow-imports=skip", diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index bb86b6ef25a..bc1fb1a9ac1 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -42,8 +42,12 @@ # stdlib +from io import BytesIO import os from pathlib import Path +from typing import Any +from typing import Callable +from typing import Generator from typing import Optional from typing import Type from typing import Union @@ -101,14 +105,14 @@ class BlobRetrieval(SyftObject): @migrate(BlobRetrieval, BlobRetrievalV1) -def downgrade_blobretrieval_v2_to_v1(): +def downgrade_blobretrieval_v2_to_v1() -> list[Callable]: return [ drop(["syft_blob_storage_entry_id", "file_size"]), ] @migrate(BlobRetrievalV1, BlobRetrieval) -def upgrade_blobretrieval_v1_to_v2(): +def upgrade_blobretrieval_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_blob_storage_entry_id", None), make_set_default("file_size", 1), @@ -131,7 +135,9 @@ class SyftObjectRetrieval(BlobRetrieval): syft_object: bytes path: Path - def _read_data(self, stream=False, _deserialize=True, **kwargs): + def _read_data( + self, stream: bool = False, _deserialize: bool = True, **kwargs: Any + ) -> Any: # development setup, we can access the same filesystem if os.access(self.path, os.R_OK) and self.path.is_file(): with open(self.path, "rb") as fp: @@ -151,19 +157,19 @@ def _read_data(self, stream=False, _deserialize=True, **kwargs): else: return res - def read(self, _deserialize=True) -> Union[SyftObject, SyftError]: + def read(self, _deserialize: bool = True) -> Union[SyftObject, SyftError]: return self._read_data(_deserialize=_deserialize) @migrate(SyftObjectRetrieval, SyftObjectRetrievalV2) -def downgrade_syftobjretrival_v3_to_v2(): +def downgrade_syftobjretrival_v3_to_v2() -> list[Callable]: return [ drop(["path"]), ] @migrate(SyftObjectRetrievalV2, SyftObjectRetrieval) -def upgrade_syftobjretrival_v2_to_v3(): +def upgrade_syftobjretrival_v2_to_v3() -> list[Callable]: return [ make_set_default("path", Path("")), ] @@ -177,8 +183,11 @@ class BlobRetrievalByURLV1(BlobRetrievalV1): def syft_iter_content( - blob_url, chunk_size, max_retries=MAX_RETRIES, timeout=DEFAULT_TIMEOUT -): + blob_url: Union[str, GridURL], + chunk_size: int, + max_retries: int = MAX_RETRIES, + timeout: int = DEFAULT_TIMEOUT, +) -> Generator: """custom iter content with smart retries (start from last byte read)""" current_byte = 0 for attempt in range(max_retries): @@ -231,7 +240,13 @@ def read(self) -> Union[SyftObject, SyftError]: else: return self._read_data() - def _read_data(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, *args, **kwargs): + def _read_data( + self, + stream: bool = False, + chunk_size: int = DEFAULT_CHUNK_SIZE, + *args: Any, + **kwargs: Any, + ) -> Any: # relative from ...client.api import APIRegistry @@ -262,14 +277,14 @@ def _read_data(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, *args, **kwarg @migrate(BlobRetrievalByURLV2, BlobRetrievalByURLV1) -def downgrade_blobretrivalbyurl_v2_to_v1(): +def downgrade_blobretrivalbyurl_v2_to_v1() -> list[Callable]: return [ drop(["syft_blob_storage_entry_id", "file_size"]), ] @migrate(BlobRetrievalByURLV1, BlobRetrievalByURLV2) -def upgrade_blobretrivalbyurl_v1_to_v2(): +def upgrade_blobretrivalbyurl_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_blob_storage_entry_id", None), make_set_default("file_size", 1), @@ -277,14 +292,14 @@ def upgrade_blobretrivalbyurl_v1_to_v2(): @migrate(BlobRetrievalByURL, BlobRetrievalByURLV2) -def downgrade_blobretrivalbyurl_v3_to_v2(): +def downgrade_blobretrivalbyurl_v3_to_v2() -> list[Callable]: return [ str_url_to_grid_url, ] @migrate(BlobRetrievalByURLV2, BlobRetrievalByURL) -def upgrade_blobretrivalbyurl_v2_to_v3(): +def upgrade_blobretrivalbyurl_v2_to_v3() -> list[Callable]: return [] @@ -295,7 +310,7 @@ class BlobDeposit(SyftObject): blob_storage_entry_id: UID - def write(self, data: bytes) -> Union[SyftSuccess, SyftError]: + def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: pass @@ -308,7 +323,7 @@ class BlobStorageConnection: def __enter__(self) -> Self: raise NotImplementedError - def __exit__(self, *exc) -> None: + def __exit__(self, *exc: Any) -> None: raise NotImplementedError def read(self, fp: SecureFilePathLocation, type_: Optional[Type]) -> BlobRetrieval: diff --git a/packages/syft/src/syft/store/blob_storage/on_disk.py b/packages/syft/src/syft/store/blob_storage/on_disk.py index 1ceebfdb129..a349de26702 100644 --- a/packages/syft/src/syft/store/blob_storage/on_disk.py +++ b/packages/syft/src/syft/store/blob_storage/on_disk.py @@ -53,11 +53,11 @@ def __init__(self, base_directory: Path) -> None: def __enter__(self) -> Self: return self - def __exit__(self, *exc) -> None: + def __exit__(self, *exc: Any) -> None: pass def read( - self, fp: SecureFilePathLocation, type_: Optional[Type], **kwargs + self, fp: SecureFilePathLocation, type_: Optional[Type], **kwargs: Any ) -> BlobRetrieval: file_path = self._base_directory / fp.path return SyftObjectRetrieval( diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 85524236a37..823fc27450f 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -3,7 +3,10 @@ import math from queue import Queue import threading +from typing import Any +from typing import Callable from typing import Dict +from typing import Generator from typing import List from typing import Optional from typing import Type @@ -103,10 +106,12 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: # read a chunk untill we have read part_size class PartGenerator: - def __init__(self): + def __init__(self) -> None: self.no_lines = 0 - def async_generator(self, chunk_size=DEFAULT_UPLOAD_CHUNK_SIZE): + def async_generator( + self, chunk_size: int = DEFAULT_UPLOAD_CHUNK_SIZE + ) -> Generator: item_queue: Queue = Queue() threading.Thread( target=self.add_chunks_to_queue, @@ -120,8 +125,10 @@ def async_generator(self, chunk_size=DEFAULT_UPLOAD_CHUNK_SIZE): item = item_queue.get() def add_chunks_to_queue( - self, queue, chunk_size=DEFAULT_UPLOAD_CHUNK_SIZE - ): + self, + queue: Queue, + chunk_size: int = DEFAULT_UPLOAD_CHUNK_SIZE, + ) -> None: """Creates a data geneator for the part""" n = 0 @@ -166,14 +173,14 @@ def add_chunks_to_queue( @migrate(SeaweedFSBlobDeposit, SeaweedFSBlobDepositV1) -def downgrade_seaweedblobdeposit_v2_to_v1(): +def downgrade_seaweedblobdeposit_v2_to_v1() -> list[Callable]: return [ drop(["size"]), ] @migrate(SeaweedFSBlobDepositV1, SeaweedFSBlobDeposit) -def upgrade_seaweedblobdeposit_v1_to_v2(): +def upgrade_seaweedblobdeposit_v1_to_v2() -> list[Callable]: return [ make_set_default("size", 1), ] @@ -240,11 +247,14 @@ def __init__( def __enter__(self) -> Self: return self - def __exit__(self, *exc) -> None: + def __exit__(self, *exc: Any) -> None: self.client.close() def read( - self, fp: SecureFilePathLocation, type_: Optional[Type], bucket_name=None + self, + fp: SecureFilePathLocation, + type_: Optional[Type], + bucket_name: Optional[str] = None, ) -> BlobRetrieval: if bucket_name is None: bucket_name = self.default_bucket_name diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py index 516a2fc85c5..6b5ee9bd8fe 100644 --- a/packages/syft/src/syft/store/dict_document_store.py +++ b/packages/syft/src/syft/store/dict_document_store.py @@ -18,11 +18,12 @@ @serializable() -class DictBackingStore(dict, KeyValueBackingStore): +class DictBackingStore(dict, KeyValueBackingStore): # type: ignore[misc] + # TODO: fix the mypy issue """Dictionary-based Store core logic""" def __init__(self, *args: Any, **kwargs: Any) -> None: - super(dict).__init__() + super().__init__() self._ddtype = kwargs.get("ddtype", None) def __getitem__(self, key: Any) -> Any: @@ -46,7 +47,7 @@ class DictStorePartition(KeyValueStorePartition): DictStore specific configuration """ - def prune(self): + def prune(self) -> None: self.init_store() @@ -71,7 +72,7 @@ def __init__( store_config = DictStoreConfig() super().__init__(root_verify_key=root_verify_key, store_config=store_config) - def reset(self): + def reset(self) -> None: for _, partition in self.partitions.items(): partition.prune() diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 5151f43294c..3cfe58bcbeb 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -59,7 +59,7 @@ def first_or_none(result: Any) -> Ok: if sys.version_info >= (3, 9): - def is_generic_alias(t: type): + def is_generic_alias(t: type) -> bool: return isinstance(t, (types.GenericAlias, typing._GenericAlias)) else: @@ -117,7 +117,7 @@ class PartitionKeys(BaseModel): pks: Union[PartitionKey, Tuple[PartitionKey, ...], List[PartitionKey]] @property - def all(self) -> List[PartitionKey]: + def all(self) -> Union[tuple[PartitionKey, ...], list[PartitionKey]]: # make sure we always return a list even if there's a single value return self.pks if isinstance(self.pks, (tuple, list)) else [self.pks] @@ -170,8 +170,9 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: # object has a method for getting these types # we can't use properties because we don't seem to be able to get the # return types - if isinstance(pk_value, (types.FunctionType, types.MethodType)): - pk_value = pk_value() + # TODO: fix the mypy issue + if isinstance(pk_value, (types.FunctionType, types.MethodType)): # type: ignore[unreachable] + pk_value = pk_value() # type: ignore[unreachable] if pk_value and not isinstance(pk_value, pk_type): raise Exception( @@ -180,11 +181,11 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: return QueryKey(key=pk_key, type_=pk_type, value=pk_value) @property - def as_dict(self): + def as_dict(self) -> dict[str, Any]: return {self.key: self.value} @property - def as_dict_mongo(self): + def as_dict_mongo(self) -> dict[str, Any]: key = self.key if key == "id": key = "_id" @@ -199,7 +200,7 @@ class PartitionKeysWithUID(PartitionKeys): uid_pk: PartitionKey @property - def all(self) -> List[PartitionKey]: + def all(self) -> Union[tuple[PartitionKey, ...], list[PartitionKey]]: all_keys = self.pks if isinstance(self.pks, (tuple, list)) else [self.pks] if self.uid_pk not in all_keys: all_keys.insert(0, self.uid_pk) @@ -211,7 +212,7 @@ class QueryKeys(SyftBaseModel): qks: Union[QueryKey, Tuple[QueryKey, ...], List[QueryKey]] @property - def all(self) -> List[QueryKey]: + def all(self) -> Union[tuple[QueryKey, ...], list[QueryKey]]: # make sure we always return a list even if there's a single value return self.qks if isinstance(self.qks, (tuple, list)) else [self.qks] @@ -260,7 +261,7 @@ def from_dict(qks_dict: Dict[str, Any]) -> QueryKeys: return QueryKeys(qks=qks) @property - def as_dict(self): + def as_dict(self) -> dict: qk_dict = {} for qk in self.all: qk_key = qk.key @@ -269,7 +270,7 @@ def as_dict(self): return qk_dict @property - def as_dict_mongo(self): + def as_dict_mongo(self) -> dict: qk_dict = {} for qk in self.all: qk_key = qk.key @@ -316,7 +317,7 @@ class StorePartition: def __init__( self, - root_verify_key: SyftVerifyKey, + root_verify_key: Optional[SyftVerifyKey], settings: PartitionSettings, store_config: StoreConfig, ) -> None: @@ -352,7 +353,9 @@ def store_query_keys(self, objs: Any) -> QueryKeys: return QueryKeys(qks=[self.store_query_key(obj) for obj in objs]) # Thread-safe methods - def _thread_safe_cbk(self, cbk: Callable, *args, **kwargs): + def _thread_safe_cbk( + self, cbk: Callable, *args: Any, **kwargs: Any + ) -> Union[Any, Err]: locked = self.lock.acquire(blocking=True) if not locked: print("FAILED TO LOCK") @@ -423,7 +426,7 @@ def update( credentials: SyftVerifyKey, qk: QueryKey, obj: SyftObject, - has_permission=False, + has_permission: bool = False, ) -> Result[SyftObject, str]: return self._thread_safe_cbk( self._update, @@ -444,7 +447,7 @@ def get_all_from_store( ) def delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission=False + self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False ) -> Result[SyftSuccess, Err]: return self._thread_safe_cbk( self._delete, credentials, qk, has_permission=has_permission @@ -475,12 +478,21 @@ def migrate_data( # These methods are called from the public thread-safe API, and will hang the process. def _set( self, + credentials: SyftVerifyKey, obj: SyftObject, + add_permissions: Optional[List[ActionObjectPermission]] = None, ignore_duplicates: bool = False, ) -> Result[SyftObject, str]: raise NotImplementedError - def _update(self, qk: QueryKey, obj: SyftObject) -> Result[SyftObject, str]: + def _update( + self, + credentials: SyftVerifyKey, + qk: QueryKey, + obj: SyftObject, + has_permission: bool = False, + overwrite: bool = False, + ) -> Result[SyftObject, str]: raise NotImplementedError def _get_all_from_store( @@ -491,10 +503,17 @@ def _get_all_from_store( ) -> Result[List[SyftObject], str]: raise NotImplementedError - def _delete(self, qk: QueryKey) -> Result[SyftSuccess, Err]: + def _delete( + self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False + ) -> Result[SyftSuccess, Err]: raise NotImplementedError - def _all(self) -> Result[List[BaseStash.object_type], str]: + def _all( + self, + credentials: SyftVerifyKey, + order_by: Optional[PartitionKey] = None, + has_permission: Optional[bool] = False, + ) -> Result[List[BaseStash.object_type], str]: raise NotImplementedError def add_permission(self, permission: ActionObjectPermission) -> None: @@ -688,7 +707,7 @@ def find_and_delete( return self.delete(credentials=credentials, qk=qk) def delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission=False + self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False ) -> Result[SyftSuccess, Err]: return self.partition.delete( credentials=credentials, qk=qk, has_permission=has_permission @@ -698,7 +717,7 @@ def update( self, credentials: SyftVerifyKey, obj: BaseStash.object_type, - has_permission=False, + has_permission: bool = False, ) -> Result[BaseStash.object_type, str]: qk = self.partition.store_query_key(obj) return self.partition.update( diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 4c9ad60b122..6b2a5e21eda 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -59,16 +59,16 @@ def __repr__(self) -> str: def __len__(self) -> int: raise NotImplementedError - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: raise NotImplementedError - def clear(self) -> Self: + def clear(self) -> None: raise NotImplementedError def copy(self) -> Self: raise NotImplementedError - def update(self, *args: Any, **kwargs: Any) -> Self: + def update(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError def keys(self) -> Any: @@ -248,7 +248,7 @@ def add_permission(self, permission: ActionObjectPermission) -> None: permissions.add(permission.permission_string) self.permissions[permission.uid] = permissions - def remove_permission(self, permission: ActionObjectPermission): + def remove_permission(self, permission: ActionObjectPermission) -> None: permissions = self.permissions[permission.uid] permissions.remove(permission.permission_string) self.permissions[permission.uid] = permissions @@ -370,8 +370,8 @@ def _update( credentials: SyftVerifyKey, qk: QueryKey, obj: SyftObject, - has_permission=False, - overwrite=False, + has_permission: bool = False, + overwrite: bool = False, ) -> Result[SyftObject, str]: try: if qk.value not in self.data: @@ -449,7 +449,7 @@ def create(self, obj: SyftObject) -> Result[SyftObject, str]: pass def _delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission=False + self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False ) -> Result[SyftSuccess, Err]: try: if has_permission or self.has_permission( @@ -486,9 +486,9 @@ def _delete_search_keys_for(self, obj: SyftObject) -> Result[SyftSuccess, str]: def _get_keys_index(self, qks: QueryKeys) -> Result[Set[Any], str]: try: # match AND - subsets = [] + subsets: list = [] for qk in qks.all: - subset = {} + subset: set = set() pk_key, pk_value = qk.key, qk.value if pk_key not in self.unique_keys: return Err(f"Failed to query index with {qk}") @@ -515,7 +515,7 @@ def _find_keys_search(self, qks: QueryKeys) -> Result[Set[QueryKey], str]: # match AND subsets = [] for qk in qks.all: - subset = {} + subset: set = set() pk_key, pk_value = qk.key, qk.value if pk_key not in self.searchable_keys: return Err(f"Failed to search with {qk}") diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index 97611d56b64..a5f979da9de 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -128,14 +128,15 @@ def with_context( object_uid=object_uid, ) - @staticmethod + @classmethod def from_uid( + cls, object_uid: UID, object_type: Type[SyftObject], service_type: Type[Any], node_uid: UID, ) -> Self: - return LinkedObject( + return cls( node_uid=node_uid, service_type=service_type, object_type=object_type, diff --git a/packages/syft/src/syft/store/locks.py b/packages/syft/src/syft/store/locks.py index a32bcd67c8d..98c12e08f55 100644 --- a/packages/syft/src/syft/store/locks.py +++ b/packages/syft/src/syft/store/locks.py @@ -5,14 +5,17 @@ from pathlib import Path import threading import time +from typing import Any from typing import Callable from typing import Dict from typing import Optional +from typing import Union import uuid # third party from pydantic import BaseModel import redis +from redis.client import Redis from sherlock.lock import BaseLock from sherlock.lock import FileLock from sherlock.lock import RedisLock @@ -94,13 +97,13 @@ class ThreadingLock(BaseLock): Threading-based Lock. Used to provide the same API as the rest of the locks. """ - def __init__(self, expire: int, **kwargs): + def __init__(self, expire: int, **kwargs: Any) -> None: self.expire = expire - self.locked_timestamp = 0 + self.locked_timestamp: float = 0.0 self.lock = threading.Lock() @property - def _locked(self): + def _locked(self) -> bool: """ Implementation of method to check if lock has been acquired. Must be :returns: if the lock is acquired or not @@ -116,7 +119,7 @@ def _locked(self): return self.lock.locked() - def _acquire(self): + def _acquire(self) -> bool: """ Implementation of acquiring a lock in a non-blocking fashion. :returns: if the lock was successfully acquired or not @@ -137,7 +140,7 @@ def _acquire(self): self.locked_timestamp = time.time() return status - def _release(self): + def _release(self) -> None: """ Implementation of releasing an acquired lock. """ @@ -166,7 +169,7 @@ class PatchedFileLock(FileLock): """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self._lock_file_enabled = True try: super().__init__(*args, **kwargs) @@ -203,7 +206,7 @@ def _thread_safe_cbk(self, cbk: Callable) -> bool: def _acquire(self) -> bool: return self._thread_safe_cbk(self._acquire_file_lock) - def _release(self) -> None: + def _release(self) -> bool: res = self._thread_safe_cbk(self._release_file_lock) return res @@ -245,12 +248,12 @@ def _acquire_file_lock(self) -> bool: self._data_file.write_text(json.dumps(data)) # We succeeded in writing to the file so we now hold the lock. - self._owner = owner + self._owner: Optional[str] = owner return True @property - def _locked(self): + def _locked(self) -> bool: if self._lock_py_thread.locked(): return True @@ -340,7 +343,7 @@ def __init__(self, config: LockingConfig): elif isinstance(config, ThreadingLockingConfig): self._lock = ThreadingLock(**base_params) elif isinstance(config, FileLockingConfig): - client = config.client_path + client: Optional[Union[Path, Redis]] = config.client_path self._lock = PatchedFileLock( **base_params, client=client, @@ -356,7 +359,7 @@ def __init__(self, config: LockingConfig): raise ValueError("Unsupported config type") @property - def _locked(self): + def _locked(self) -> bool: """ Implementation of method to check if lock has been acquired. @@ -380,9 +383,9 @@ def acquire(self, blocking: bool = True) -> bool: if not blocking: return self._acquire() - timeout = self.timeout + timeout: float = float(self.timeout) start_time = time.time() - elapsed = 0 + elapsed: float = 0.0 while timeout >= elapsed: if not self._acquire(): time.sleep(self.retry_interval) @@ -411,17 +414,17 @@ def _acquire(self) -> bool: except BaseException: return False - def _release(self): + def _release(self) -> Optional[bool]: """ Implementation of releasing an acquired lock. """ if self.passthrough: - return + return None try: return self._lock._release() except BaseException: - pass + return None def _renew(self) -> bool: """ diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py index cbbf1c5d4f0..8b2b837da67 100644 --- a/packages/syft/src/syft/store/mongo_client.py +++ b/packages/syft/src/syft/store/mongo_client.py @@ -126,7 +126,7 @@ class MongoStoreClientConfig(StoreClientConfig): class MongoClientCache: - __client_cache__: Dict[str, Type["MongoClient"]] = {} + __client_cache__: Dict[int, Optional[Type["MongoClient"]]] = {} _lock: Lock = Lock() @classmethod @@ -239,6 +239,6 @@ def with_collection_permissions( return Ok(collection_permissions) - def close(self): + def close(self) -> None: self.client.close() MongoClientCache.__client_cache__.pop(hash(str(self.config)), None) diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index efdd6496154..c292c16c005 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -1,11 +1,11 @@ # stdlib from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Optional from typing import Set from typing import Type -from typing import Union # third party from pymongo import ASCENDING @@ -62,11 +62,11 @@ class MongoDict(SyftBaseObject): def dict(self) -> Dict[Any, Any]: return dict(zip(self.keys, self.values)) - @staticmethod - def from_dict(input: Dict[Any, Any]) -> Self: - return MongoDict(keys=list(input.keys()), values=list(input.values())) + @classmethod + def from_dict(cls, input: Dict[Any, Any]) -> Self: + return cls(keys=list(input.keys()), values=list(input.values())) - def __repr__(self): + def __repr__(self) -> str: return self.dict.__repr__() @@ -105,7 +105,7 @@ def to_mongo(context: TransformContext) -> TransformContext: @transform(SyftObject, MongoBsonObject) -def syft_obj_to_mongo(): +def syft_obj_to_mongo() -> list[Callable]: return [to_mongo] @@ -167,7 +167,9 @@ def _create_update_index(self) -> Result[Ok, Err]: return collection_status collection: MongoCollection = collection_status.ok() - def check_index_keys(current_keys, new_index_keys): + def check_index_keys( + current_keys: list[tuple[str, int]], new_index_keys: list[tuple[str, int]] + ) -> bool: current_keys.sort() new_index_keys.sort() return current_keys == new_index_keys @@ -231,7 +233,7 @@ def permissions(self) -> Result[MongoCollection, Err]: return Ok(self._permissions) - def set(self, *args, **kwargs): + def set(self, *args: Any, **kwargs: Any) -> Result[SyftObject, str]: return self._set(*args, **kwargs) def _set( @@ -244,7 +246,7 @@ def _set( # TODO: Refactor this function since now it's doing both set and # update at the same time write_permission = ActionObjectWRITE(uid=obj.id, credentials=credentials) - can_write = self.has_permission(write_permission) + can_write: bool = self.has_permission(write_permission) store_query_key: QueryKey = self.settings.store_key.with_obj(obj) collection_status = self.collection @@ -258,7 +260,7 @@ def _set( if (not store_key_exists) and (not self.item_keys_exist(obj, collection)): # attempt to claim ownership for writing ownership_result = self.take_ownership(uid=obj.id, credentials=credentials) - can_write: bool = ownership_result.is_ok() + can_write = ownership_result.is_ok() elif not ignore_duplicates: unique_query_keys: QueryKeys = self.settings.unique_keys.with_obj(obj) keys = ", ".join(f"`{key.key}`" for key in unique_query_keys.all) @@ -291,7 +293,7 @@ def _set( else: return Err(f"No permission to write object with id {obj.id}") - def item_keys_exist(self, obj, collection): + def item_keys_exist(self, obj: SyftObject, collection: MongoCollection) -> bool: qks: QueryKeys = self.settings.unique_keys.with_obj(obj) query = {"$or": [{k: v} for k, v in qks.as_dict_mongo.items()]} res = collection.find_one(query) @@ -303,6 +305,7 @@ def _update( qk: QueryKey, obj: SyftObject, has_permission: bool = False, + overwrite: bool = False, ) -> Result[SyftObject, str]: collection_status = self.collection if collection_status.is_err(): @@ -355,13 +358,13 @@ def _find_index_or_search_keys( order_by: Optional[PartitionKey] = None, ) -> Result[List[SyftObject], str]: # TODO: pass index as hint to find method - qks = QueryKeys(qks=(index_qks.all + search_qks.all)) + qks = QueryKeys(qks=(list(index_qks.all) + list(search_qks.all))) return self._get_all_from_store( credentials=credentials, qks=qks, order_by=order_by ) @property - def data(self): + def data(self) -> dict: values: List = self._all(credentials=None, has_permission=True).ok() return {v.id: v for v in values} @@ -535,8 +538,8 @@ def take_ownership( return collection_status collection: MongoCollection = collection_status.ok() - data: List[UID] = collection.find_one({"_id": uid}) - permissions: List[UID] = collection_permissions.find_one({"_id": uid}) + data: Optional[List[UID]] = collection.find_one({"_id": uid}) + permissions: Optional[List[UID]] = collection_permissions.find_one({"_id": uid}) # first person using this UID can claim ownership if permissions is None and data is None: @@ -557,7 +560,7 @@ def _all( credentials: SyftVerifyKey, order_by: Optional[PartitionKey] = None, has_permission: Optional[bool] = False, - ): + ) -> Result[List[SyftObject], str]: qks = QueryKeys(qks=()) return self._get_all_from_store( credentials=credentials, @@ -566,7 +569,7 @@ def _all( has_permission=has_permission, ) - def __len__(self): + def __len__(self) -> int: collection_status = self.collection if collection_status.is_err(): return 0 @@ -653,7 +656,7 @@ def __init__( self.ddtype = ddtype self.init_client() - def init_client(self) -> Union[None, Err]: + def init_client(self) -> Optional[Err]: self.client = MongoClient(config=self.store_config.client_config) collection_status = self.client.with_collection( @@ -664,6 +667,7 @@ def init_client(self) -> Union[None, Err]: if collection_status.is_err(): return collection_status self._collection: MongoCollection = collection_status.ok() + return None @property def collection(self) -> Result[MongoCollection, Err]: @@ -757,7 +761,7 @@ def _len(self) -> int: def __len__(self) -> int: return self._len() - def _delete(self, key: UID) -> None: + def _delete(self, key: UID) -> Result[SyftSuccess, Err]: collection_status = self.collection if collection_status.is_err(): return collection_status @@ -765,8 +769,9 @@ def _delete(self, key: UID) -> None: result = collection.delete_one({"_id": key}) if result.deleted_count != 1: raise KeyError(f"{key} does not exist") + return Ok(SyftSuccess(message="Deleted")) - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: self._delete(key) def _delete_all(self) -> None: @@ -776,7 +781,7 @@ def _delete_all(self) -> None: collection: MongoCollection = collection_status.ok() collection.delete_many({}) - def clear(self) -> Self: + def clear(self) -> None: self._delete_all() def _get_all(self) -> Any: @@ -818,14 +823,14 @@ def copy(self) -> Self: # 🟡 TODO raise NotImplementedError - def update(self, *args: Any, **kwargs: Any) -> Self: + def update(self, *args: Any, **kwargs: Any) -> None: """ Inserts the specified items to the dictionary. """ # 🟡 TODO raise NotImplementedError - def __del__(self): + def __del__(self) -> None: """ Close the mongo client connection: - Cleanup client resources and disconnect from MongoDB diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index fe1201ef92b..999ce20bbd4 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -58,7 +58,7 @@ def _repr_debug_(value: Any) -> str: return repr(value) -def raise_exception(table_name: str, e: Exception): +def raise_exception(table_name: str, e: Exception) -> None: if "disk I/O error" in str(e): message = f"Error usually related to concurrent writes. {str(e)}" raise Exception(message) @@ -183,7 +183,7 @@ def _execute( ) -> Result[Ok[sqlite3.Cursor], Err[str]]: with SyftLock(self.lock_config): cursor: Optional[sqlite3.Cursor] = None - err = None + # err = None try: cursor = self.cur.execute(sql, *args) except Exception as e: @@ -196,8 +196,8 @@ def _execute( # err = Err(str(e)) self.db.commit() # Commit if everything went ok - if err is not None: - return err + # if err is not None: + # return err return Ok(cursor) @@ -323,10 +323,10 @@ def __repr__(self) -> str: def __len__(self) -> int: return self._len() - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: self._delete(key) - def clear(self) -> Self: + def clear(self) -> None: self._delete_all() def copy(self) -> Self: @@ -352,7 +352,7 @@ def __contains__(self, key: Any) -> bool: def __iter__(self) -> Any: return iter(self.keys()) - def __del__(self): + def __del__(self) -> None: try: self._close() except BaseException: From 542f83a4cf1eba25f26e8231ca3fbec1ea24ab66 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Sat, 24 Feb 2024 10:17:21 +0700 Subject: [PATCH 06/35] [refactor] fixing mypy issues of `syft/client`, `syft/types` --- .pre-commit-config.yaml | 2 +- packages/syft/src/syft/client/client.py | 2 +- .../syft/src/syft/client/domain_client.py | 10 +-- .../syft/src/syft/client/gateway_client.py | 2 +- packages/syft/src/syft/node/node.py | 2 +- packages/syft/src/syft/types/blob_storage.py | 82 ++++++++++++++----- packages/syft/src/syft/types/datetime.py | 2 +- packages/syft/src/syft/types/syft_object.py | 2 +- 8 files changed, 72 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b5633960bab..8e8497c2ccc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,7 +172,7 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/ast|^packages/syft/src/syft/capnp|^packages/syft/src/syft/client|^packages/syft/src/syft/core|^packages/syft/src/syft/custom_worker|^packages/syft/src/syft/enclave|^packages/syft/src/syft/node|^packages/syft/src/syft/shylock|^packages/syft/src/syft/exceptions|^packages/syft/src/syft/grid|^packages/syft/src/syft/img|^packages/syft/src/syft/lib|^packages/syft/src/syft/store" + files: "^packages/syft/src/syft/ast|^packages/syft/src/syft/capnp|^packages/syft/src/syft/client|^packages/syft/src/syft/core|^packages/syft/src/syft/custom_worker|^packages/syft/src/syft/enclave|^packages/syft/src/syft/node|^packages/syft/src/syft/shylock|^packages/syft/src/syft/exceptions|^packages/syft/src/syft/grid|^packages/syft/src/syft/img|^packages/syft/src/syft/lib|^packages/syft/src/syft/store|^packages/syft/src/syft/protocol" # files: "^packages/syft/src/syft/" args: [ "--follow-imports=skip", diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 39e9f887978..176c3bec82e 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -474,7 +474,7 @@ def __init__( self.credentials: Optional[SyftSigningKey] = credentials self._api = api self.communication_protocol: Optional[Union[int, str]] = None - self.current_protocol = None + self.current_protocol: Optional[Union[int, str]] = None self.post_init() diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 805fd149807..0ec22e9cccb 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -151,11 +151,11 @@ def upload_files( return SyftError(message="No files to upload") if not isinstance(file_list, list): - file_list = [file_list] + file_list2: list[Union[BlobFile, str, Path]] = [file_list] expanded_file_list = [] - for file in file_list: + for file in file_list2: if isinstance(file, BlobFile): expanded_file_list.append(file) continue @@ -189,13 +189,13 @@ def upload_files( if isinstance(file, BlobFile): print(file.path or file.file_name) else: - print(file.absolute()) + print(file.absolute()) # type: ignore[unreachable] try: result = [] for file in expanded_file_list: if not isinstance(file, BlobFile): - file = BlobFile(path=file, file_name=file.name) + file = BlobFile(path=file, file_name=file.name) # type: ignore[unreachable] print("Uploading", file.file_name) if not file.uploaded: file.upload_to_blobstorage(self) @@ -308,7 +308,7 @@ def worker_images(self) -> Optional[APIModule]: def get_project( self, name: Optional[str] = None, - uid: UID = None, + uid: Optional[UID] = None, ) -> Optional[Project]: """Get project by name or UID""" diff --git a/packages/syft/src/syft/client/gateway_client.py b/packages/syft/src/syft/client/gateway_client.py index 3f9f3e571a7..1e443b05daa 100644 --- a/packages/syft/src/syft/client/gateway_client.py +++ b/packages/syft/src/syft/client/gateway_client.py @@ -176,7 +176,7 @@ def _repr_html_(self) -> str: def __len__(self) -> int: return len(self.retrieve_nodes()) - def __getitem__(self, key: int) -> SyftClient: + def __getitem__(self, key: Union[int, str]) -> SyftClient: if not isinstance(key, int): raise SyftException(f"Key: {key} must be an integer") diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index e869c308708..1d3103674a3 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -780,7 +780,7 @@ def guest_client(self) -> SyftClient: return self.get_guest_client() @property - def current_protocol(self) -> List: + def current_protocol(self) -> Union[str, int]: data_protocol = get_data_protocol() return data_protocol.latest_version diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index e36dc56df4e..ab911b7df22 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -8,9 +8,12 @@ import threading from time import sleep from typing import Any +from typing import Callable from typing import ClassVar +from typing import Iterator from typing import List from typing import Optional +from typing import TYPE_CHECKING from typing import Type from typing import Union @@ -21,6 +24,8 @@ from typing_extensions import Self # relative +from ..client.api import SyftAPI +from ..client.client import SyftClient from ..node.credentials import SyftVerifyKey from ..serde import serialize from ..serde.serializable import serializable @@ -43,6 +48,12 @@ from .syft_object import SyftObject from .uid import UID +if TYPE_CHECKING: + # relative + from ..store.blob_storage import BlobRetrievalByURL + from ..store.blob_storage import BlobStorageConnection + + READ_EXPIRATION_TIME = 1800 # seconds DEFAULT_CHUNK_SIZE = 10000 * 1024 @@ -81,7 +92,12 @@ class BlobFile(SyftObject): __repr_attrs__ = ["id", "file_name"] - def read(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, force=False): + def read( + self, + stream: bool = False, + chunk_size: int = DEFAULT_CHUNK_SIZE, + force: bool = False, + ) -> Any: # get blob retrieval object from api + syft_blob_storage_entry_id read_method = from_api_or_context( "blob_storage.read", self.syft_node_location, self.syft_client_verify_key @@ -92,13 +108,13 @@ def read(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, force=False): ) @classmethod - def upload_from_path(self, path, client): + def upload_from_path(cls, path: Union[str, Path], client: SyftClient) -> Any: # syft absolute import syft as sy return sy.ActionObject.from_path(path=path).send(client).syft_action_data - def _upload_to_blobstorage_from_api(self, api): + def _upload_to_blobstorage_from_api(self, api: SyftAPI) -> Optional[SyftError]: if self.path is None: raise ValueError("cannot upload BlobFile, no path specified") storage_entry = CreateBlobStorageEntry.from_path(self.path) @@ -117,12 +133,14 @@ def _upload_to_blobstorage_from_api(self, api): self.syft_blob_storage_entry_id = blob_deposit_object.blob_storage_entry_id self.uploaded = True - def upload_to_blobstorage(self, client): + return None + + def upload_to_blobstorage(self, client: SyftClient) -> Optional[SyftError]: self.syft_node_location = client.id self.syft_client_verify_key = client.verify_key return self._upload_to_blobstorage_from_api(client.api) - def _iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE): + def _iter_lines(self, chunk_size: int = DEFAULT_CHUNK_SIZE) -> Iterator[bytes]: """Synchronous version of the async iter_lines. This implementation is also optimized in terms of splitting chunks, making it faster for larger lines""" @@ -130,7 +148,7 @@ def _iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE): for chunk in self.read(stream=True, chunk_size=chunk_size): if b"\n" in chunk: if pending is not None: - chunk = pending + chunk + chunk = pending + chunk # type: ignore[unreachable] lines = chunk.splitlines() if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: pending = lines.pop() @@ -146,7 +164,13 @@ def _iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE): if pending is not None: yield pending - def read_queue(self, queue, chunk_size, progress=False, buffer_lines=10000): + def read_queue( + self, + queue: Queue, + chunk_size: int, + progress: bool = False, + buffer_lines: int = 10000, + ) -> None: total_read = 0 for _i, line in enumerate(self._iter_lines(chunk_size=chunk_size)): line_size = len(line) + 1 # add byte for \n @@ -165,7 +189,9 @@ def read_queue(self, queue, chunk_size, progress=False, buffer_lines=10000): # Put anything not a string at the end queue.put(0) - def iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE, progress=False): + def iter_lines( + self, chunk_size: int = DEFAULT_CHUNK_SIZE, progress: bool = False + ) -> Iterator[str]: item_queue: Queue = Queue() threading.Thread( target=self.read_queue, @@ -177,19 +203,19 @@ def iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE, progress=False): yield item item = item_queue.get() - def _coll_repr_(self): + def _coll_repr_(self) -> dict[str, str]: return {"file_name": self.file_name} @migrate(BlobFile, BlobFileV1) -def downgrade_blobfile_v2_to_v1(): +def downgrade_blobfile_v2_to_v1() -> list[Callable]: return [ drop(["syft_blob_storage_entry_id", "file_size"]), ] @migrate(BlobFileV1, BlobFile) -def upgrade_blobfile_v1_to_v2(): +def upgrade_blobfile_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_blob_storage_entry_id", None), make_set_default("file_size", None), @@ -225,7 +251,13 @@ class SecureFilePathLocation(SyftObject): def __repr__(self) -> str: return f"{self.path}" - def generate_url(self, *args): + def generate_url( + self, + connection: "BlobStorageConnection", + type_: Optional[Type], + bucket_name: Optional[str], + *args: Any, + ) -> "BlobRetrievalByURL": raise NotImplementedError @@ -244,7 +276,13 @@ class SeaweedSecureFilePathLocation(SecureFilePathLocation): upload_id: Optional[str] = None - def generate_url(self, connection, type_, bucket_name): + def generate_url( + self, + connection: "BlobStorageConnection", + type_: Optional[Type], + bucket_name: Optional[str], + *args: Any, + ) -> "BlobRetrievalByURL": try: url = connection.client.generate_presigned_url( ClientMethod="get_object", @@ -263,12 +301,12 @@ def generate_url(self, connection, type_, bucket_name): @migrate(SeaweedSecureFilePathLocationV1, SeaweedSecureFilePathLocation) -def upgrade_seaweedsecurefilepathlocation_v1_to_v2(): +def upgrade_seaweedsecurefilepathlocation_v1_to_v2() -> list[Callable]: return [make_set_default("bucket_name", "")] @migrate(SeaweedSecureFilePathLocation, SeaweedSecureFilePathLocationV1) -def downgrade_seaweedsecurefilepathlocation_v2_to_v1(): +def downgrade_seaweedsecurefilepathlocation_v2_to_v1() -> list[Callable]: return [ drop(["bucket_name"]), ] @@ -283,7 +321,9 @@ class AzureSecureFilePathLocation(SecureFilePathLocation): azure_profile_name: str # Used by Seaweedfs to refer to a remote config bucket_name: str - def generate_url(self, connection, type_, *args): + def generate_url( + self, connection: "BlobStorageConnection", type_: Optional[Type], *args: Any + ) -> "BlobRetrievalByURL": # SAS is almost the same thing as the presigned url config = connection.config.remote_profiles[self.azure_profile_name] account_name = config.account_name @@ -340,14 +380,14 @@ class BlobStorageEntry(SyftObject): @migrate(BlobStorageEntry, BlobStorageEntryV1) -def downgrade_blobstorageentry_v2_to_v1(): +def downgrade_blobstorageentry_v2_to_v1() -> list[Callable]: return [ drop(["no_lines", "bucket_name"]), ] @migrate(BlobStorageEntryV1, BlobStorageEntry) -def upgrade_blobstorageentry_v1_to_v2(): +def upgrade_blobstorageentry_v1_to_v2() -> list[Callable]: return [make_set_default("no_lines", 1), make_set_default("bucket_name", None)] @@ -373,14 +413,14 @@ class BlobStorageMetadata(SyftObject): @migrate(BlobStorageMetadata, BlobStorageMetadataV1) -def downgrade_blobmeta_v2_to_v1(): +def downgrade_blobmeta_v2_to_v1() -> list[Callable]: return [ drop(["no_lines"]), ] @migrate(BlobStorageMetadataV1, BlobStorageMetadata) -def upgrade_blobmeta_v1_to_v2(): +def upgrade_blobmeta_v1_to_v2() -> list[Callable]: return [make_set_default("no_lines", 1)] @@ -433,7 +473,7 @@ def file_name(self) -> str: @transform(BlobStorageEntry, BlobStorageMetadata) -def storage_entry_to_metadata(): +def storage_entry_to_metadata() -> list[Callable]: return [keep(["id", "type_", "mimetype", "file_size"])] diff --git a/packages/syft/src/syft/types/datetime.py b/packages/syft/src/syft/types/datetime.py index c03e1433fd0..36efca1bc1a 100644 --- a/packages/syft/src/syft/types/datetime.py +++ b/packages/syft/src/syft/types/datetime.py @@ -33,7 +33,7 @@ def __str__(self) -> str: def __hash__(self) -> int: return hash(self.utc_timestamp) - def __eq__(self, other: Self) -> bool: + def __eq__(self, other: object) -> bool: return self.utc_timestamp == other.utc_timestamp def __lt__(self, other: Self) -> bool: diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 0a278370a6b..a29d9e98064 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -476,7 +476,7 @@ def keys(self) -> KeysView[str]: return self.__dict__.keys() # allows splatting with ** - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: Union[str, int]) -> Any: return self.__dict__.__getitem__(key) def _upgrade_version(self, latest: bool = True) -> "SyftObject": From 5264f9ccae72b00309ceeeb3eb7bacd2f4add27e Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 26 Feb 2024 09:29:23 +0700 Subject: [PATCH 07/35] [refactor] fix container notebooks not passing due to `error: cannot access local variable where it is not associated with a value` --- packages/syft/src/syft/client/domain_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 0ec22e9cccb..c37ab407476 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -151,11 +151,11 @@ def upload_files( return SyftError(message="No files to upload") if not isinstance(file_list, list): - file_list2: list[Union[BlobFile, str, Path]] = [file_list] + file_list = [file_list] expanded_file_list = [] - for file in file_list2: + for file in file_list: if isinstance(file, BlobFile): expanded_file_list.append(file) continue From 00d8b8538d8104d1de1759baed00af6358cd7592 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Tue, 27 Feb 2024 10:42:28 +0700 Subject: [PATCH 08/35] [refactor] fixing various mypy issues in `syft/` --- .pre-commit-config.yaml | 5 +- packages/syft/setup.cfg | 2 +- packages/syft/src/syft/__init__.py | 3 + packages/syft/src/syft/client/api.py | 19 +++--- packages/syft/src/syft/client/client.py | 37 +++++++---- .../syft/src/syft/client/domain_client.py | 66 +++++++++++-------- .../syft/src/syft/client/enclave_client.py | 33 +++++----- .../syft/src/syft/client/gateway_client.py | 23 ++++--- packages/syft/src/syft/external/oblv/auth.py | 3 +- .../syft/external/oblv/deployment_client.py | 11 ++-- .../src/syft/external/oblv/oblv_service.py | 8 +-- packages/syft/src/syft/node/node.py | 41 ++++++++---- .../syft/src/syft/protocol/data_protocol.py | 11 ++-- .../src/syft/store/blob_storage/seaweedfs.py | 2 +- packages/syft/src/syft/store/locks.py | 10 +-- packages/syft/src/syft/types/identity.py | 2 + .../syft/src/syft/types/syft_metaclass.py | 2 +- packages/syft/src/syft/types/syft_object.py | 59 ++++++++++------- packages/syft/src/syft/types/transforms.py | 65 +++++++++++------- packages/syft/src/syft/types/uid.py | 7 +- packages/syft/src/syft/util/autoreload.py | 6 ++ packages/syft/src/syft/util/schema.py | 16 +++-- .../syft/src/syft/util/trace_decorator.py | 2 +- .../syft/src/syft/util/version_compare.py | 4 +- 24 files changed, 261 insertions(+), 176 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8e8497c2ccc..91c795bf72c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,8 +172,8 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/ast|^packages/syft/src/syft/capnp|^packages/syft/src/syft/client|^packages/syft/src/syft/core|^packages/syft/src/syft/custom_worker|^packages/syft/src/syft/enclave|^packages/syft/src/syft/node|^packages/syft/src/syft/shylock|^packages/syft/src/syft/exceptions|^packages/syft/src/syft/grid|^packages/syft/src/syft/img|^packages/syft/src/syft/lib|^packages/syft/src/syft/store|^packages/syft/src/syft/protocol" - # files: "^packages/syft/src/syft/" + files: "^packages/syft/src/syft/" + exclude: "^packages/syft/src/syft/service|^packages/syft/src/syft/external/oblv/|^packages/syft/src/syft/types/dicttuple.py|^packages/syft/src/syft/util/telemetry.py" args: [ "--follow-imports=skip", "--ignore-missing-imports", @@ -190,7 +190,6 @@ repos: "--install-types", "--non-interactive", "--config-file=tox.ini", - "--disable-error-code=union-attr", # todo: remove this line after fixing the issue context.node can be None ] - repo: https://github.com/kynan/nbstripout diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 1a0afd7d28c..e88ce585f21 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -91,7 +91,7 @@ data_science = dev = %(test_plugins)s %(telemetry)s - bandit==1.7.5 + bandit==1.7.7 ruff==0.1.6 importlib-metadata==6.8.0 isort==5.12.0 diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index d438f9ef9bd..2a0fcfa5b6d 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -93,6 +93,9 @@ logger.start() try: + # third party + from IPython import get_ipython + get_ipython() # noqa: F821 # TODO: add back later or auto detect # display( diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 2e7c7fc25e3..d41ffbaac61 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -16,6 +16,7 @@ from typing import Tuple from typing import Union from typing import _GenericAlias +from typing import cast from typing import get_args from typing import get_origin @@ -240,8 +241,8 @@ def __ipython_inspector_signature_override__(self) -> Optional[Signature]: return self.signature def prepare_args_and_kwargs( - self, args: List[Any], kwargs: Dict[str, Any] - ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]: + self, args: Union[list, tuple], kwargs: dict[str, Any] + ) -> Union[SyftError, tuple[tuple, dict[str, Any]]]: # Validate and migrate args and kwargs res = validate_callable_args_and_kwargs(args, kwargs, self.signature) if isinstance(res, SyftError): @@ -278,7 +279,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: api_call = SyftAPICall( node_uid=self.node_uid, path=self.path, - args=_valid_args, + args=list(_valid_args), kwargs=_valid_kwargs, blocking=blocking, ) @@ -303,8 +304,8 @@ class RemoteUserCodeFunction(RemoteFunction): api: SyftAPI def prepare_args_and_kwargs( - self, args: List[Any], kwargs: Dict[str, Any] - ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]: + self, args: Union[list, tuple], kwargs: Dict[str, Any] + ) -> Union[SyftError, tuple[tuple, dict[str, Any]]]: # relative from ..service.action.action_object import convert_to_pointers @@ -824,16 +825,16 @@ def build_endpoint_tree( ) @property - def services(self) -> Optional[APIModule]: + def services(self) -> APIModule: if self.api_module is None: self.generate_endpoints() - return self.api_module + return cast(APIModule, self.api_module) @property - def lib(self) -> Optional[APIModule]: + def lib(self) -> APIModule: if self.libs is None: self.generate_endpoints() - return self.libs + return cast(APIModule, self.libs) def has_service(self, service_name: str) -> bool: return hasattr(self.services, service_name) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 176c3bec82e..f5feb2be19a 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -484,7 +484,7 @@ def get_env(self) -> str: def post_init(self) -> None: if self.metadata is None: self._fetch_node_metadata(self.credentials) - + self.metadata = cast(NodeMetadataJSON, self.metadata) self.communication_protocol = self._get_communication_protocol( self.metadata.supported_protocols ) @@ -625,9 +625,7 @@ def api(self) -> SyftAPI: # invalidate API if self._api is None or (self._api.signing_key != self.credentials): self._fetch_api(self.credentials) - if self._api is None: - raise ValueError(f"{self}'s api is None") - return self._api + return cast(SyftAPI, self._api) # we are sure self._api is not None after fetch def guest(self) -> Self: return self.__class__( @@ -642,7 +640,8 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]: self_node_route = connection_to_route(self.connection) remote_node_route = connection_to_route(client.connection) - + if client.metadata is None: + return SyftError(f"client {client}'s metadata is None!") result = self.api.services.network.exchange_credentials_with( self_node_route=self_node_route, remote_node_route=remote_node_route, @@ -700,10 +699,11 @@ def me(self) -> Optional[Union[UserView, SyftError]]: def login_as_guest(self) -> Self: _guest_client = self.guest() - print( - f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side " - f"{self.metadata.node_type.capitalize()}> as GUEST" - ) + if self.metadata is not None: + print( + f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side " + f"{self.metadata.node_type.capitalize()}> as GUEST" + ) return _guest_client @@ -747,11 +747,11 @@ def login( client.__logged_in_user = email - if user_private_key is not None: + if user_private_key is not None and client.users is not None: client.__user_role = user_private_key.role client.__logged_in_username = client.users.get_current_user().name - if signing_key is not None: + if signing_key is not None and client.metadata is not None: print( f"Logged into <{client.name}: {client.metadata.node_side_type.capitalize()} side " f"{client.metadata.node_type.capitalize()}> as <{email}>" @@ -789,7 +789,9 @@ def login( # relative from ..node.node import CODE_RELOADER - CODE_RELOADER[thread_ident()] = client._reload_user_code + thread_id = thread_ident() + if thread_id is not None: + CODE_RELOADER[thread_id] = client._reload_user_code return client @@ -808,7 +810,7 @@ def register( password_verify: Optional[str] = None, institution: Optional[str] = None, website: Optional[str] = None, - ) -> Optional[Union[SyftError, Any]]: + ) -> Optional[Union[SyftError, SyftSigningKey]]: if not email: email = input("Email: ") if not password: @@ -833,7 +835,10 @@ def register( except Exception as e: return SyftError(message=str(e)) - if self.metadata.node_side_type == NodeSideType.HIGH_SIDE.value: + if ( + self.metadata + and self.metadata.node_side_type == NodeSideType.HIGH_SIDE.value + ): message = ( "You're registering a user to a high side " f"{self.metadata.node_type}, which could " @@ -889,6 +894,10 @@ def refresh_callback() -> None: return self._fetch_api(self.credentials) _api.refresh_api_callback = refresh_callback + + if self.credentials is None: + raise ValueError(f"{self}'s credentials (signing key) is None!") + APIRegistry.set_api_for( node_uid=self.id, user_verify_key=self.credentials.verify_key, diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index c37ab407476..a0fdae07e07 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -7,6 +7,7 @@ from typing import Optional from typing import TYPE_CHECKING from typing import Union +from typing import cast # third party from hagrid.orchestra import NodeHandle @@ -90,6 +91,9 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError # relative from ..types.twin_object import TwinObject + if self.users is None: + return SyftError(f"can't get user service for {self}") + user = self.users.get_current_user() dataset = add_default_uploader(user, dataset) for i in range(len(dataset.asset_list)): @@ -97,21 +101,22 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError dataset.asset_list[i] = add_default_uploader(user, asset) dataset._check_asset_must_contain_mock() - dataset_size = 0 + dataset_size: float = 0.0 # TODO: Refactor so that object can also be passed to generate warnings - metadata = self.api.connection.get_node_metadata(self.api.signing_key) - - if ( - metadata.show_warnings - and metadata.node_side_type == NodeSideType.HIGH_SIDE.value - ): - message = ( - "You're approving a request on " - f"{metadata.node_side_type} side {metadata.node_type} " - "which may host datasets with private information." - ) - prompt_warning_message(message=message, confirm=True) + if self.api.connection: + metadata = self.api.connection.get_node_metadata(self.api.signing_key) + + if ( + metadata.show_warnings + and metadata.node_side_type == NodeSideType.HIGH_SIDE.value + ): + message = ( + "You're approving a request on " + f"{metadata.node_side_type} side {metadata.node_type} " + "which may host datasets with private information." + ) + prompt_warning_message(message=message, confirm=True) for asset in tqdm(dataset.asset_list): print(f"Uploading: {asset.name}") @@ -151,7 +156,8 @@ def upload_files( return SyftError(message="No files to upload") if not isinstance(file_list, list): - file_list = [file_list] + file_list = [file_list] # type: ignore[assignment] + file_list = cast(list, file_list) expanded_file_list = [] @@ -230,9 +236,12 @@ def connect_to_gateway( res = self.exchange_route(client) if isinstance(res, SyftSuccess): - return SyftSuccess( - message=f"Connected {self.metadata.node_type} to {client.name} gateway" - ) + if self.metadata: + return SyftSuccess( + message=f"Connected {self.metadata.node_type} to {client.name} gateway" + ) + else: + return SyftSuccess(message=f"Connected to {client.name} gateway") return res @property @@ -375,18 +384,17 @@ def _repr_html_(self) -> str: url = getattr(self.connection, "url", None) node_details = f"URL: {url}
" if url else "" - node_details += ( - f"Node Type: {self.metadata.node_type.capitalize()}
" - ) - node_side_type = ( - "Low Side" - if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value - else "High Side" - ) - node_details += f"Node Side Type: {node_side_type}
" - node_details += ( - f"Syft Version: {self.metadata.syft_version}
" - ) + if self.metadata is not None: + node_details += f"Node Type: {self.metadata.node_type.capitalize()}
" + node_side_type = ( + "Low Side" + if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value + else "High Side" + ) + node_details += f"Node Side Type: {node_side_type}
" + node_details += ( + f"Syft Version: {self.metadata.syft_version}
" + ) return f"""